Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/compiler/xla/xla.py: 20%
240 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""xla is an experimental library that provides XLA support APIs."""
17import contextlib
20from tensorflow.compiler.jit.ops import xla_ops
21from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.python.distribute import summary_op_util
24from tensorflow.python.eager import context
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import compat
32from tensorflow.python.util import nest
33from tensorflow.python.util import tf_inspect
34from tensorflow.python.util.compat import collections_abc
35from tensorflow.python.util.deprecation import deprecated
36from tensorflow.python.util.tf_export import tf_export
38_XLA_COMPILE_ATTR = '_xla_compile_id'
39_MAX_WARNING_LINES = 5
41# Operations that indicate some error in the users graph. For example, XLA
42# computation should not have any Placeholder op.
43_DENYLISTED_OPS = set([
44 'Placeholder',
45])
47# XLA doesn't currently support reading of intermediate tensors, thus some ops
48# are not supported.
49_UNSUPPORTED_OPS = set([
50 'AudioSummary',
51 'AudioSummaryV2',
52 'HistogramSummary',
53 'ImageSummary',
54 'MergeSummary',
55 'Print',
56 'ScalarSummary',
57 'TensorSummary',
58 'TensorSummaryV2',
59])
62@tf_export('xla.experimental.compile')
63@deprecated(
64 None, 'xla.experimental.compile is deprecated. Consider using '
65 '`@tf.function(jit_compile=True)`.',
66 warn_once=True)
67def compile(computation, inputs=None): # pylint: disable=redefined-builtin
68 """Builds an operator that compiles and runs `computation` with XLA.
70 NOTE: In eager mode, `computation` will have `@tf.function` semantics.
72 Args:
73 computation: A Python function that builds a computation to apply to the
74 input. If the function takes n inputs, 'inputs' should be a list of n
75 `Tensor`s.
77 `computation` may return a list of `Tensor`s and `Operation`s.
78 `Tensor`s must come before `Operation`s in the returned list.
80 All `Operation`s returned from `computation` will be executed when
81 evaluating any of the returned output tensors.
82 inputs: A list of inputs or `None` (equivalent to an empty list). Each input
83 can be a nested structure containing values that can be converted to
84 `Tensor`s. Note that passing an N-dimension list of compatible values will
85 result in an N-dimension list of scalar `Tensor`s rather than a single
86 Rank-N `Tensor`. If you need a different behavior, convert parts of
87 `inputs` to `Tensor`s with `tf.convert_to_tensor`.
89 Returns:
90 List of `Tensor`s corresponding to the `Tensor`s from
91 the output of `computation` i.e. the same return value as if
92 computation(*inputs) is called directly, with the following exceptions:
93 * None output: a NoOp would be returned with a control dependency on
94 `computation`.
95 * Single value output: a tuple containing the value would be returned.
96 * Operation-only outputs: a NoOp would be returned with a control
97 dependency on `computation`.
98 TODO(b/121383831): Investigate into removing these special cases.
100 Raises:
101 RuntimeError: When eager execution is enabled.
103 Known issues:
104 When a tf.random operation is built with XLA, the implementation doesn't
105 pass the user provided seed to the XLA compiler. As such, the XLA compiler
106 generates a random number and uses it as a seed when compiling the
107 operation. This implementation causes a violation of the Tensorflow
108 defined semantics in two aspects. First, changing the value of the user
109 defined seed doesn't change the numbers generated by the operation.
110 Second, when a seed is not specified, running the program multiple times
111 will generate the same numbers.
112 """
113 if context.executing_eagerly():
115 @def_function.function
116 def xla_compile_wrapper():
117 return _compile_internal(computation, inputs)
119 return xla_compile_wrapper()
121 return _compile_internal(computation, inputs)
124class XLACompileContext(control_flow_ops.XLAControlFlowContext):
125 """A `ControlFlowContext` for nodes inside an XLA computation cluster.
127 THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
129 The primary role of `XLACompileContext` is to mark operators inside a
130 xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
131 a unique name.
133 `ControlFlowContext` is used to perform the annotation since it integrates
134 with Tensorflow constructs like ResourceVariables. For example, if a
135 `ResourceVariable` is constructed inside a xla.compile() block, the
136 `ResourceVariable` implementation can use
137 `with ops.control_dependencies(None)` to build the variable's definition
138 outside the compiled computation.
139 """
141 def __init__(self, name, pivot):
142 """Builds a new XLACompileContext.
144 Args:
145 name: a unique name for the context, used to populate the
146 `_xla_compile_id` attribute.
147 pivot: a pivot node. Nodes in the XLACompileContext that do not have any
148 inputs will have a control dependency on the pivot node. This ensures
149 that nodes are correctly included in any enclosing control flow
150 contexts.
151 """
152 super(XLACompileContext, self).__init__()
153 self._name = name
154 self._name_as_bytes = compat.as_bytes(name)
155 self._unsupported_ops = []
156 self._pivot = pivot
158 def report_unsupported_operations(self):
159 if self._unsupported_ops:
160 op_str = '\n'.join([
161 ' %s (%s)' % (op.type, op.name)
162 for op in self._unsupported_ops[:_MAX_WARNING_LINES]
163 ])
164 logging.warning('%d unsupported operations found: \n%s',
165 len(self._unsupported_ops), op_str)
166 if len(self._unsupported_ops) > _MAX_WARNING_LINES:
167 logging.warning('... and %d more',
168 len(self._unsupported_ops) - _MAX_WARNING_LINES)
170 def _RemoveExternalControlEdges(self, op):
171 """Remove any external control dependency on this op."""
172 internal_control_inputs = []
173 external_control_inputs = []
174 for x in op.control_inputs:
175 # pylint: disable=protected-access
176 is_internal_op = False
177 ctxt = x._get_control_flow_context()
178 while ctxt is not None:
179 if ctxt == self:
180 is_internal_op = True
181 break
182 ctxt = ctxt._outer_context
183 if is_internal_op:
184 internal_control_inputs.append(x)
185 else:
186 external_control_inputs.append(x)
187 # pylint: enable=protected-access
188 # pylint: disable=protected-access
189 op._remove_all_control_inputs()
190 op._add_control_inputs(internal_control_inputs)
191 # pylint: enable=protected-access
192 return internal_control_inputs, external_control_inputs
194 def AddOp(self, op):
195 """Create op in XLACompileContext and notifies outer context recursively."""
196 # pylint: disable=protected-access
197 if op.type in _DENYLISTED_OPS:
198 logging.error(
199 'Operation of type %s (%s) is not supported in XLA. Execution will '
200 'fail if this op is used in the graph. ', op.type, op.name)
202 # TODO(ycao): Automatically disable summaries instead of reporting them.
203 if op.type in _UNSUPPORTED_OPS:
204 self._unsupported_ops.append(op)
206 if any(x.dtype._is_ref_dtype for x in op.inputs):
207 raise NotImplementedError(
208 'Non-resource Variables are not supported inside XLA computations '
209 '(operator name: %s)' % op.name)
211 if _XLA_COMPILE_ATTR in op.node_def.attr:
212 raise ValueError('XLA compiled computations cannot be nested, (operator '
213 'name: %s)' % op.name)
215 op._set_attr(
216 _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
218 op.graph.prevent_feeding(op)
219 op.graph.prevent_fetching(op)
221 # Remove any control edges from outer control flow contexts. These may cause
222 # mismatched frame errors. An example is when one of op's inputs is
223 # generated in a different While control flow context.
224 (internal_control_inputs,
225 external_control_inputs) = self._RemoveExternalControlEdges(op)
227 if not op.inputs:
228 # Add a control edge from the control pivot to this op.
229 if not internal_control_inputs:
230 # pylint: disable=protected-access
231 op._add_control_input(self._pivot)
232 # pylint: enable=protected-access
233 else:
234 for index in range(len(op.inputs)):
235 x = op.inputs[index]
236 real_x = self.AddValue(x)
237 if real_x is not x:
238 op._update_input(index, real_x) # pylint: disable=protected-access
240 if external_control_inputs:
241 # Use an identity to pull control inputs as data inputs. Note that we
242 # ignore ops which don't have outputs. TODO(phawkins): fix that.
243 with ops.control_dependencies(None):
244 self.Enter()
245 external_control_inputs = [
246 array_ops.identity(x.outputs[0]).op
247 for x in external_control_inputs
248 if x.outputs
249 ]
250 self.Exit()
251 # pylint: disable=protected-access
252 op._add_control_inputs(external_control_inputs)
253 # pylint: enable=protected-access
255 # Mark op's outputs as seen by this context and any outer contexts.
256 output_names = [x.name for x in op.outputs]
257 context = self
258 while context is not None:
259 # pylint: disable=protected-access
260 context._values.update(output_names)
261 context = context._outer_context
262 # pylint: enable=protected-access
264 if self._outer_context:
265 self._outer_context.AddInnerOp(op)
267 def AddValue(self, val):
268 """Add `val` to the current context and its outer context recursively."""
269 if val.name in self._values:
270 # Use the real value if it comes from outer context.
271 result = self._external_values.get(val.name)
272 return val if result is None else result
274 result = val
275 self._values.add(val.name)
276 if self._outer_context:
277 result = self._outer_context.AddValue(val)
278 self._values.add(result.name)
280 self._external_values[val.name] = result
282 return result
284 def AddInnerOp(self, op):
285 self.AddOp(op)
286 if self._outer_context:
287 self._outer_context.AddInnerOp(op)
289 @property
290 def grad_state(self):
291 # Define the gradient loop state associated with the XLACompileContext to
292 # be None as the XLACompileContext does not get nested nor does the
293 # grad_state outside the XLACompileContext affect the graph inside so the
294 # grad_state should be as if this is the top-level gradient state.
295 return None
297 @property
298 def back_prop(self):
299 """Forwards to the enclosing while context, if any."""
300 if self.GetWhileContext():
301 return self.GetWhileContext().back_prop
302 return False
305def _compile_internal(computation, inputs=None):
306 """Builds graph operators that compiles and symbolically executes computation.
308 Args:
309 computation: A Python function that builds the computation to compile and
310 execute.
311 inputs: A list of inputs or `None` (equivalent to an empty list). Each input
312 can be a nested structure containing values that are convertible to
313 tensors. Note that passing an N-dimension list of compatible values will
314 result in a N-dimension list of scalar tensors rather than a single Rank-N
315 tensors. If you need different behavior, convert part of inputs to tensors
316 with `tf.convert_to_tensor`.
318 Returns:
319 Same data structure as if computation(*inputs) is called directly with some
320 exceptions for correctness. Exceptions include: 1) None output 2) Single
321 value output 3) Operation-only outputs
322 Raises:
323 ValueError: If any element in computation outputs is neither an operations
324 or a value that can be converted to tensor.
325 ValueError: If computation outputs is non-flat and contains any Operations.
326 TypeError: If `inputs` is not a list or tuple.
327 """
328 if inputs is None:
329 inputs = []
331 if not isinstance(inputs, collections_abc.Sequence):
332 raise TypeError('inputs must be a list')
334 # Flatten inputs.
335 flat_inputs = nest.flatten(inputs)
336 # Converts inputs to Tensors.
337 flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
339 cluster_name = ops.get_default_graph().unique_name('cluster')
340 pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
341 context = XLACompileContext(name=cluster_name, pivot=pivot)
342 try:
343 context.Enter()
345 # Add identity ops so even unused inputs are 'consumed' by the
346 # computation.
347 flat_inputs = [
348 array_ops.identity(x, name='input_{}'.format(i))
349 for i, x in enumerate(flat_inputs)
350 ]
352 # Re-pack flat_inputs in same structure as 'inputs'.
353 computation_inputs = nest.pack_sequence_as(
354 structure=inputs, flat_sequence=flat_inputs)
356 # Only resource variables work inside an XLA computation, so turn on
357 # resource variables for the computation.
358 vscope = variable_scope.get_variable_scope()
359 saved_use_resource = vscope.use_resource
360 vscope.set_use_resource(True)
362 with _disable_summary_context():
363 outputs = computation(*computation_inputs)
365 # Restore variable scope after computation.
366 vscope.set_use_resource(saved_use_resource)
368 outputs_is_flat = is_flat(outputs)
369 if outputs_is_flat:
370 output_tensors, control_deps = _postprocess_flat_outputs(outputs)
371 else:
372 output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
374 context.ExitResult(output_tensors)
375 finally:
376 context.report_unsupported_operations()
377 context.Exit()
379 # When XLA computation returns only operations and no tensors, a NoOp
380 # dependent on the operations in outputs is returned. Otherwise final
381 # outputs would be empty and there is no way to trigger returned
382 # operations.
383 if not output_tensors:
384 return control_flow_ops.group(control_deps, name='output_0')
386 output_tensors = [
387 xla_ops.xla_cluster_output(o, name='output{}'.format(i))
388 for i, o in enumerate(output_tensors)
389 ]
391 with ops.control_dependencies(control_deps):
392 # Wraps the outputs in identity operators that carries control
393 # dependencies.
394 output_tensors = [
395 array_ops.identity(o, name='output_%d' % i)
396 for i, o in enumerate(output_tensors)
397 ]
399 # If `computation` returned non-flat output structure, pack output tensors
400 # back into same structure.
401 if not outputs_is_flat:
402 output_tensors = nest.pack_sequence_as(
403 structure=outputs, flat_sequence=output_tensors)
405 return output_tensors
408def is_flat(outputs):
409 """Checks if outputs is a flat structure.
411 Following structures and values are considered flat:
412 1) None
413 2) A single object
414 3) A list or tuple of Tensors/Operations
416 The only structures that this function understands are sequences,
417 dictionaries and types defined using the attrs library. E.g. this means
418 that if outputs contains a single user-defined Object, it is considered to
419 be flat. Errors are raised later on if that Object cannot be converted to a
420 Tensor.
422 Args:
423 outputs: Output from `computation` inside `xla.compile`.
425 Returns:
426 A boolean indicates whether outputs is flat.
427 """
428 # If outputs is a list or tuple, check if it has any nested structure. If
429 # there is, then outputs is non-flat.
430 if isinstance(outputs, collections_abc.Sequence):
431 for o in outputs:
432 if (isinstance(o, collections_abc.Sequence) or
433 isinstance(o, collections_abc.Mapping) or
434 hasattr(o.__class__, '__attrs_attrs__')):
435 return False
437 # If outputs is a dict, it is non-flat.
438 if isinstance(outputs, collections_abc.Mapping):
439 return False
441 # If outputs is from the attrs library, it is non-flat.
442 if hasattr(outputs.__class__, '__attrs_attrs__'):
443 return False
445 # Getting here means either outputs itself is a single non-structured value
446 # or it is a flat list of single non-structured values.
447 return True
450def _postprocess_flat_outputs(outputs):
451 """Validates flat outputs and adds back device assignments.
453 Args:
454 outputs: Output from `computation` inside `xla.compile`.
456 Returns:
457 Tensors and Operations extracted from outputs.
458 """
459 # Following code segment is to preserve legacy behavior. Previously we only
460 # supported flat outputs and thus for consistency it was nice to convert even
461 # single element into a tuple. But now that we support arbitrary output
462 # structure, this is no longer necessary.
463 # TODO(b/121383831): Migrate all legacy use cases and delete this special
464 # case.
465 # If the computation returns `None`, make it an empty tuple.
466 if outputs is None:
467 outputs = tuple()
468 # If the computation only returned one value, make it a tuple.
469 if not isinstance(outputs, collections_abc.Sequence):
470 outputs = (outputs,)
472 # Append `no_op` here so that return value of this function always contains
473 # at least one op that can trigger XlaLaunch node.
474 outputs += (control_flow_ops.no_op(),)
475 try:
476 outputs = [
477 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
478 for o in outputs
479 ]
480 except Exception as e:
481 raise ValueError(
482 'XLA computation function return values must all either be Operations'
483 ' or convertible to Tensors. Got error: "%s"' % str(e))
485 # Separates the returned Operations and Tensors.
486 output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
487 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
489 if outputs != output_tensors + output_operations:
490 raise ValueError(
491 'XLA computation function must return zero or more Tensor values '
492 'followed by zero or more Operations.')
494 new_output_tensors = []
495 for t in output_tensors:
496 with ops.device(t.device if t.device else ''):
497 new_output_tensors.append(array_ops.identity(t))
499 return new_output_tensors, output_operations
502def _postprocess_non_flat_outputs(outputs):
503 """Validates non-flat outputs and adds back device assignments.
505 Args:
506 outputs: Output from `computation` inside `xla.compile`.
508 Returns:
509 Tensors extracted from outputs and an empty list because Operations are not
510 allowed in non-flat outputs..
511 """
512 # Convert all non-Operation outputs to Tensors.
513 new_output_tensors = []
514 for o in nest.flatten(outputs):
515 if isinstance(o, ops.Operation):
516 raise ValueError(
517 'xla.compile does not support Operation as return value in non-flat '
518 'output structure. You can set returned Operations as control '
519 'dependencies of returned Tensors so Operations are triggered when '
520 'Tensors are evaluated. Operation found: "%s"' % o.name)
522 try:
523 o = ops.convert_to_tensor(o)
524 except Exception as e:
525 raise ValueError(
526 'XLA computation function return values must all either be '
527 'Operations or convertible to Tensors. Got error: "%s"' % str(e))
529 # Makes sure even pass-through inputs/outputs are touched in compile
530 # context by creating an Identity node inside compile context.
531 with ops.device(o.device if o.device else ''):
532 new_output_tensors.append(array_ops.identity(o))
534 return new_output_tensors, []
537@contextlib.contextmanager
538def _disable_summary_context():
539 """Enters a context where all summary ops are skipped.
541 Summaries are not yet supported in xla.compile(). So we provide this context
542 manager that can skip creating summary ops. This is a temporary workaround due
543 to XLA not supporting summary ops.
545 Yields:
546 None.
547 """
548 original_skip_summary_func = summary_op_util.skip_summary
549 summary_op_util.skip_summary = lambda: True
551 try:
552 yield
553 finally:
554 summary_op_util.skip_summary = original_skip_summary_func
557class _CapturedObject(object):
558 """A placeholder to capture an object."""
560 def __init__(self):
561 self._object = None
563 def capture(self, o):
564 if self._object:
565 raise RuntimeError(
566 'InternalError: _CapturedObject can capture only once. Please file '
567 'bug.')
569 self._object = o
571 def get(self):
572 return self._object
575def _get_scaffold(captured_scaffold_fn):
576 """Retrieves the Scaffold from `captured_scaffold_fn`."""
577 scaffold_fn = captured_scaffold_fn.get()
579 if not scaffold_fn:
580 return None
582 scaffold = scaffold_fn()
583 if scaffold is None:
584 raise ValueError(
585 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
587 return scaffold
590def check_function_argument_count(func, input_arity, infeed_queue):
591 """Validate the number of input arguments to an XLA function.
593 Args:
594 func: the Python function that will be called to generate the body of an XLA
595 computation graph.
596 input_arity: the number of explicit arguments supplied by the caller.
597 infeed_queue: if not None, the infeed queue that will supply
598 additional arguments to the function.
600 Returns:
601 None if function can be called with the supplied number of
602 arguments, or an error string if it cannot.
603 """
604 def format_error(complaint, quantity):
605 return '%s %d argument%s' % (complaint, quantity, ''
606 if quantity == 1 else 's')
608 num_args_supplied = input_arity
609 if infeed_queue is not None:
610 num_args_supplied += infeed_queue.number_of_tuple_elements
611 arg_spec = tf_inspect.getargspec(func)
612 num_func_args = len(arg_spec.args)
613 if arg_spec.defaults is None:
614 num_func_defaults = 0
615 else:
616 num_func_defaults = len(arg_spec.defaults)
617 min_func_args = num_func_args - num_func_defaults
618 if num_args_supplied < min_func_args:
619 # The required number of arguments is not enough to call the function.
620 if num_func_defaults == 0 and arg_spec.varargs is None:
621 return format_error('exactly', num_func_args)
622 else:
623 return format_error('at least', min_func_args)
624 if arg_spec.varargs is None and num_args_supplied > num_func_args:
625 # The required number of arguments is too many to call the function.
626 if num_func_defaults == 0:
627 return format_error('exactly', num_func_args)
628 else:
629 return format_error('at most', num_func_args)
630 # Reaching here means either
631 # 1) There are varargs, func can accept any number of arguments greater than
632 # the minimum.
633 # 2) Number of supplied arguments falls in range of acceptable argument count
634 # of func.
635 return None