Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_replication.py: 17%
311 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file8 except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
16"""OutsideCompilation, TPUReplicateContext, and supporting functions."""
18from typing import Any, Callable, List, Optional, Text, Tuple, Union
19from absl import logging
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.python.distribute import device_util
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.framework import device as pydev
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import func_graph
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import variables
30from tensorflow.python.tpu import device_assignment as device_assignment_lib
31from tensorflow.python.tpu.ops import tpu_ops
32from tensorflow.python.types import core as core_types
33from tensorflow.python.util import compat
34from tensorflow.python.util.tf_export import tf_export
36_MAX_WARNING_LINES = 5
37_TPU_REPLICATE_ATTR = "_tpu_replicate"
38_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
40# Operations that indicate some error in the users graph, e.g. a placeholder
41# that's introduced outside of the infeed.
42_DENYLISTED_OPS = frozenset([
43 "Placeholder",
44])
47# XLA doesn't currently support reading of intermediate tensors, thus some ops
48# are not supported.
49_UNSUPPORTED_OPS = frozenset([
50 "AudioSummary",
51 "AudioSummaryV2",
52 "HistogramSummary",
53 "ImageSummary",
54 "MergeSummary",
55 "Print",
56 "ScalarSummary",
57 "TensorSummary",
58 "TensorSummaryV2",
59])
62def is_tpu_strategy(strategy: Any) -> bool:
63 is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy")
64 clz = strategy.__class__
65 return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__))
68def _enclosing_tpu_device_assignment(
69) -> Optional[device_assignment_lib.DeviceAssignment]:
70 if not distribute_lib.has_strategy():
71 return None
72 strategy = distribute_lib.get_strategy()
73 if not is_tpu_strategy(strategy):
74 return None
75 return strategy.extended._device_assignment # pylint: disable=protected-access
78class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
79 """A `ControlFlowContext` for nodes inside a TPU computation.
81 The primary role of `TPUReplicateContext` is to mark operators inside a
82 tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
83 is a unique name.
85 We use a `ControlFlowContext` to perform the annotation since it integrates
86 with Tensorflow constructs like ResourceVariables. For example, if a
87 `ResourceVariable` is constructed inside a tpu.replicate() block, the
88 `ResourceVariable` implementation can use
89 `with ops.control_dependencies(None)` to build the variable's definition
90 outside the replicated computation.
91 """
93 def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation):
94 """Builds a new TPUReplicateContext.
96 Args:
97 name: a unique name for the context, used to populate the `_tpu_replicate`
98 attribute.
99 num_replicas: an integer that gives the number of replicas for the
100 computation.
101 pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any
102 inputs will have a control dependency on the pivot node. This ensures
103 that nodes are correctly included in any enclosing control flow
104 contexts.
105 """
106 super(TPUReplicateContext, self).__init__()
107 self._num_replicas = num_replicas
108 self._outer_device_function_stack = None
109 self._oc_dev_fn_stack = None
110 self._outside_compilation_cluster = None
111 self._outside_compilation_v2_context = None
112 self._outside_compilation_counter = 0
113 self._in_gradient_colocation = None
114 self._gradient_colocation_stack = []
115 self._host_compute_core = []
116 self._name = name
117 self._tpu_replicate_attr = attr_value_pb2.AttrValue(
118 s=compat.as_bytes(self._name)
119 )
120 self._unsupported_ops = []
121 self._pivot = pivot
122 self._replicated_vars = {}
124 def get_replicated_var_handle(self,
125 name: Text,
126 handle_id: Text,
127 vars_: Union[List[core_types.Tensor],
128 List[variables.Variable]],
129 is_mirrored: bool = False,
130 is_packed: bool = False) -> core_types.Tensor:
131 """Returns a variable handle for replicated TPU variable 'var'.
133 This is a method used by an experimental replicated variable implementation
134 and is not intended as a public API.
136 Args:
137 name: The common name of the variable.
138 handle_id: Unique ID of the variable handle, used as the cache key.
139 vars_: The replicated TPU variables or handles.
140 is_mirrored: Whether the variables are mirrored, which guarantees the
141 values in each replica are always the same.
142 is_packed: Whether the replicated variables are packed into one variable.
144 Returns:
145 The handle of the TPU replicated input node.
146 """
147 device_assignment = _enclosing_tpu_device_assignment()
148 # We don't need to put device assignment as part of the replicated_vars key
149 # because each TPUReplicateContext will only have one device assignment.
150 handle = self._replicated_vars.get(handle_id)
151 if handle is not None:
152 return handle
154 if device_assignment is not None and not is_packed:
155 # Find a variable copy for each replica in the device assignment.
156 # Note that the order of devices for replicas for the variable and the
157 # device assignment might not match.
158 job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
159 devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_}
160 replicated_vars = []
161 for replica_id in range(device_assignment.num_replicas):
162 for logical_core in range(device_assignment.num_cores_per_replica):
163 device = device_util.canonicalize(
164 device_assignment.tpu_device(
165 replica=replica_id, logical_core=logical_core, job=job_name))
166 if device in devices_to_vars:
167 replicated_vars.append(devices_to_vars[device])
168 break
169 else:
170 raise ValueError(
171 "Failed to find a variable on any device in replica {} for "
172 "current device assignment".format(replica_id)
173 )
174 else:
175 replicated_vars = vars_
177 # Builds a TPUReplicatedInput node for the variable, if one does not already
178 # exist. The TPUReplicatedInput node must belong to the enclosing
179 # control-flow scope of the TPUReplicateContext.
180 # TODO(phawkins): consider changing the contract of the TPU encapsulation
181 # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
182 # instead.
184 _, graph = _enclosing_tpu_context_and_graph()
185 with graph.as_default():
186 # If replicated_vars are variables, get the handles. Note that this can be
187 # done inside TPUReplicateContext because replicated_vars.handle may
188 # create new ops.
189 if isinstance(replicated_vars[0], variables.Variable):
190 replicated_vars = [v.handle for v in replicated_vars]
191 # pylint: disable=protected-access
192 saved_context = graph._get_control_flow_context()
193 graph._set_control_flow_context(self.outer_context)
194 handle = tpu_ops.tpu_replicated_input(
195 replicated_vars,
196 name=name + "/handle",
197 is_mirrored_variable=is_mirrored,
198 is_packed=is_packed)
199 graph._set_control_flow_context(saved_context)
200 # pylint: enable=protected-access
201 self._replicated_vars[handle_id] = handle
202 return handle
204 def report_unsupported_operations(self) -> None:
205 if self._unsupported_ops:
206 op_str = "\n".join(
207 " %s (%s)" % (op.type, op.name) for op in
208 self._unsupported_ops[:_MAX_WARNING_LINES])
209 logging.warning("%d unsupported operations found: \n%s",
210 len(self._unsupported_ops), op_str)
211 if len(self._unsupported_ops
212 ) > _MAX_WARNING_LINES:
213 logging.warning("... and %d more",
214 (len(self._unsupported_ops) - _MAX_WARNING_LINES))
216 def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text):
217 if op is not None:
218 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
219 # If we are in TF 2 functions (control flow V2 functions, or
220 # tf.function()), we need to attach _xla_outside_compilation attribute
221 # directly because we are not in TPUReplicateContext.
222 try:
223 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
224 except ValueError:
225 # The attr was not present: do nothing.
226 return
227 parts = outside_attr.split(".")
228 cluster = parts[0] + "." + gradient_uid
229 self._outside_compilation_v2_context = OutsideCompilationV2Context(
230 cluster)
231 self._outside_compilation_v2_context.Enter()
232 return
233 self._gradient_colocation_stack.append(op)
234 if not self._outside_compilation_cluster:
235 try:
236 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
237 if self._in_gradient_colocation:
238 raise NotImplementedError(
239 "Cannot nest gradient colocation operations outside compilation"
240 )
241 if gradient_uid == "__unsupported__":
242 raise NotImplementedError(
243 "No gradient_uid calling gradient within outside_compilation")
244 # When we take the gradient of an op X in an outside_compilation
245 # cluster C in a forward computation we would like to put the ops
246 # corresponding to the gradient of X into a new outside_compilation
247 # cluster C'. However, if we take the gradient of X twice, the second
248 # one should get yet another new outside_compilation cluster C''.
249 #
250 # The mechanism we adopt is to use a 'root_cluster' which is the
251 # cluster that X was in before we took gradients, and a 'gradient_uid'
252 # which is different for every invocation of gradients, and put the
253 # gradient of X in cluster 'root_cluster.gradient_uid'.
254 #
255 # When taking a gradient of a gradient, some ops will be colocated
256 # with Op in the forward pass (e.g., cluster root_cluster) and some in
257 # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
258 # We need all of the grad-of-grad ops to be in the same cluster to
259 # avoid cyclic dependencies between clusters. We adopt a heuristic
260 # that puts any op clustered with root_cluster.<xxx> in
261 # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
262 self._in_gradient_colocation = op
263 parts = outside_attr.split(".")
264 cluster = parts[0] + "." + gradient_uid
265 self._EnterOutsideCompilationScope(cluster=cluster)
266 except ValueError:
267 # The attr was not present: do nothing.
268 pass
270 def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text):
271 if op is not None:
272 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
273 # Inside a TF2 tf.function or control flow graph and `op` was not
274 # marked to be outside compiled.
275 assert self._outside_compilation_v2_context is None
276 return
277 if self._outside_compilation_v2_context is not None:
278 # Inside a TF2 tf.function or control flow graph and `op` was
279 # marked to be outside compiled.
280 self._outside_compilation_v2_context.Exit()
281 self._outside_compilation_v2_context = None
282 return
283 if not self._gradient_colocation_stack:
284 raise errors.InternalError(
285 op.node_def, op,
286 ("Badly nested gradient colocation: "
287 + f"empty stack when popping Op {op.name}")
288 )
289 last_op = self._gradient_colocation_stack.pop()
290 if op is last_op:
291 if op is self._in_gradient_colocation:
292 self._in_gradient_colocation = None
293 self._ExitOutsideCompilationScope()
294 else:
295 raise errors.InternalError(
296 op.node_def, op,
297 ("Badly nested gradient colocation, " +
298 f"expected {last_op}, got {op.name}")
299 )
301 def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None):
303 class FakeOp(object):
304 """A helper class to determine the current device.
306 Supports only the type and device set/get methods needed to run the
307 graph's _apply_device_function method.
308 """
310 def __init__(self):
311 self._device = ""
313 @property
314 def type(self):
315 return "FakeOp"
317 @property
318 def device(self):
319 return self._device
321 def _set_device(self, device):
322 if isinstance(device, pydev.DeviceSpec):
323 self._device = device.to_string()
324 else:
325 self._device = device
327 def _set_device_from_string(self, device_str):
328 self._device = device_str
330 if self._outside_compilation_cluster:
331 raise NotImplementedError("Cannot nest outside_compilation clusters")
332 if cluster:
333 self._outside_compilation_cluster = cluster
334 else:
335 self._outside_compilation_cluster = str(self._outside_compilation_counter)
336 self._outside_compilation_counter += 1
337 graph = ops.get_default_graph()
338 fake_op = FakeOp()
339 graph._apply_device_functions(fake_op) # pylint: disable=protected-access
340 device = pydev.DeviceSpec.from_string(fake_op.device)
341 if (device.device_type == "TPU_REPLICATED_CORE" and
342 device.device_index is not None):
343 self._host_compute_core.append(self._outside_compilation_cluster + ":" +
344 str(device.device_index))
345 self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access
346 graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access
348 def _ExitOutsideCompilationScope(self):
349 if not self._outside_compilation_cluster:
350 raise ValueError(
351 "Attempted to exit outside_compilation scope when not in scope")
352 self._outside_compilation_cluster = None
353 graph = ops.get_default_graph()
354 graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access
356 def Enter(self) -> None:
357 if not self._outer_device_function_stack:
358 # Capture the device function stack at the time of first entry
359 # since that is the stack that will be used outside_compilation.
360 graph = ops.get_default_graph()
361 # pylint: disable=protected-access
362 self._outer_device_function_stack = graph._device_function_stack.copy()
363 # pylint: enable=protected-access
364 super(TPUReplicateContext, self).Enter()
366 def HostComputeCore(self) -> List[Text]:
367 return self._host_compute_core
369 def _RemoveExternalControlEdges(
370 self,
371 op: ops.Operation) -> Tuple[List[ops.Operation], List[ops.Operation]]:
372 """Remove any external control dependency on this op."""
373 internal_control_inputs = []
374 external_control_inputs = []
375 for x in op.control_inputs:
376 # pylint: disable=protected-access
377 is_internal_op = False
378 ctxt = x._get_control_flow_context()
379 while ctxt is not None:
380 if ctxt == self:
381 is_internal_op = True
382 break
383 ctxt = ctxt._outer_context
384 if is_internal_op:
385 internal_control_inputs.append(x)
386 else:
387 external_control_inputs.append(x)
388 # pylint: enable=protected-access
389 # pylint: disable=protected-access
390 op._remove_all_control_inputs()
391 op._add_control_inputs(internal_control_inputs)
392 # pylint: enable=protected-access
393 return internal_control_inputs, external_control_inputs
395 def AddOp(self, op: ops.Operation) -> None:
396 # pylint: disable=protected-access
397 if op.type in _DENYLISTED_OPS:
398 logging.error(
399 "Operation of type %s (%s) is not supported on the TPU. "
400 "Execution will fail if this op is used in the graph. ", op.type,
401 op.name)
403 if op.type in _UNSUPPORTED_OPS:
404 self._unsupported_ops.append(op)
406 if any(x.dtype._is_ref_dtype for x in op.inputs):
407 raise NotImplementedError(
408 f"Non-resource Variables are not supported inside TPU computations "
409 f"(operator name: {op.name})")
411 # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add
412 # the "_cloned" attribute and we should continue in that case.
413 if (_TPU_REPLICATE_ATTR in op.node_def.attr and
414 "_cloned" not in op.node_def.attr):
415 raise ValueError(f"TPU computations cannot be nested on op ({op})")
416 op._set_attr(_TPU_REPLICATE_ATTR, self._tpu_replicate_attr)
417 if self._outside_compilation_cluster:
418 op._set_attr(
419 _OUTSIDE_COMPILATION_ATTR,
420 attr_value_pb2.AttrValue(
421 s=compat.as_bytes(self._outside_compilation_cluster)))
422 if self._num_replicas > 1 or not self._outside_compilation_cluster:
423 # Prevent feeding or fetching anything that is being compiled,
424 # and any replicated outside_compilation Op.
425 op.graph.prevent_feeding(op)
426 op.graph.prevent_fetching(op)
428 # Remove any control edges from outer control flow contexts. These may cause
429 # mismatched frame errors.
430 (internal_control_inputs,
431 external_control_inputs) = self._RemoveExternalControlEdges(op)
433 if not op.inputs:
434 # Add a control edge from the control pivot to this op.
435 if not internal_control_inputs:
436 # pylint: disable=protected-access
437 op._add_control_input(self.GetControlPivot())
438 # pylint: enable=protected-access
439 else:
440 for index in range(len(op.inputs)):
441 x = op.inputs[index]
442 real_x = self.AddValue(x)
443 if real_x is not x:
444 op._update_input(index, real_x) # pylint: disable=protected-access
446 if external_control_inputs:
447 # Use an identity to pull control inputs as data inputs. Note that we
448 # ignore ops which don't have outputs. TODO(phawkins): fix that.
449 with ops.control_dependencies(None):
450 self.Enter()
451 external_control_inputs = [
452 array_ops.identity(x.outputs[0]).op
453 for x in external_control_inputs
454 if x.outputs
455 ]
456 self.Exit()
457 # pylint: disable=protected-access
458 op._add_control_inputs(external_control_inputs)
459 # pylint: enable=protected-access
461 # Mark op's outputs as seen by this context and any outer contexts.
462 output_names = [x.name for x in op.outputs]
463 context = self
464 while context is not None:
465 # pylint: disable=protected-access
466 context._values.update(output_names)
467 context = context._outer_context
468 # pylint: enable=protected-access
470 if self._outer_context:
471 self._outer_context.AddInnerOp(op)
473 def AddValue(self, val: core_types.Tensor) -> core_types.Tensor:
474 """Add `val` to the current context and its outer context recursively."""
475 if not self._outer_context:
476 return val
478 if val.name in self._values:
479 # Use the real value if it comes from outer context.
480 result = self._external_values.get(val.name)
481 return val if result is None else result
483 result = val
484 self._values.add(val.name)
485 if self._outer_context:
486 result = self._outer_context.AddValue(val)
487 self._values.add(result.name)
489 self._external_values[val.name] = result
491 return result
493 def AddInnerOp(self, op: ops.Operation):
494 self.AddOp(op)
495 if self._outer_context:
496 self._outer_context.AddInnerOp(op)
498 @property
499 def grad_state(self):
500 # Define the gradient loop state associated with the TPUReplicateContext to
501 # be None as the TPUReplicateContext does not get nested nor does the
502 # grad_state outside the TPUReplicateContext affect the graph inside so the
503 # grad_state should be as if this is the top-level gradient state.
504 return None
506 @property
507 def back_prop(self):
508 """Forwards to the enclosing while context, if any."""
509 if self.GetWhileContext():
510 return self.GetWhileContext().back_prop
511 return False
513 def GetControlPivot(self) -> ops.Operation:
514 return self._pivot
516 def RequiresUniqueFunctionRetracing(self):
517 # More context: b/158152827. TPU stack uses the TPUReplicateContext to
518 # create replicated variable handles and cluster TPU computations, thus we
519 # always retrace a tf.function when the wrapped TPUReplicateContext changes.
520 return True
523def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]:
524 """Returns the TPUReplicateContext and its associated graph."""
525 graph = ops.get_default_graph()
526 while graph is not None:
527 # pylint: disable=protected-access
528 context_ = graph._get_control_flow_context()
529 # pylint: enable=protected-access
530 while context_ is not None:
531 if isinstance(context_, TPUReplicateContext):
532 return context_, graph
533 context_ = context_.outer_context
534 graph = getattr(graph, "outer_graph", None)
535 raise ValueError("get_replicated_var_handle() called without "
536 "TPUReplicateContext. This shouldn't happen. Please file "
537 "a bug.")
540class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
541 """The context for outside compilation in Tensorflow 2.0.
543 Every op added in this context will be assigned an _xla_outside_compilation
544 attribute.
545 """
547 def __init__(self, name: Text):
548 control_flow_ops.ControlFlowContext.__init__(self)
549 self._name = name
551 def AddOp(self, op: ops.Operation) -> None:
552 if self._outer_context:
553 self._outer_context.AddOp(op)
554 # pylint: disable=protected-access
555 op._set_attr("_xla_outside_compilation",
556 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
557 # pylint: enable=protected-access
559 def AddInnerOp(self, op: ops.Operation) -> None:
560 if self._outer_context:
561 self._outer_context.AddInnerOp(op)
562 # pylint: disable=protected-access
563 op._set_attr("_xla_outside_compilation",
564 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
565 # pylint: enable=protected-access
567 def to_control_flow_context_def(self, context_def, export_scope=None):
568 raise NotImplementedError
571@tf_export(v1=["tpu.outside_compilation"])
572def outside_compilation(computation: Callable[..., Any], *args,
573 **kwargs) -> Any:
574 """Builds part of a computation outside any current TPU replicate scope.
576 `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU
577 instead of running on TPU. For example, users can run ops that are not
578 supported on TPU's (e.g. tf.summary.write()) by explicitly placing those
579 ops on CPU's. Below usage of outside compilation will place ops in
580 `computation_with_string_ops` on CPU.
582 Example usage:
584 ```python
585 def computation_with_string_ops(x):
586 # strings types are not supported on TPU's and below ops must
587 # run on CPU instead.
588 output = tf.strings.format('1{}', x)
589 return tf.strings.to_number(output)
591 def tpu_computation():
592 # Expected output is 11.
593 output = tf.tpu.outside_compilation(computation_with_string_ops, 1)
594 ```
596 Outside compilation should be called inside TPUReplicateContext. That is,
597 `tf.tpu.outside_compilation()` should be called inside a function that is
598 passed to `tpu.split_compile_and_replicate()` -- this is implied when
599 outside compilation is invoked inside a function passed to TPUStrategy
600 `run()`. If invoked outside of TPUReplicateContext,
601 then this simply returns the result of `computation`, and therefore,
602 would be a no-op. Note that outside compilation is different from
603 `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in
604 outside compilation is replicated and executed separately for each
605 replica. On the other hand, `merge_call()` requires a `merge_fn`
606 to aggregate the inputs from different replicas and is executed only
607 once.
609 For variables placed in TPU device, which includes variables created inside
610 TPUStrategy scope, outside compilation logic must not include variable
611 read/write. For variables placed on host, which is the case when variables
612 created via TPUEstimator, variable read/write is only allowed if the variable
613 is not accessed by any other ops in the TPU computation. Variable read/write
614 from outside compilation cluster is not visible from TPU computation and
615 vice versa. Therefore, if outside compilation logic contains such host
616 variables read/write ops and if the variables are accessed by TPU
617 computation as well, then this may lead to deadlock.
619 Internally, `tf.tpu.outside_compilation()` adds outside compilation
620 attributes to all ops in `computation`. During later graph pass, these
621 ops with outside compilation attribute is extracted out and replicated
622 into a host-side graph. Inputs to this extract host-side graph is sent
623 from TPU computation graph to host graph via a pair of XlaSendToHost and
624 XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()`
625 may result in tensor transfer between TPU and CPU, leading to non-trivial
626 performance impact.
628 Args:
629 computation: A Python function that builds the computation to place on the
630 host.
631 *args: the positional arguments for the computation.
632 **kwargs: the keyword arguments for the computation.
634 Returns:
635 The Tensors returned by computation.
636 """
637 args = [] if args is None else args
638 graph = ops.get_default_graph()
640 # If we are in TF 2 functions (control flow V2 functions, or tf.function()),
641 # we need to attach _xla_outside_compilation attribute directly because we are
642 # not in TPUReplicateContext.
643 if isinstance(graph, func_graph.FuncGraph):
644 try:
645 tpu_context, _ = _enclosing_tpu_context_and_graph()
646 except ValueError:
647 logging.warning(
648 "Outside compilation attempted outside TPUReplicateContext "
649 "scope. As no enclosing TPUReplicateContext can be found, "
650 "returning the result of `computation` as is.")
651 return computation(*args, **kwargs)
653 # pylint: disable=protected-access
654 outside_compilation_name = str(tpu_context._outside_compilation_counter)
655 tpu_context._outside_compilation_counter = (
656 tpu_context._outside_compilation_counter + 1)
657 # pylint: enable=protected-access
659 outside_compilation_context = OutsideCompilationV2Context(
660 outside_compilation_name)
661 outside_compilation_context.Enter()
662 args = [] if args is None else args
663 retval = computation(*args, **kwargs)
664 outside_compilation_context.Exit()
665 return retval
667 # If we are in a TPUReplicateContext, signal that we are now
668 # outside_compilation
669 initial_context = graph._get_control_flow_context() # pylint: disable=protected-access
670 context = initial_context
671 while context:
672 if isinstance(context, TPUReplicateContext):
673 context._EnterOutsideCompilationScope() # pylint: disable=protected-access
674 context = context.outer_context
676 retval = computation(*args, **kwargs)
678 # If we are in a TPUReplicateContext, signal that we are no longer
679 # outside_compilation
680 final_context = graph._get_control_flow_context() # pylint: disable=protected-access
681 if initial_context is not final_context:
682 raise NotImplementedError(
683 "Control-flow context cannot be different at start and end of an "
684 "outside_compilation scope")
685 context = initial_context
686 while context:
687 if isinstance(context, TPUReplicateContext):
688 context._ExitOutsideCompilationScope() # pylint: disable=protected-access
689 context = context.outer_context
691 return retval