Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py: 33%
2184 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions used to construct graphs."""
16# pylint: disable=g-bad-name
17import collections
18import copy
19import re
20import sys
21import threading
22import types
23from typing import Optional
24from absl import app
26import numpy as np
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.core.framework import full_type_pb2
30from tensorflow.core.framework import function_pb2
31from tensorflow.core.framework import graph_pb2
32from tensorflow.core.framework import node_def_pb2
33from tensorflow.core.framework import op_def_pb2
34from tensorflow.core.framework import versions_pb2
35from tensorflow.core.protobuf import config_pb2
36# pywrap_tensorflow must be imported first to avoid protobuf issues.
37# (b/143110113)
38# pylint: disable=invalid-import-order,g-bad-import-order,unused-import
39from tensorflow.python import pywrap_tensorflow
40from tensorflow.python import pywrap_tfe
41# pylint: enable=invalid-import-order,g-bad-import-order,unused-import
42from tensorflow.python import tf2
43from tensorflow.python.client import pywrap_tf_session
44from tensorflow.python.eager import context
45from tensorflow.python.eager import core
46from tensorflow.python.eager import monitoring
47from tensorflow.python.eager import record
48from tensorflow.python.framework import c_api_util
49from tensorflow.python.framework import composite_tensor
50from tensorflow.python.framework import device as pydev
51from tensorflow.python.framework import dtypes
52from tensorflow.python.framework import errors
53from tensorflow.python.framework import op_callbacks
54from tensorflow.python.framework import registry
55from tensorflow.python.framework import stack
56from tensorflow.python.framework import tensor_conversion_registry
57from tensorflow.python.framework import tensor_shape
58from tensorflow.python.framework import tensor_util
59from tensorflow.python.framework import traceable_stack
60from tensorflow.python.framework import versions
61from tensorflow.python.ops import control_flow_util
62from tensorflow.python.ops import handle_data_util
63from tensorflow.python.platform import tf_logging as logging
64from tensorflow.python.profiler import trace as profiler_trace
65from tensorflow.python.types import core as core_tf_types
66from tensorflow.python.types import internal
67from tensorflow.python.util import compat
68from tensorflow.python.util import decorator_utils
69from tensorflow.python.util import deprecation
70from tensorflow.python.util import function_utils
71from tensorflow.python.util import lock_util
72from tensorflow.python.util import object_identity
73from tensorflow.python.util import tf_contextlib
74from tensorflow.python.util import tf_stack
75from tensorflow.python.util import traceback_utils
76from tensorflow.python.util.compat import collections_abc
77from tensorflow.python.util.deprecation import deprecated_args
78from tensorflow.python.util.lazy_loader import LazyLoader
79from tensorflow.python.util.tf_export import kwarg_only
80from tensorflow.python.util.tf_export import tf_export
82# TODO(b/218887885): Loaded lazily due to a circular dependency with this file.
83tensor_spec = LazyLoader(
84 "tensor_spec", globals(),
85 "tensorflow.python.framework.tensor_spec")
86ag_ctx = LazyLoader(
87 "ag_ctx", globals(),
88 "tensorflow.python.autograph.core.ag_ctx")
91# Temporary global switches determining if we should enable the work-in-progress
92# calls to the C API. These will be removed once all functionality is supported.
93_USE_C_API = True
94_USE_C_SHAPES = True
96_api_usage_gauge = monitoring.BoolGauge(
97 "/tensorflow/api/ops_eager_execution",
98 "Whether ops.enable_eager_execution() is called.")
100_tensor_equality_api_usage_gauge = monitoring.BoolGauge(
101 "/tensorflow/api/enable_tensor_equality",
102 "Whether ops.enable_tensor_equality() is called.")
104_control_flow_api_gauge = monitoring.BoolGauge(
105 "/tensorflow/api/enable_control_flow_v2",
106 "Whether enable_control_flow_v2() is called.")
108_tf_function_api_gauge = monitoring.BoolGauge(
109 "/tensorflow/api/tf_function",
110 "Whether tf.function() is used.")
112# pylint: disable=protected-access
113_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE
114# pylint: enable=protected-access
117def tensor_id(tensor):
118 """Returns a unique identifier for this Tensor."""
119 return tensor._id # pylint: disable=protected-access
122class _UserDeviceSpec(object):
123 """Store user-specified device and provide computation of merged device."""
125 def __init__(self, device_name_or_function):
126 self._device_name_or_function = device_name_or_function
127 self.display_name = str(self._device_name_or_function)
128 self.function = device_name_or_function
129 self.raw_string = None
131 if isinstance(device_name_or_function, pydev.MergeDevice):
132 self.is_null_merge = device_name_or_function.is_null_merge
134 elif callable(device_name_or_function):
135 self.is_null_merge = False
136 dev_func = self._device_name_or_function
137 func_name = function_utils.get_func_name(dev_func)
138 func_code = function_utils.get_func_code(dev_func)
139 if func_code:
140 fname = func_code.co_filename
141 lineno = func_code.co_firstlineno
142 else:
143 fname = "unknown"
144 lineno = -1
145 self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
147 elif device_name_or_function is None:
148 # NOTE(taylorrobie): This MUST be False. None signals a break in the
149 # device stack, so `is_null_merge` must be False for such a case to
150 # allow callers to safely skip over null merges without missing a None.
151 self.is_null_merge = False
153 else:
154 self.raw_string = device_name_or_function
155 self.function = pydev.merge_device(device_name_or_function)
156 self.is_null_merge = self.function.is_null_merge
158 # We perform this check in __init__ because it is of non-trivial cost,
159 # and self.string_merge is typically called many times.
160 self.fast_string_merge = isinstance(self.function, pydev.MergeDevice)
162 def string_merge(self, node_def):
163 if self.fast_string_merge:
164 return self.function.shortcut_string_merge(node_def)
166 return compat.as_str(_device_string(self.function(node_def)))
169class NullContextmanager(object):
171 def __init__(self, *args, **kwargs):
172 pass
174 def __enter__(self):
175 pass
177 def __exit__(self, type_arg, value_arg, traceback_arg):
178 return False # False values do not suppress exceptions
181def _override_helper(clazz_object, operator, func):
182 """Overrides (string) operator on Tensors to call func.
184 Args:
185 clazz_object: the class to override for; either Tensor or SparseTensor.
186 operator: the string name of the operator to override.
187 func: the function that replaces the overridden operator.
189 Raises:
190 ValueError: If operator is not allowed to be overwritten.
191 """
192 if operator not in Tensor.OVERLOADABLE_OPERATORS:
193 raise ValueError(f"Overriding {operator} is disallowed. "
194 f"Allowed operators are {Tensor.OVERLOADABLE_OPERATORS}.")
195 setattr(clazz_object, operator, func)
198def _as_graph_element(obj):
199 """Convert `obj` to a graph element if possible, otherwise return `None`.
201 Args:
202 obj: Object to convert.
204 Returns:
205 The result of `obj._as_graph_element()` if that method is available;
206 otherwise `None`.
207 """
208 conv_fn = getattr(obj, "_as_graph_element", None)
209 if conv_fn and callable(conv_fn):
210 return conv_fn()
211 return None
214# Deprecated - do not use.
215# This API to avoid breaking estimator and tensorflow-mesh which depend on this
216# internal API. The stub should be safe to use after TF 2.3 is released.
217def is_dense_tensor_like(t):
218 return isinstance(t, core_tf_types.Tensor)
221def uid():
222 """A unique (within this program execution) integer."""
223 return pywrap_tfe.TFE_Py_UID()
226def numpy_text(tensor, is_repr=False):
227 """Human readable representation of a tensor's numpy value."""
228 if tensor.dtype.is_numpy_compatible:
229 # pylint: disable=protected-access
230 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
231 # pylint: enable=protected-access
232 else:
233 text = "<unprintable>"
234 if "\n" in text:
235 text = "\n" + text
236 return text
239def value_text(tensor, is_repr=False):
240 """Either the NumPy value or a custom TensorFlow formatting of `tensor`.
242 Custom formatting is used for custom device tensors, e.g. parallel tensors
243 with multiple components on different devices.
245 Args:
246 tensor: The tensor to format.
247 is_repr: Controls the style/verbosity of formatting.
249 Returns:
250 The formatted tensor.
251 """
252 # pylint: disable=protected-access # friend access
253 if tensor._prefer_custom_summarizer():
254 text = tensor._summarize_value()
255 # pylint: enable=protected-access
256 if is_repr:
257 text = "value=" + text
258 else:
259 text = numpy_text(tensor, is_repr=is_repr)
260 if is_repr:
261 text = "numpy=" + text
262 return text
265@tf_export(v1=["enable_tensor_equality"])
266def enable_tensor_equality():
267 """Compare Tensors with element-wise comparison and thus be unhashable.
269 Comparing tensors with element-wise allows comparisons such as
270 tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are
271 unhashable. Thus tensors can no longer be directly used in sets or as a key in
272 a dictionary.
273 """
274 logging.vlog(1, "Enabling tensor equality")
275 _tensor_equality_api_usage_gauge.get_cell().set(True)
276 Tensor._USE_EQUALITY = True # pylint: disable=protected-access
279@tf_export(v1=["disable_tensor_equality"])
280def disable_tensor_equality():
281 """Compare Tensors by their id and be hashable.
283 This is a legacy behaviour of TensorFlow and is highly discouraged.
284 """
285 logging.vlog(1, "Disabling tensor equality")
286 _tensor_equality_api_usage_gauge.get_cell().set(False)
287 Tensor._USE_EQUALITY = False # pylint: disable=protected-access
290# Tensor subclassing has historically been a mess.
291#
292# There is no "Tensor" base class for Graph & Eager tensors. Instead, when we
293# introduced EagerTensor, we had it subclass the graph "Tensor" class, and
294# override a bunch of behavior. Introducing a proper subclassing relationship
295# is complicated because many users check for type(t) == Tensor of isinstance.
296#
297# This is done internally for "bad" reasons as a way to separate out Graph and
298# Eager tensors, or subclasses which "look like" Tensor, e.g. distribute.Value.
299#
300# For now, we work around this by deferring initialization of graph tensors to
301# a separate `_init` method. `GraphTensor` is a free function, not a class, that
302# returns a Tensor object.
303#
304# b(XXX) -- fix type(t) == Tensor checks in the code base
305@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"])
306class Tensor(
307 pywrap_tf_session.PyTensor, internal.NativeObject, core_tf_types.Symbol
308):
309 """A `tf.Tensor` represents a multidimensional array of elements.
311 All elements are of a single known data type.
313 When writing a TensorFlow program, the main object that is
314 manipulated and passed around is the `tf.Tensor`.
316 A `tf.Tensor` has the following properties:
318 * a single data type (float32, int32, or string, for example)
319 * a shape
321 TensorFlow supports eager execution and graph execution. In eager
322 execution, operations are evaluated immediately. In graph
323 execution, a computational graph is constructed for later
324 evaluation.
326 TensorFlow defaults to eager execution. In the example below, the
327 matrix multiplication results are calculated immediately.
329 >>> # Compute some values using a Tensor
330 >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
331 >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
332 >>> e = tf.matmul(c, d)
333 >>> print(e)
334 tf.Tensor(
335 [[1. 3.]
336 [3. 7.]], shape=(2, 2), dtype=float32)
338 Note that during eager execution, you may discover your `Tensors` are actually
339 of type `EagerTensor`. This is an internal detail, but it does give you
340 access to a useful function, `numpy`:
342 >>> type(e)
343 <class '...ops.EagerTensor'>
344 >>> print(e.numpy())
345 [[1. 3.]
346 [3. 7.]]
348 In TensorFlow, `tf.function`s are a common way to define graph execution.
350 A Tensor's shape (that is, the rank of the Tensor and the size of
351 each dimension) may not always be fully known. In `tf.function`
352 definitions, the shape may only be partially known.
354 Most operations produce tensors of fully-known shapes if the shapes of their
355 inputs are also fully known, but in some cases it's only possible to find the
356 shape of a tensor at execution time.
358 A number of specialized tensors are available: see `tf.Variable`,
359 `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and
360 `tf.RaggedTensor`.
362 Caution: when constructing a tensor from a numpy array or pandas dataframe
363 the underlying buffer may be re-used:
365 ```python
366 a = np.array([1, 2, 3])
367 b = tf.constant(a)
368 a[0] = 4
369 print(b) # tf.Tensor([4 2 3], shape=(3,), dtype=int64)
370 ```
372 Note: this is an implementation detail that is subject to change and users
373 should not rely on this behaviour.
375 For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor).
376 """
377 # List of Python operators that we allow to override.
378 OVERLOADABLE_OPERATORS = {
379 # Binary.
380 "__add__",
381 "__radd__",
382 "__sub__",
383 "__rsub__",
384 "__mul__",
385 "__rmul__",
386 "__div__",
387 "__rdiv__",
388 "__truediv__",
389 "__rtruediv__",
390 "__floordiv__",
391 "__rfloordiv__",
392 "__mod__",
393 "__rmod__",
394 "__lt__",
395 "__le__",
396 "__gt__",
397 "__ge__",
398 "__ne__",
399 "__eq__",
400 "__and__",
401 "__rand__",
402 "__or__",
403 "__ror__",
404 "__xor__",
405 "__rxor__",
406 "__getitem__",
407 "__pow__",
408 "__rpow__",
409 # Unary.
410 "__invert__",
411 "__neg__",
412 "__abs__",
413 "__matmul__",
414 "__rmatmul__"
415 }
417 # Whether to allow hashing or numpy-style equality
418 _USE_EQUALITY = tf2.enabled()
420 def __getattr__(self, name):
421 if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size",
422 "tolist", "data"}:
423 # TODO(wangpeng): Export the enable_numpy_behavior knob
424 raise AttributeError(
425 f"{type(self).__name__} object has no attribute '{name}'. " + """
426 If you are looking for numpy-related methods, please run the following:
427 from tensorflow.python.ops.numpy_ops import np_config
428 np_config.enable_numpy_behavior()
429 """)
430 self.__getattribute__(name)
432 @property
433 def dtype(self):
434 """The `DType` of elements in this tensor."""
435 return self._dtype
437 @property
438 def name(self):
439 """The string name of this tensor."""
440 if self._name is None:
441 assert self._op.name
442 self._name = "%s:%d" % (self._op.name, self.value_index)
443 return self._name
445 @property
446 def shape(self):
447 """Returns a `tf.TensorShape` that represents the shape of this tensor.
449 >>> t = tf.constant([1,2,3,4,5])
450 >>> t.shape
451 TensorShape([5])
453 `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`.
455 In a `tf.function` or when building a model using
456 `tf.keras.Input`, they return the build-time shape of the
457 tensor, which may be partially unknown.
459 A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor
460 containing the shape, calculated at runtime.
462 See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples.
463 """
464 if self._shape_val is None:
465 dims, unknown_shape = self._shape
466 if unknown_shape:
467 self._shape_val = tensor_shape.unknown_shape()
468 else:
469 self._shape_val = tensor_shape.TensorShape(dims)
470 return self._shape_val
472 @property
473 def ndim(self):
474 return self.shape.rank
476 def _disallow_when_autograph_unavailable(self, task):
477 raise errors.OperatorNotAllowedInGraphError(
478 f"{task} is not allowed: AutoGraph is unavailable in this runtime. See"
479 " https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code"
480 " for more information.")
482 def _disallow_when_autograph_disabled(self, task):
483 raise errors.OperatorNotAllowedInGraphError(
484 f"{task} is not allowed: AutoGraph is disabled in this function."
485 " Try decorating it directly with @tf.function.")
487 def _disallow_when_autograph_enabled(self, task):
488 raise errors.OperatorNotAllowedInGraphError(
489 f"{task} is not allowed: AutoGraph did convert this function. This"
490 " might indicate you are trying to use an unsupported feature.")
492 def _disallow_in_graph_mode(self, task):
493 raise errors.OperatorNotAllowedInGraphError(
494 f"{task} is not allowed in Graph execution. Use Eager execution or"
495 " decorate this function with @tf.function.")
497 def _disallow_bool_casting(self):
498 if not ag_ctx.INSPECT_SOURCE_SUPPORTED:
499 self._disallow_when_autograph_unavailable(
500 "Using a symbolic `tf.Tensor` as a Python `bool`")
501 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
502 self._disallow_when_autograph_disabled(
503 "Using a symbolic `tf.Tensor` as a Python `bool`")
504 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
505 self._disallow_when_autograph_enabled(
506 "Using a symbolic `tf.Tensor` as a Python `bool`")
507 else:
508 # Default: V1-style Graph execution.
509 self._disallow_in_graph_mode(
510 "Using a symbolic `tf.Tensor` as a Python `bool`")
512 def _disallow_iteration(self):
513 if not ag_ctx.INSPECT_SOURCE_SUPPORTED:
514 self._disallow_when_autograph_unavailable(
515 "Iterating over a symbolic `tf.Tensor`")
516 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
517 self._disallow_when_autograph_disabled(
518 "Iterating over a symbolic `tf.Tensor`")
519 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
520 self._disallow_when_autograph_enabled(
521 "Iterating over a symbolic `tf.Tensor`")
522 else:
523 # Default: V1-style Graph execution.
524 self._disallow_in_graph_mode("Iterating over a symbolic `tf.Tensor`")
526 def __iter__(self):
527 if not context.executing_eagerly():
528 self._disallow_iteration()
530 shape = self._shape_tuple()
531 if shape is None:
532 raise TypeError("Cannot iterate over a tensor with unknown shape.")
533 if not shape:
534 raise TypeError("Cannot iterate over a scalar tensor.")
535 if shape[0] is None:
536 raise TypeError(
537 "Cannot iterate over a tensor with unknown first dimension.")
538 return _TensorIterator(self, shape[0])
540 def _shape_as_list(self):
541 if self.shape.ndims is not None:
542 return [dim.value for dim in self.shape.dims]
543 else:
544 return None
546 def _shape_tuple(self):
547 shape = self._shape_as_list()
548 if shape is None:
549 return None
550 return tuple(shape)
552 def _record_tape(self, capture):
553 """Connect this graph tensor with capture for gradients calculation."""
554 record.record_operation(
555 "captured_value",
556 [self], [capture],
557 backward_function=lambda x: [x],
558 forward_function=lambda x: [x])
560 def get_shape(self):
561 """Returns a `tf.TensorShape` that represents the shape of this tensor.
563 In eager execution the shape is always fully-known.
565 >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
566 >>> print(a.shape)
567 (2, 3)
569 `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`.
572 When executing in a `tf.function` or building a model using
573 `tf.keras.Input`, `Tensor.shape` may return a partial shape (including
574 `None` for unknown dimensions). See `tf.TensorShape` for more details.
576 >>> inputs = tf.keras.Input(shape = [10])
577 >>> # Unknown batch size
578 >>> print(inputs.shape)
579 (None, 10)
581 The shape is computed using shape inference functions that are
582 registered for each `tf.Operation`.
584 The returned `tf.TensorShape` is determined at *build* time, without
585 executing the underlying kernel. It is not a `tf.Tensor`. If you need a
586 shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or
587 use the `tf.shape(tensor)` function, which returns the tensor's shape at
588 *execution* time.
590 This is useful for debugging and providing early errors. For
591 example, when tracing a `tf.function`, no ops are being executed, shapes
592 may be unknown (See the [Concrete Functions
593 Guide](https://www.tensorflow.org/guide/concrete_function) for details).
595 >>> @tf.function
596 ... def my_matmul(a, b):
597 ... result = a@b
598 ... # the `print` executes during tracing.
599 ... print("Result shape: ", result.shape)
600 ... return result
602 The shape inference functions propagate shapes to the extent possible:
604 >>> f = my_matmul.get_concrete_function(
605 ... tf.TensorSpec([None,3]),
606 ... tf.TensorSpec([3,5]))
607 Result shape: (None, 5)
609 Tracing may fail if a shape missmatch can be detected:
611 >>> cf = my_matmul.get_concrete_function(
612 ... tf.TensorSpec([None,3]),
613 ... tf.TensorSpec([4,5]))
614 Traceback (most recent call last):
615 ...
616 ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op:
617 'MatMul') with input shapes: [?,3], [4,5].
619 In some cases, the inferred shape may have unknown dimensions. If
620 the caller has additional information about the values of these
621 dimensions, `tf.ensure_shape` or `Tensor.set_shape()` can be used to augment
622 the inferred shape.
624 >>> @tf.function
625 ... def my_fun(a):
626 ... a = tf.ensure_shape(a, [5, 5])
627 ... # the `print` executes during tracing.
628 ... print("Result shape: ", a.shape)
629 ... return a
631 >>> cf = my_fun.get_concrete_function(
632 ... tf.TensorSpec([None, None]))
633 Result shape: (5, 5)
635 Returns:
636 A `tf.TensorShape` representing the shape of this tensor.
638 """
639 return self.shape
641 def set_shape(self, shape):
642 """Updates the shape of this tensor.
644 Note: It is recommended to use `tf.ensure_shape` instead of
645 `Tensor.set_shape`, because `tf.ensure_shape` provides better checking for
646 programming errors and can create guarantees for compiler
647 optimization.
649 With eager execution this operates as a shape assertion.
650 Here the shapes match:
652 >>> t = tf.constant([[1,2,3]])
653 >>> t.set_shape([1, 3])
655 Passing a `None` in the new shape allows any value for that axis:
657 >>> t.set_shape([1,None])
659 An error is raised if an incompatible shape is passed.
661 >>> t.set_shape([1,5])
662 Traceback (most recent call last):
663 ...
664 ValueError: Tensor's shape (1, 3) is not compatible with supplied
665 shape [1, 5]
667 When executing in a `tf.function`, or building a model using
668 `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with
669 the current shape of this tensor, and set the tensor's shape to the
670 merged value (see `tf.TensorShape.merge_with` for details):
672 >>> t = tf.keras.Input(shape=[None, None, 3])
673 >>> print(t.shape)
674 (None, None, None, 3)
676 Dimensions set to `None` are not updated:
678 >>> t.set_shape([None, 224, 224, None])
679 >>> print(t.shape)
680 (None, 224, 224, 3)
682 The main use case for this is to provide additional shape information
683 that cannot be inferred from the graph alone.
685 For example if you know all the images in a dataset have shape [28,28,3] you
686 can set it with `tf.set_shape`:
688 >>> @tf.function
689 ... def load_image(filename):
690 ... raw = tf.io.read_file(filename)
691 ... image = tf.image.decode_png(raw, channels=3)
692 ... # the `print` executes during tracing.
693 ... print("Initial shape: ", image.shape)
694 ... image.set_shape([28, 28, 3])
695 ... print("Final shape: ", image.shape)
696 ... return image
698 Trace the function, see the [Concrete Functions
699 Guide](https://www.tensorflow.org/guide/concrete_function) for details.
701 >>> cf = load_image.get_concrete_function(
702 ... tf.TensorSpec([], dtype=tf.string))
703 Initial shape: (None, None, 3)
704 Final shape: (28, 28, 3)
706 Similarly the `tf.io.parse_tensor` function could return a tensor with
707 any shape, even the `tf.rank` is unknown. If you know that all your
708 serialized tensors will be 2d, set it with `set_shape`:
710 >>> @tf.function
711 ... def my_parse(string_tensor):
712 ... result = tf.io.parse_tensor(string_tensor, out_type=tf.float32)
713 ... # the `print` executes during tracing.
714 ... print("Initial shape: ", result.shape)
715 ... result.set_shape([None, None])
716 ... print("Final shape: ", result.shape)
717 ... return result
719 Trace the function
721 >>> concrete_parse = my_parse.get_concrete_function(
722 ... tf.TensorSpec([], dtype=tf.string))
723 Initial shape: <unknown>
724 Final shape: (None, None)
726 Make sure it works:
728 >>> t = tf.ones([5,3], dtype=tf.float32)
729 >>> serialized = tf.io.serialize_tensor(t)
730 >>> print(serialized.dtype)
731 <dtype: 'string'>
732 >>> print(serialized.shape)
733 ()
734 >>> t2 = concrete_parse(serialized)
735 >>> print(t2.shape)
736 (5, 3)
738 Caution: `set_shape` ensures that the applied shape is compatible with
739 the existing shape, but it does not check at runtime. Setting
740 incorrect shapes can result in inconsistencies between the
741 statically-known graph and the runtime value of tensors. For runtime
742 validation of the shape, use `tf.ensure_shape` instead. It also modifies
743 the `shape` of the tensor.
745 >>> # Serialize a rank-3 tensor
746 >>> t = tf.ones([5,5,5], dtype=tf.float32)
747 >>> serialized = tf.io.serialize_tensor(t)
748 >>> # The function still runs, even though it `set_shape([None,None])`
749 >>> t2 = concrete_parse(serialized)
750 >>> print(t2.shape)
751 (5, 5, 5)
753 Args:
754 shape: A `TensorShape` representing the shape of this tensor, a
755 `TensorShapeProto`, a list, a tuple, or None.
757 Raises:
758 ValueError: If `shape` is not compatible with the current shape of
759 this tensor.
760 """
761 # Reset cached shape.
762 self._shape_val = None
764 # We want set_shape to be reflected in the C API graph for when we run it.
765 if not isinstance(shape, tensor_shape.TensorShape):
766 shape = tensor_shape.TensorShape(shape)
767 dim_list = []
768 if shape.dims is None:
769 unknown_shape = True
770 else:
771 unknown_shape = False
772 for dim in shape.dims:
773 if dim.value is None:
774 dim_list.append(-1)
775 else:
776 dim_list.append(dim.value)
777 self._set_shape(dim_list, unknown_shape)
779 def _as_node_def_input(self):
780 """Return a value to use for the NodeDef "input" attribute.
782 The returned string can be used in a NodeDef "input" attribute
783 to indicate that the NodeDef uses this Tensor as input.
785 Raises:
786 ValueError: if this Tensor's Operation does not have a name.
788 Returns:
789 a string.
790 """
791 assert self._op.name
792 if self.value_index == 0:
793 return self._op.name
794 else:
795 return "%s:%d" % (self._op.name, self.value_index)
797 def __str__(self):
798 return "Tensor(\"%s\"%s%s%s)" % (
799 self.name,
800 (", shape=%s" %
801 self.get_shape()) if self.get_shape().ndims is not None else "",
802 (", dtype=%s" % self._dtype.name) if self._dtype else "",
803 (", device=%s" % self.device) if self.device else "")
805 def __repr__(self):
806 return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
807 self._dtype.name)
809 def __hash__(self):
810 g = getattr(self, "graph", None)
811 if (Tensor._USE_EQUALITY and (g is None or g.building_function)):
812 raise TypeError("Tensor is unhashable. "
813 "Instead, use tensor.ref() as the key.")
814 else:
815 return id(self)
817 def __copy__(self):
818 # TODO(b/77597810): get rid of Tensor copies.
819 cls = self.__class__
820 result = cls.__new__(cls)
821 result._init(self.op, self.value_index, self.dtype)
822 result.__dict__.update(self.__dict__)
823 return result
825 # NOTE(mrry): This enables the Tensor's overloaded "right" binary
826 # operators to run when the left operand is an ndarray, because it
827 # accords the Tensor class higher priority than an ndarray, or a
828 # numpy matrix.
829 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
830 # mechanism, which allows more control over how Tensors interact
831 # with ndarrays.
832 __array_priority__ = 100
834 def __array__(self, dtype=None):
835 del dtype
836 raise NotImplementedError(
837 f"Cannot convert a symbolic tf.Tensor ({self.name}) to a numpy array."
838 f" This error may indicate that you're trying to pass a Tensor to"
839 f" a NumPy call, which is not supported.")
841 def __len__(self):
842 raise TypeError(f"len is not well defined for a symbolic Tensor "
843 f"({self.name}). Please call `x.shape` rather than "
844 f"`len(x)` for shape information.")
846 # TODO(mdan): This convoluted machinery is hard to maintain. Clean up.
847 @staticmethod
848 def _override_operator(operator, func):
849 _override_helper(Tensor, operator, func)
851 def __bool__(self):
852 """Dummy method to prevent a tensor from being used as a Python `bool`.
854 This overload raises a `TypeError` when the user inadvertently
855 treats a `Tensor` as a boolean (most commonly in an `if` or `while`
856 statement), in code that was not converted by AutoGraph. For example:
858 ```python
859 if tf.constant(True): # Will raise.
860 # ...
862 if tf.constant(5) < tf.constant(7): # Will raise.
863 # ...
864 ```
866 Raises:
867 `TypeError`.
868 """
869 self._disallow_bool_casting()
871 def __nonzero__(self):
872 """Dummy method to prevent a tensor from being used as a Python `bool`.
874 This is the Python 2.x counterpart to `__bool__()` above.
876 Raises:
877 `TypeError`.
878 """
879 self._disallow_bool_casting()
881 def eval(self, feed_dict=None, session=None):
882 """Evaluates this tensor in a `Session`.
884 Note: If you are not using `compat.v1` libraries, you should not need this,
885 (or `feed_dict` or `Session`). In eager execution (or within `tf.function`)
886 you do not need to call `eval`.
888 Calling this method will execute all preceding operations that
889 produce the inputs needed for the operation that produces this
890 tensor.
892 *N.B.* Before invoking `Tensor.eval()`, its graph must have been
893 launched in a session, and either a default session must be
894 available, or `session` must be specified explicitly.
896 Args:
897 feed_dict: A dictionary that maps `Tensor` objects to feed values. See
898 `tf.Session.run` for a description of the valid feed values.
899 session: (Optional.) The `Session` to be used to evaluate this tensor. If
900 none, the default session will be used.
902 Returns:
903 A numpy array corresponding to the value of this tensor.
904 """
905 return _eval_using_default_session(self, feed_dict, self.graph, session)
907 @deprecation.deprecated(None, "Use ref() instead.")
908 def experimental_ref(self):
909 return self.ref()
911 def ref(self):
912 # tf.Variable also has the same ref() API. If you update the
913 # documentation here, please update tf.Variable.ref() as well.
914 """Returns a hashable reference object to this Tensor.
916 The primary use case for this API is to put tensors in a set/dictionary.
917 We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer
918 available starting Tensorflow 2.0.
920 The following will raise an exception starting 2.0
922 >>> x = tf.constant(5)
923 >>> y = tf.constant(10)
924 >>> z = tf.constant(10)
925 >>> tensor_set = {x, y, z}
926 Traceback (most recent call last):
927 ...
928 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
929 >>> tensor_dict = {x: 'five', y: 'ten'}
930 Traceback (most recent call last):
931 ...
932 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
934 Instead, we can use `tensor.ref()`.
936 >>> tensor_set = {x.ref(), y.ref(), z.ref()}
937 >>> x.ref() in tensor_set
938 True
939 >>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
940 >>> tensor_dict[y.ref()]
941 'ten'
943 Also, the reference object provides `.deref()` function that returns the
944 original Tensor.
946 >>> x = tf.constant(5)
947 >>> x.ref().deref()
948 <tf.Tensor: shape=(), dtype=int32, numpy=5>
949 """
950 return object_identity.Reference(self)
952 def __tf_tracing_type__(self, signature_context):
953 if self.dtype == dtypes.resource or self.dtype == dtypes.variant:
954 handle_data = handle_data_util.get_handle_data(self)
955 dtype = dtypes.DType(self.dtype._type_enum, handle_data)
956 else:
957 dtype = self.dtype
958 spec = tensor_spec.TensorSpec(self.shape, dtype)
959 return spec
961 def __tf_tensor__(
962 self, dtype: Optional[dtypes.DType] = None, name: Optional[str] = None
963 ) -> "Tensor":
964 if dtype is not None and not dtype.is_compatible_with(self.dtype):
965 raise ValueError(
966 _add_error_prefix(
967 f"Tensor conversion requested dtype {dtype.name} "
968 f"for Tensor with dtype {self.dtype.name}: {self!r}",
969 name=name))
970 return self
973def GraphTensor(op, value_index, dtype):
974 """Creates a new `Tensor` in a graph.
976 Args:
977 op: An `Operation`. `Operation` that computes this tensor.
978 value_index: An `int`. Index of the operation's endpoint that produces this
979 tensor.
980 dtype: A `DType`. Type of elements stored in this tensor.
982 Returns:
983 A Tensor object.
985 Raises:
986 TypeError: If the op is not an `Operation`.
987 """
988 self = Tensor()
989 # pylint: disable=protected-access
990 self._init(op, value_index, dtype)
991 self._dtype = dtypes.as_dtype(dtype)
993 # This will be set by self.shape().
994 self._shape_val = None
995 self._name = None
996 self._id = uid()
997 # pylint: enable=protected-access
998 return self
1001def _create_graph_constant(
1002 value, dtype, shape, name, verify_shape, allow_broadcast
1003):
1004 """Create a graph constant and invoke constant callbacks."""
1005 g = get_default_graph()
1006 tensor_value = attr_value_pb2.AttrValue()
1007 tensor_value.tensor.CopyFrom(
1008 tensor_util.make_tensor_proto(
1009 value, dtype=dtype, shape=shape, verify_shape=verify_shape,
1010 allow_broadcast=allow_broadcast))
1011 dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
1012 attrs = {"value": tensor_value, "dtype": dtype_value}
1013 const_tensor = g._create_op_internal( # pylint: disable=protected-access
1014 "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0]
1016 if op_callbacks.should_invoke_op_callbacks():
1017 # TODO(b/147670703): Once the special-op creation code paths
1018 # are unified. Remove this `if` block.
1019 callback_outputs = op_callbacks.invoke_op_callbacks(
1020 "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g)
1021 if callback_outputs is not None:
1022 [const_tensor] = callback_outputs
1023 return const_tensor
1026class _EagerTensorBase(Tensor, internal.NativeObject, core_tf_types.Value):
1027 """Base class for EagerTensor."""
1029 # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and
1030 # only work for scalars; values are cast as per numpy.
1031 def __complex__(self):
1032 return complex(self._numpy())
1034 def __int__(self):
1035 return int(self._numpy())
1037 def __long__(self):
1038 return long(self._numpy())
1040 def __float__(self):
1041 return float(self._numpy())
1043 def __index__(self):
1044 return self._numpy().__index__()
1046 def __bool__(self):
1047 return bool(self._numpy())
1049 __nonzero__ = __bool__
1051 def __format__(self, format_spec):
1052 if self._prefer_custom_summarizer():
1053 return self._summarize_value().__format__(format_spec)
1054 elif self.dtype.is_numpy_compatible:
1055 # Not numpy_text here, otherwise the __format__ behaves differently.
1056 return self._numpy().__format__(format_spec)
1057 else:
1058 return "<unprintable>".__format__(format_spec)
1060 def __reduce__(self):
1061 return convert_to_tensor, (self._numpy(),)
1063 def __copy__(self):
1064 # Eager Tensors are immutable so it's safe to return themselves as a copy.
1065 return self
1067 def __deepcopy__(self, memo):
1068 # Eager Tensors are immutable so it's safe to return themselves as a copy.
1069 del memo
1070 return self
1072 def __str__(self):
1073 return "tf.Tensor(%s, shape=%s, dtype=%s)" % (
1074 value_text(self, is_repr=False), self.shape, self.dtype.name)
1076 def __repr__(self):
1077 return "<tf.Tensor: shape=%s, dtype=%s, %s>" % (
1078 self.shape, self.dtype.name, value_text(self, is_repr=True))
1080 def __len__(self):
1081 """Returns the length of the first dimension in the Tensor."""
1082 if not self.shape.ndims:
1083 raise TypeError("Scalar tensor has no `len()`")
1084 # pylint: disable=protected-access
1085 try:
1086 return self._shape_tuple()[0]
1087 except core._NotOkStatusException as e:
1088 raise core._status_to_exception(e) from None
1090 def __array__(self, dtype=None):
1091 a = self._numpy()
1092 if not dtype:
1093 return a
1095 return np.array(a, dtype=dtype)
1097 def __hash__(self) -> int:
1098 # EagerTensors are never hashable.
1099 raise TypeError("Tensor is unhashable. "
1100 "Instead, use tensor.ref() as the key.")
1102 def _numpy_internal(self):
1103 raise NotImplementedError()
1105 def _numpy(self):
1106 try:
1107 return self._numpy_internal()
1108 except core._NotOkStatusException as e: # pylint: disable=protected-access
1109 raise core._status_to_exception(e) from None # pylint: disable=protected-access
1111 @property
1112 def dtype(self):
1113 # Note: using the intern table directly here as this is
1114 # performance-sensitive in some models.
1115 return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access
1117 def numpy(self):
1118 """Copy of the contents of this Tensor into a NumPy array or scalar.
1120 Unlike NumPy arrays, Tensors are immutable, so this method has to copy
1121 the contents to ensure safety. Use `memoryview` to get a readonly
1122 view of the contents without doing a copy:
1124 >>> t = tf.constant([42])
1125 >>> np.array(memoryview(t))
1126 array([42], dtype=int32)
1128 Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor
1129 is on GPU, it will have to be transferred to CPU first in order for
1130 `memoryview` to work.
1132 Returns:
1133 A NumPy array of the same shape and dtype or a NumPy scalar, if this
1134 Tensor has rank 0.
1136 Raises:
1137 ValueError: If the dtype of this Tensor does not have a compatible
1138 NumPy dtype.
1139 """
1140 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
1141 maybe_arr = self._numpy() # pylint: disable=protected-access
1142 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
1144 @property
1145 def backing_device(self):
1146 """Returns the name of the device holding this tensor's memory.
1148 `.backing_device` is usually the same as `.device`, which returns
1149 the device on which the kernel of the operation that produced this tensor
1150 ran. However, some operations can produce tensors on a different device
1151 (e.g., an operation that executes on the GPU but produces output tensors
1152 in host memory).
1153 """
1154 raise NotImplementedError()
1156 def _datatype_enum(self):
1157 raise NotImplementedError()
1159 def _shape_tuple(self):
1160 """The shape of this Tensor, as a tuple.
1162 This is more performant than tuple(shape().as_list()) as it avoids
1163 two list and one object creation. Marked private for now as from an API
1164 perspective, it would be better to have a single performant way of
1165 getting a shape rather than exposing shape() and shape_tuple()
1166 (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
1167 but ideally one would work things out and remove the need for this method.
1169 Returns:
1170 tuple with the shape.
1171 """
1172 raise NotImplementedError()
1174 def _rank(self):
1175 """Integer rank of this Tensor.
1177 Unlike regular Tensors, the rank is always known for EagerTensors.
1179 This is more performant than len(self._shape_tuple())
1181 Returns:
1182 Integer rank
1183 """
1184 raise NotImplementedError()
1186 def _num_elements(self):
1187 """Number of elements of this Tensor.
1189 Unlike regular Tensors, the number of elements is always known for
1190 EagerTensors.
1192 This is more performant than tensor.shape.num_elements
1194 Returns:
1195 Long - num elements in the tensor
1196 """
1197 raise NotImplementedError()
1199 def _copy_to_device(self, device_name): # pylint: disable=redefined-outer-name
1200 raise NotImplementedError()
1202 @staticmethod
1203 def _override_operator(name, func):
1204 setattr(_EagerTensorBase, name, func)
1206 def _copy_nograd(self, ctx=None, device_name=None):
1207 """Copies tensor to dest device, but doesn't record the operation."""
1208 # Creates a new tensor on the dest device.
1209 if ctx is None:
1210 ctx = context.context()
1211 if device_name is None:
1212 device_name = ctx.device_name
1213 # pylint: disable=protected-access
1214 try:
1215 ctx.ensure_initialized()
1216 new_tensor = self._copy_to_device(device_name)
1217 except core._NotOkStatusException as e:
1218 raise core._status_to_exception(e) from None
1219 return new_tensor
1221 def _copy(self, ctx=None, device_name=None):
1222 """Copies tensor to dest device."""
1223 new_tensor = self._copy_nograd(ctx, device_name)
1224 # Record the copy on tape and define backprop copy as well.
1225 if context.executing_eagerly():
1226 self_device = self.device
1228 def grad_fun(dresult):
1229 return [
1230 dresult._copy(device_name=self_device)
1231 if hasattr(dresult, "_copy") else dresult
1232 ]
1234 record.record_operation("_copy", [new_tensor], [self], grad_fun)
1235 return new_tensor
1236 # pylint: enable=protected-access
1238 @property
1239 def shape(self):
1240 if self._tensor_shape is None: # pylint: disable=access-member-before-definition
1241 # pylint: disable=protected-access
1242 try:
1243 # `_tensor_shape` is declared and defined in the definition of
1244 # `EagerTensor`, in C.
1245 self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
1246 except core._NotOkStatusException as e:
1247 raise core._status_to_exception(e) from None
1249 return self._tensor_shape
1251 def get_shape(self):
1252 """Alias of Tensor.shape."""
1253 return self.shape
1255 def _shape_as_list(self):
1256 """The shape of the tensor as a list."""
1257 return list(self._shape_tuple())
1259 @deprecation.deprecated(
1260 None, "Use tf.identity with explicit device placement instead.")
1261 def cpu(self):
1262 """A copy of this Tensor with contents backed by host memory."""
1263 return self._copy(context.context(), "CPU:0")
1265 @deprecation.deprecated(None, "Use tf.identity instead.")
1266 def gpu(self, gpu_index=0):
1267 """A copy of this Tensor with contents backed by memory on the GPU.
1269 Args:
1270 gpu_index: Identifies which GPU to place the contents on the returned
1271 Tensor in.
1273 Returns:
1274 A GPU-memory backed Tensor object initialized with the same contents
1275 as this Tensor.
1276 """
1277 return self._copy(context.context(), "GPU:" + str(gpu_index))
1279 def set_shape(self, shape):
1280 if not self.shape.is_compatible_with(shape):
1281 raise ValueError(f"Tensor's shape {self.shape} is not compatible "
1282 f"with supplied shape {shape}.")
1284 # Methods not supported / implemented for Eager Tensors.
1285 @property
1286 def op(self):
1287 raise AttributeError(
1288 "Tensor.op is undefined when eager execution is enabled.")
1290 @property
1291 def graph(self):
1292 raise AttributeError(
1293 "Tensor.graph is undefined when eager execution is enabled.")
1295 @property
1296 def name(self):
1297 raise AttributeError(
1298 "Tensor.name is undefined when eager execution is enabled.")
1300 @property
1301 def value_index(self):
1302 raise AttributeError(
1303 "Tensor.value_index is undefined when eager execution is enabled.")
1305 def consumers(self):
1306 raise NotImplementedError(
1307 "Tensor.consumers is undefined when eager execution is enabled.")
1309 def _add_consumer(self, consumer):
1310 raise NotImplementedError(
1311 "_add_consumer not supported when eager execution is enabled.")
1313 def _as_node_def_input(self):
1314 raise NotImplementedError(
1315 "_as_node_def_input not supported when eager execution is enabled.")
1317 def _as_tf_output(self):
1318 raise NotImplementedError(
1319 "_as_tf_output not supported when eager execution is enabled.")
1321 def eval(self, feed_dict=None, session=None):
1322 raise NotImplementedError(
1323 "eval is not supported when eager execution is enabled, "
1324 "is .numpy() what you're looking for?")
1326 def __tf_tensor__(
1327 self, dtype: Optional[dtypes.DType] = None, name: Optional[str] = None
1328 ) -> Tensor:
1329 if not context.executing_eagerly():
1330 graph = get_default_graph()
1331 if not graph.building_function:
1332 raise RuntimeError(
1333 _add_error_prefix(
1334 "Attempting to capture an EagerTensor without "
1335 "building a function.",
1336 name=name))
1337 return graph.capture(self, name=name)
1338 return super().__tf_tensor__(dtype, name)
1340 def _capture_as_const(self, name):
1341 """Capture the EagerTensor to a graph constant tensor."""
1342 with control_dependencies(None):
1343 constant_value = tensor_util.constant_value(self)
1344 if constant_value is None:
1345 # Some eager tensors, e.g. parallel tensors, are not convertible to
1346 # a single constant. Return None in this case and the caller graph
1347 # would create a placeholder instead.
1348 return None
1350 const_tensor = _create_graph_constant(
1351 constant_value, dtype=self.dtype, shape=self.shape, name=name,
1352 verify_shape=False, allow_broadcast=True)
1353 return const_tensor
1356# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
1357# registers it with the current module.
1358# It is exposed as an __internal__ api for now (b/171081052), though we
1359# expect it to be eventually covered by tf Tensor types and typing.
1360EagerTensor = tf_export("__internal__.EagerTensor", v1=[])(
1361 pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase))
1364def _add_error_prefix(msg, *, name=None):
1365 return msg if name is None else f"{name}: {msg}"
1368def pack_eager_tensors(tensors, ctx=None):
1369 """Pack multiple `EagerTensor`s of the same dtype and shape.
1371 Args:
1372 tensors: a list of EagerTensors to pack.
1373 ctx: context.context().
1375 Returns:
1376 A packed EagerTensor.
1377 """
1378 if not isinstance(tensors, list):
1379 raise TypeError(f"tensors must be a list, but got a {type(tensors)}")
1381 if not tensors:
1382 raise ValueError("Cannot pack an empty list of tensors.")
1384 dtype = tensors[0].dtype
1385 shape = tensors[0].shape
1386 handle_data = tensors[0]._handle_data # pylint: disable=protected-access
1387 is_resource = dtype == dtypes.resource
1388 for i in range(len(tensors)):
1389 t = tensors[i]
1390 if not isinstance(t, EagerTensor):
1391 raise TypeError(f"All tensors being packed must be EagerTensor. "
1392 f"Found an item of type {type(t)}.")
1394 if t.dtype != dtype:
1395 raise ValueError(
1396 f"All tensors being packed should have the same dtype {dtype}, "
1397 f"but the {i}-th tensor is of dtype {t.dtype}")
1398 if t.shape != shape:
1399 raise ValueError(
1400 f"All tensors being packed should have the same shape {shape}, "
1401 f"but the {i}-th tensor is of shape {t.shape}")
1402 # pylint: disable=protected-access
1403 if is_resource and t._handle_data != handle_data:
1404 raise ValueError(
1405 f"All tensors being packed should have the same handle data "
1406 f"{handle_data}, "
1407 f"but the {i}-th tensor is of handle data {t._handle_data}")
1408 # pylint: enable=protected-access
1410 if ctx is None:
1411 ctx = context.context()
1413 # Propagate handle data for resource variables
1414 packed_tensor = ctx.pack_eager_tensors(tensors)
1415 if handle_data is not None:
1416 packed_tensor._handle_data = handle_data # pylint: disable=protected-access
1418 def grad_fun(_):
1419 raise ValueError(
1420 "Computing gradients through pack_eager_tensors is not supported.")
1422 record.record_operation("pack_eager_tensors", [packed_tensor], tensors,
1423 grad_fun)
1425 return packed_tensor
1428@profiler_trace.trace_wrapper("convert_to_tensor")
1429def convert_to_tensor(
1430 value,
1431 dtype=None,
1432 name=None,
1433 as_ref=False,
1434 preferred_dtype=None,
1435 dtype_hint=None,
1436 # TODO(b/268347915): Remove argument.
1437 ctx=None, # pylint: disable=unused-argument
1438 accepted_result_types=(Tensor,),
1439):
1440 """Implementation of the public convert_to_tensor."""
1441 # TODO(b/142518781): Fix all call-sites and remove redundant arg
1442 preferred_dtype = preferred_dtype or dtype_hint
1443 return tensor_conversion_registry.convert(
1444 value, dtype, name, as_ref, preferred_dtype, accepted_result_types
1445 )
1448internal_convert_to_tensor = convert_to_tensor
1451def internal_convert_n_to_tensor(values,
1452 dtype=None,
1453 name=None,
1454 as_ref=False,
1455 preferred_dtype=None,
1456 # TODO(b/268347915): Remove argument.
1457 ctx=None): # pylint: disable=unused-argument
1458 """Converts `values` to a list of `Tensor` objects.
1460 Args:
1461 values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1462 dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1463 name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1464 which case element `i` will be given the name `name + '_' + i`.
1465 as_ref: True if the caller wants the results as ref tensors.
1466 preferred_dtype: Optional element type for the returned tensors, used when
1467 dtype is None. In some cases, a caller may not have a dtype in mind when
1468 converting to a tensor, so preferred_dtype can be used as a soft
1469 preference. If the conversion to `preferred_dtype` is not possible, this
1470 argument has no effect.
1471 ctx: Unused. Present for API backwards compatibility.
1473 Returns:
1474 A list of `Tensor` and/or `IndexedSlices` objects.
1476 Raises:
1477 TypeError: If no conversion function is registered for an element in
1478 `values`.
1479 RuntimeError: If a registered conversion function returns an invalid
1480 value.
1481 """
1482 if not isinstance(values, collections_abc.Sequence):
1483 raise TypeError("values must be a sequence.")
1484 ret = []
1485 for i, value in enumerate(values):
1486 n = None if name is None else "%s_%d" % (name, i)
1487 ret.append(
1488 convert_to_tensor(
1489 value,
1490 dtype=dtype,
1491 name=n,
1492 as_ref=as_ref,
1493 preferred_dtype=preferred_dtype))
1494 return ret
1497def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
1498 """Converts `values` to a list of `Tensor` objects.
1500 Args:
1501 values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1502 dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1503 name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1504 which case element `i` will be given the name `name + '_' + i`.
1505 preferred_dtype: Optional element type for the returned tensors, used when
1506 dtype is None. In some cases, a caller may not have a dtype in mind when
1507 converting to a tensor, so preferred_dtype can be used as a soft
1508 preference. If the conversion to `preferred_dtype` is not possible, this
1509 argument has no effect.
1511 Returns:
1512 A list of `Tensor` and/or `IndexedSlices` objects.
1514 Raises:
1515 TypeError: If no conversion function is registered for an element in
1516 `values`.
1517 RuntimeError: If a registered conversion function returns an invalid
1518 value.
1519 """
1520 return internal_convert_n_to_tensor(
1521 values=values,
1522 dtype=dtype,
1523 name=name,
1524 preferred_dtype=preferred_dtype,
1525 as_ref=False)
1528def convert_to_tensor_or_composite(value, dtype=None, name=None):
1529 """Converts the given object to a `Tensor` or `CompositeTensor`.
1531 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it
1532 is converted to a `Tensor` using `convert_to_tensor()`.
1534 Args:
1535 value: A `CompositeTensor` or an object that can be consumed by
1536 `convert_to_tensor()`.
1537 dtype: (Optional.) The required `DType` of the returned `Tensor` or
1538 `CompositeTensor`.
1539 name: (Optional.) A name to use if a new `Tensor` is created.
1541 Returns:
1542 A `Tensor` or `CompositeTensor`, based on `value`.
1544 Raises:
1545 ValueError: If `dtype` does not match the element type of `value`.
1546 """
1547 return internal_convert_to_tensor_or_composite(
1548 value=value, dtype=dtype, name=name, as_ref=False)
1551def internal_convert_to_tensor_or_composite(value,
1552 dtype=None,
1553 name=None,
1554 as_ref=False):
1555 """Converts the given object to a `Tensor` or `CompositeTensor`.
1557 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it
1558 is converted to a `Tensor` using `convert_to_tensor()`.
1560 Args:
1561 value: A `CompositeTensor`, or an object that can be consumed by
1562 `convert_to_tensor()`.
1563 dtype: (Optional.) The required `DType` of the returned `Tensor` or
1564 `CompositeTensor`.
1565 name: (Optional.) A name to use if a new `Tensor` is created.
1566 as_ref: True if the caller wants the results as ref tensors.
1568 Returns:
1569 A `Tensor` or `CompositeTensor`, based on `value`.
1571 Raises:
1572 ValueError: If `dtype` does not match the element type of `value`.
1573 """
1574 if isinstance(value, composite_tensor.CompositeTensor):
1575 value_dtype = getattr(value, "dtype", None)
1576 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype):
1577 raise ValueError(f"Tensor conversion dtype mismatch. "
1578 f"Requested dtype is {dtypes.as_dtype(dtype).name}, "
1579 f"Tensor has dtype {value.dtype.name}: {value!r}")
1580 return value
1581 else:
1582 return convert_to_tensor(
1583 value,
1584 dtype=dtype,
1585 name=name,
1586 as_ref=as_ref,
1587 accepted_result_types=(Tensor, composite_tensor.CompositeTensor))
1590def internal_convert_n_to_tensor_or_composite(values,
1591 dtype=None,
1592 name=None,
1593 as_ref=False):
1594 """Converts `values` to a list of `Tensor` or `CompositeTensor` objects.
1596 Any `CompositeTensor` objects in `values` are returned unmodified.
1598 Args:
1599 values: A list of `None`, `CompositeTensor`, or objects that can be consumed
1600 by `convert_to_tensor()`.
1601 dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1602 `CompositeTensor`s.
1603 name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1604 which case element `i` will be given the name `name + '_' + i`.
1605 as_ref: True if the caller wants the results as ref tensors.
1607 Returns:
1608 A list of `Tensor`, `CompositeTensor`, and/or `None` objects.
1610 Raises:
1611 TypeError: If no conversion function is registered for an element in
1612 `values`.
1613 RuntimeError: If a registered conversion function returns an invalid
1614 value.
1615 """
1616 if not isinstance(values, collections_abc.Sequence):
1617 raise TypeError("values must be a sequence.")
1618 ret = []
1619 for i, value in enumerate(values):
1620 if value is None:
1621 ret.append(value)
1622 else:
1623 n = None if name is None else "%s_%d" % (name, i)
1624 ret.append(
1625 internal_convert_to_tensor_or_composite(
1626 value, dtype=dtype, name=n, as_ref=as_ref))
1627 return ret
1630def convert_n_to_tensor_or_composite(values, dtype=None, name=None):
1631 """Converts `values` to a list of `Output` or `CompositeTensor` objects.
1633 Any `CompositeTensor` objects in `values` are returned unmodified.
1635 Args:
1636 values: A list of `None`, `CompositeTensor``, or objects that can be
1637 consumed by `convert_to_tensor()`.
1638 dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1639 `CompositeTensor`s.
1640 name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1641 which case element `i` will be given the name `name + '_' + i`.
1643 Returns:
1644 A list of `Tensor` and/or `CompositeTensor` objects.
1646 Raises:
1647 TypeError: If no conversion function is registered for an element in
1648 `values`.
1649 RuntimeError: If a registered conversion function returns an invalid
1650 value.
1651 """
1652 return internal_convert_n_to_tensor_or_composite(
1653 values=values, dtype=dtype, name=name, as_ref=False)
1656def _device_string(dev_spec):
1657 if pydev.is_device_spec(dev_spec):
1658 return dev_spec.to_string()
1659 else:
1660 return dev_spec
1663def _NodeDef(op_type, name, attrs=None):
1664 """Create a NodeDef proto.
1666 Args:
1667 op_type: Value for the "op" attribute of the NodeDef proto.
1668 name: Value for the "name" attribute of the NodeDef proto.
1669 attrs: Dictionary where the key is the attribute name (a string)
1670 and the value is the respective "attr" attribute of the NodeDef proto (an
1671 AttrValue).
1673 Returns:
1674 A node_def_pb2.NodeDef protocol buffer.
1675 """
1676 node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type),
1677 name=compat.as_bytes(name))
1678 if attrs:
1679 for k, v in attrs.items():
1680 node_def.attr[k].CopyFrom(v)
1681 return node_def
1684# Copied from core/framework/node_def_util.cc
1685# TODO(mrry,josh11b): Consolidate this validation in C++ code.
1686_VALID_OP_NAME_REGEX = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
1687_VALID_SCOPE_NAME_REGEX = re.compile(r"^[A-Za-z0-9_.\\/>-]*$")
1690@tf_export("__internal__.create_c_op", v1=[])
1691@traceback_utils.filter_traceback
1692def _create_c_op(graph,
1693 node_def,
1694 inputs,
1695 control_inputs,
1696 op_def=None,
1697 extract_traceback=True):
1698 """Creates a TF_Operation.
1700 Args:
1701 graph: a `Graph`.
1702 node_def: `node_def_pb2.NodeDef` for the operation to create.
1703 inputs: A flattened list of `Tensor`s. This function handles grouping
1704 tensors into lists as per attributes in the `node_def`.
1705 control_inputs: A list of `Operation`s to set as control dependencies.
1706 op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not
1707 specified, is looked up from the `graph` using `node_def.op`.
1708 extract_traceback: if True, extract the current Python traceback to the
1709 TF_Operation.
1711 Returns:
1712 A wrapped TF_Operation*.
1713 """
1714 if op_def is None:
1715 op_def = graph.op_def_for_type(node_def.op) # pylint: disable=protected-access
1716 # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
1717 # Refactor so we don't have to do this here.
1718 inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr)
1719 # pylint: disable=protected-access
1720 with graph._c_graph.get() as c_graph:
1721 op_desc = pywrap_tf_session.TF_NewOperation(c_graph,
1722 compat.as_str(node_def.op),
1723 compat.as_str(node_def.name))
1724 if node_def.device:
1725 pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device))
1726 # Add inputs
1727 for op_input in inputs:
1728 if isinstance(op_input, (list, tuple)):
1729 pywrap_tf_session.TF_AddInputList(op_desc,
1730 [t._as_tf_output() for t in op_input])
1731 else:
1732 pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output())
1734 # Add control inputs
1735 for control_input in control_inputs:
1736 pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op)
1737 # pylint: enable=protected-access
1739 # Add attrs
1740 for name, attr_value in node_def.attr.items():
1741 serialized = attr_value.SerializeToString()
1742 # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
1743 # It might be worth creating a convenient way to re-use the same status.
1744 pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name),
1745 serialized)
1747 try:
1748 c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
1749 except errors.InvalidArgumentError as e:
1750 # Convert to ValueError for backwards compatibility.
1751 raise ValueError(e.message)
1753 # Record the current Python stack trace as the creating stacktrace of this
1754 # TF_Operation.
1755 if extract_traceback:
1756 tf_stack.extract_stack_for_op(c_op, stacklevel=3)
1758 return c_op
1761@tf_export("Operation")
1762class Operation(pywrap_tf_session.PyOperation):
1763 """Represents a graph node that performs computation on tensors.
1765 An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor`
1766 objects as input, and produces zero or more `Tensor` objects as output.
1767 Objects of type `Operation` are created by calling a Python op constructor
1768 (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default`
1769 context manager.
1771 For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an
1772 `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and
1773 produces `c` as output.
1775 If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be
1776 executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for
1777 calling `tf.compat.v1.get_default_session().run(op)`.
1778 """
1780 @classmethod
1781 def from_node_def(
1782 cls,
1783 node_def,
1784 g,
1785 inputs=None,
1786 output_types=None,
1787 control_inputs=None,
1788 input_types=None,
1789 original_op=None,
1790 op_def=None,
1791 ):
1792 r"""Creates an `Operation`.
1794 NOTE: This constructor validates the name of the `Operation` (passed
1795 as `node_def.name`). Valid `Operation` names match the following
1796 regular expression:
1798 [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
1800 Args:
1801 node_def: `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. Used for
1802 attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and
1803 `device`. The `input` attribute is irrelevant here as it will be
1804 computed when generating the model.
1805 g: `Graph`. The parent graph.
1806 inputs: list of `Tensor` objects. The inputs to this `Operation`.
1807 output_types: list of `DType` objects. List of the types of the `Tensors`
1808 computed by this operation. The length of this list indicates the
1809 number of output endpoints of the `Operation`.
1810 control_inputs: list of operations or tensors from which to have a control
1811 dependency.
1812 input_types: List of `DType` objects representing the types of the tensors
1813 accepted by the `Operation`. By default uses `[x.dtype.base_dtype for x
1814 in inputs]`. Operations that expect reference-typed inputs must specify
1815 these explicitly.
1816 original_op: Optional. Used to associate the new `Operation` with an
1817 existing `Operation` (for example, a replica with the op that was
1818 replicated).
1819 op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type
1820 that this `Operation` represents.
1822 Raises:
1823 TypeError: if control inputs are not Operations or Tensors,
1824 or if `node_def` is not a `NodeDef`,
1825 or if `g` is not a `Graph`,
1826 or if `inputs` are not tensors,
1827 or if `inputs` and `input_types` are incompatible.
1828 ValueError: if the `node_def` name is not valid.
1830 Returns:
1831 Operation object.
1832 """
1833 if not isinstance(g, Graph):
1834 raise TypeError(f"Argument g must be a Graph. "
1835 f"Received an instance of type {type(g)}")
1837 if not isinstance(node_def, node_def_pb2.NodeDef):
1838 raise TypeError(f"Argument node_def must be a NodeDef. "
1839 f"Received an instance of type: {type(node_def)}.")
1840 if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
1841 raise ValueError(
1842 f"Cannot create a tensor proto whose content is larger than 2GB. "
1843 f"Size of tensor is {node_def.ByteSize()} bytes.")
1845 # TODO(mdan): This does not belong here. Graph::AddNode should handle it.
1846 if not _VALID_OP_NAME_REGEX.match(node_def.name):
1847 raise ValueError(
1848 f"`{node_def.name}` is not a valid node name. "
1849 f"Accepted names conform to Regex /{_VALID_OP_NAME_REGEX}/")
1851 # FIXME(b/225400189): output_types is unused. Consider remove it from
1852 # the argument list.
1853 del output_types
1855 if inputs is None:
1856 inputs = []
1857 elif not isinstance(inputs, list):
1858 raise TypeError(f"Argument inputs shall be a list of Tensors. "
1859 f"Received an instance of type {type(inputs)}")
1860 for a in inputs:
1861 if not isinstance(a, Tensor):
1862 raise TypeError(f"Items of argument inputs shall be Tensor. "
1863 f"Received an instance of type {type(a)}.")
1864 if input_types is None:
1865 input_types = [i.dtype.base_dtype for i in inputs]
1866 else:
1867 if not all(
1868 x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)):
1869 raise TypeError("In op '%s', input types (%s) are not compatible "
1870 "with expected types (%s)" %
1871 (node_def.name, [i.dtype for i in inputs], input_types))
1873 # Build the list of control inputs.
1874 control_input_ops = []
1875 if control_inputs:
1876 for c in control_inputs:
1877 control_op = None
1878 if isinstance(c, Operation):
1879 control_op = c
1880 elif isinstance(c, (Tensor, internal.IndexedSlices)):
1881 control_op = c.op
1882 else:
1883 raise TypeError(f"Control input must be an Operation, "
1884 f"a Tensor, or IndexedSlices. "
1885 f"Received an instance of type {type(c)}.")
1886 control_input_ops.append(control_op)
1888 # Initialize c_op from node_def and other inputs
1889 c_op = _create_c_op(g, node_def, inputs, control_input_ops, op_def=op_def)
1890 self = Operation(c_op, GraphTensor)
1891 self._init(g)
1893 self._original_op = original_op
1895 # Post process for control flows.
1896 self._control_flow_post_processing(input_tensors=inputs)
1898 # Removes this frame from the Python traceback.
1899 # We adjust stacklevel directly to avoid triggering serialization.
1900 if self.traceback is not None:
1901 self.traceback._stacklevel += 1 # pylint: disable=protected-access
1903 return self
1905 @classmethod
1906 def _from_c_op(cls, c_op, g):
1907 """Create an Operation from a TF_Operation.
1909 For internal use only: This is useful for creating Operation for ops
1910 indirectly created by C API methods, e.g. the ops created by
1911 TF_ImportGraphDef.
1913 Args:
1914 c_op: a TF_Operation.
1915 g: A Graph.
1917 Returns:
1918 an Operation object.
1919 """
1920 self = Operation(c_op, GraphTensor)
1921 self._init(g)
1922 return self
1924 def _init(self, graph):
1925 """Initializes Operation from a TF_Operation."""
1926 self.graph = graph
1927 self._original_op = None
1929 # This will be set by self.inputs.
1930 self._inputs_val = None
1932 # List of _UserDevSpecs holding code location of device context manager
1933 # invocations and the users original argument to them.
1934 self._device_code_locations = None
1935 # Dict mapping op name to file and line information for op colocation
1936 # context managers.
1937 self._colocation_code_locations = None
1938 self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access
1940 # Gradient function for this op. There are three ways to specify gradient
1941 # function, and first available gradient gets used, in the following order.
1942 # 1. self._gradient_function
1943 # 2. Gradient name registered by "_gradient_op_type" attribute.
1944 # 3. Gradient name registered by op.type.
1945 self._gradient_function = None
1947 self._init_outputs()
1948 self._id_value = self.graph._add_op(self) # pylint: disable=protected-access
1950 def _control_flow_post_processing(self, input_tensors=None):
1951 """Add this op to its control flow context.
1953 This may add new ops and change this op's inputs. self.inputs must be
1954 available before calling this method.
1956 Args:
1957 input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs
1958 of this op, which should be equivalent to `self.inputs`. Pass this
1959 argument to avoid evaluating `self.inputs` unnecessarily.
1960 """
1961 if input_tensors is None:
1962 input_tensors = self.inputs
1963 for input_tensor in input_tensors:
1964 control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
1965 if self._control_flow_context is not None:
1966 self._control_flow_context.AddOp(self)
1968 def colocation_groups(self):
1969 """Returns the list of colocation groups of the op."""
1970 default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)]
1971 try:
1972 class_attr = self.get_attr("_class")
1973 except ValueError:
1974 # This op has no explicit colocation group, so it is itself its
1975 # own root of a colocation group.
1976 return default_colocation_group
1978 attr_groups = [
1979 class_name for class_name in class_attr
1980 if class_name.startswith(b"loc:@")
1981 ]
1983 # If there are no colocation groups in the explicit _class field,
1984 # return the default colocation group.
1985 return attr_groups if attr_groups else default_colocation_group
1987 def values(self):
1988 """DEPRECATED: Use outputs."""
1989 return tuple(self.outputs)
1991 def _get_control_flow_context(self):
1992 """Returns the control flow context of this op.
1994 Returns:
1995 A context object.
1996 """
1997 return self._control_flow_context
1999 def _set_control_flow_context(self, ctx):
2000 """Sets the current control flow context of this op.
2002 Args:
2003 ctx: a context object.
2004 """
2005 self._control_flow_context = ctx
2007 @property
2008 def _id(self):
2009 """The unique integer id of this operation."""
2010 return self._id_value
2012 @property
2013 def device(self):
2014 """The name of the device to which this op has been assigned, if any.
2016 Returns:
2017 The string name of the device to which this op has been
2018 assigned, or an empty string if it has not been assigned to a
2019 device.
2020 """
2021 return pywrap_tf_session.TF_OperationDevice(self._c_op)
2023 @property
2024 def _device_assignments(self):
2025 """Code locations for device context managers active at op creation.
2027 This property will return a list of traceable_stack.TraceableObject
2028 instances where .obj is a string representing the assigned device
2029 (or information about the function that would be applied to this op
2030 to compute the desired device) and the filename and lineno members
2031 record the location of the relevant device context manager.
2033 For example, suppose file_a contained these lines:
2035 file_a.py:
2036 15: with tf.device('/gpu:0'):
2037 16: node_b = tf.constant(4, name='NODE_B')
2039 Then a TraceableObject t_obj representing the device context manager
2040 would have these member values:
2042 t_obj.obj -> '/gpu:0'
2043 t_obj.filename = 'file_a.py'
2044 t_obj.lineno = 15
2046 and node_b.op._device_assignments would return the list [t_obj].
2048 Returns:
2049 [str: traceable_stack.TraceableObject, ...] as per this method's
2050 description, above.
2051 """
2052 return self._device_code_locations or []
2054 @property
2055 def _colocation_dict(self):
2056 """Code locations for colocation context managers active at op creation.
2058 This property will return a dictionary for which the keys are nodes with
2059 which this Operation is colocated, and for which the values are
2060 traceable_stack.TraceableObject instances. The TraceableObject instances
2061 record the location of the relevant colocation context manager but have the
2062 "obj" field set to None to prevent leaking private data.
2064 For example, suppose file_a contained these lines:
2066 file_a.py:
2067 14: node_a = tf.constant(3, name='NODE_A')
2068 15: with tf.compat.v1.colocate_with(node_a):
2069 16: node_b = tf.constant(4, name='NODE_B')
2071 Then a TraceableObject t_obj representing the colocation context manager
2072 would have these member values:
2074 t_obj.obj -> None
2075 t_obj.filename = 'file_a.py'
2076 t_obj.lineno = 15
2078 and node_b.op._colocation_dict would return the dictionary
2080 { 'NODE_A': t_obj }
2082 Returns:
2083 {str: traceable_stack.TraceableObject} as per this method's description,
2084 above.
2085 """
2086 locations_dict = self._colocation_code_locations or {}
2087 return locations_dict.copy()
2089 @property
2090 def _output_types(self):
2091 """List this operation's output types.
2093 Returns:
2094 List of the types of the Tensors computed by this operation.
2095 Each element in the list is an integer whose value is one of
2096 the TF_DataType enums defined in pywrap_tf_session.h
2097 The length of this list indicates the number of output endpoints
2098 of the operation.
2099 """
2100 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
2101 output_types = [
2102 int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i)))
2103 for i in range(num_outputs)
2104 ]
2106 return output_types
2108 def _set_device(self, device): # pylint: disable=redefined-outer-name
2109 """Set the device of this operation.
2111 Args:
2112 device: string or device.. The device to set.
2113 """
2114 self._set_device_from_string(compat.as_str(_device_string(device)))
2116 def _update_input(self, index, tensor):
2117 """Update the input to this operation at the given index.
2119 NOTE: This is for TF internal use only. Please don't use it.
2121 Args:
2122 index: the index of the input to update.
2123 tensor: the Tensor to be used as the input at the given index.
2125 Raises:
2126 TypeError: if tensor is not a Tensor,
2127 or if input tensor type is not convertible to dtype.
2128 ValueError: if the Tensor is from a different graph.
2129 """
2130 if not isinstance(tensor, Tensor):
2131 raise TypeError("tensor must be a Tensor: %s" % tensor)
2133 _assert_same_graph(self, tensor)
2135 # Reset cached inputs.
2136 self._inputs_val = None
2137 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access
2138 pywrap_tf_session.UpdateEdge(
2139 c_graph,
2140 tensor._as_tf_output(), # pylint: disable=protected-access
2141 self._tf_input(index))
2143 def _add_while_inputs(self, tensors):
2144 """See AddWhileInputHack in python_api.h.
2146 NOTE: This is for TF internal use only. Please don't use it.
2148 Args:
2149 tensors: list of Tensors
2151 Raises:
2152 TypeError: if tensor is not a Tensor,
2153 or if input tensor type is not convertible to dtype.
2154 ValueError: if the Tensor is from a different graph.
2155 """
2156 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access
2157 for tensor in tensors:
2158 if not isinstance(tensor, Tensor):
2159 raise TypeError("tensor must be a Tensor: %s" % tensor)
2160 _assert_same_graph(self, tensor)
2162 # Reset cached inputs.
2163 self._inputs_val = None
2164 pywrap_tf_session.AddWhileInputHack(
2165 c_graph, # pylint: disable=protected-access
2166 tensor._as_tf_output(), # pylint: disable=protected-access
2167 self._c_op)
2169 def __str__(self):
2170 return str(self.node_def)
2172 def __repr__(self):
2173 return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
2175 def __tf_tensor__(self, dtype=None, name=None):
2176 """Raises a helpful error."""
2177 raise TypeError("can't convert Operation '{}' to Tensor".format(self.name))
2179 @property
2180 def inputs(self):
2181 """The sequence of `Tensor` objects representing the data inputs of this op."""
2182 if self._inputs_val is None:
2183 # pylint: disable=protected-access
2184 self._inputs_val = tuple(
2185 self.graph._get_tensor_by_tf_output(i)
2186 for i in pywrap_tf_session.GetOperationInputs(self._c_op))
2187 # pylint: enable=protected-access
2188 return self._inputs_val
2190 @property
2191 def _input_types(self):
2192 num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op)
2193 input_types = [
2194 dtypes.as_dtype(
2195 pywrap_tf_session.TF_OperationInputType(self._tf_input(i)))
2196 for i in range(num_inputs)
2197 ]
2198 return input_types
2200 @property
2201 def traceback(self):
2202 """Returns the call stack from when this operation was constructed."""
2203 # FIXME(b/225423591): This object contains a dangling reference if _c_op
2204 # goes out of scope.
2205 return pywrap_tf_session.TF_OperationGetStackTrace(self._c_op)
2207 @property
2208 def node_def(self):
2209 return node_def_pb2.NodeDef.FromString(self._node_def)
2211 @property
2212 def op_def(self):
2213 return op_def_pb2.OpDef.FromString(self._op_def)
2215 def _set_attr(self, attr_name, attr_value):
2216 """Private method used to set an attribute in the node_def."""
2217 buf = pywrap_tf_session.TF_NewBufferFromString(
2218 compat.as_bytes(attr_value.SerializeToString()))
2219 try:
2220 self._set_attr_with_buf(attr_name, buf)
2221 finally:
2222 pywrap_tf_session.TF_DeleteBuffer(buf)
2224 def _set_attr_with_buf(self, attr_name, attr_buf):
2225 """Set an attr in the node_def with a pre-allocated buffer."""
2226 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access
2227 # pylint: disable=protected-access
2228 pywrap_tf_session.SetAttr(c_graph, self._c_op, attr_name, attr_buf)
2229 # pylint: enable=protected-access
2231 def _set_func_attr(self, attr_name, func_name):
2232 """Private method used to set a function attribute in the node_def."""
2233 func = attr_value_pb2.NameAttrList(name=func_name)
2234 self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
2236 def _set_func_list_attr(self, attr_name, func_names):
2237 """Private method used to set a list(function) attribute in the node_def."""
2238 funcs = [attr_value_pb2.NameAttrList(name=func_name)
2239 for func_name in func_names]
2240 funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs)
2241 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list))
2243 def _set_type_list_attr(self, attr_name, types):
2244 """Private method used to set a list(type) attribute in the node_def."""
2245 if not types:
2246 return
2247 if isinstance(types[0], dtypes.DType):
2248 types = [dt.as_datatype_enum for dt in types]
2249 types_list = attr_value_pb2.AttrValue.ListValue(type=types)
2250 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
2252 def _set_shape_list_attr(self, attr_name, shapes):
2253 """Private method used to set a list(shape) attribute in the node_def."""
2254 shapes = [s.as_proto() for s in shapes]
2255 shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
2256 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
2258 def _clear_attr(self, attr_name):
2259 """Private method used to clear an attribute in the node_def."""
2260 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access
2261 # pylint: disable=protected-access
2262 pywrap_tf_session.ClearAttr(c_graph, self._c_op, attr_name)
2263 # pylint: enable=protected-access
2265 def get_attr(self, name):
2266 """Returns the value of the attr of this op with the given `name`.
2268 Args:
2269 name: The name of the attr to fetch.
2271 Returns:
2272 The value of the attr, as a Python object.
2274 Raises:
2275 ValueError: If this op does not have an attr with the given `name`.
2276 """
2277 fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
2278 try:
2279 with c_api_util.tf_buffer() as buf:
2280 pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf)
2281 data = pywrap_tf_session.TF_GetBuffer(buf)
2282 except errors.InvalidArgumentError as e:
2283 # Convert to ValueError for backwards compatibility.
2284 raise ValueError(e.message)
2285 x = attr_value_pb2.AttrValue()
2286 x.ParseFromString(data)
2288 oneof_value = x.WhichOneof("value")
2289 if oneof_value is None:
2290 return []
2291 if oneof_value == "list":
2292 for f in fields:
2293 if getattr(x.list, f):
2294 if f == "type":
2295 return [dtypes.as_dtype(t) for t in x.list.type]
2296 else:
2297 return list(getattr(x.list, f))
2298 return []
2299 if oneof_value == "type":
2300 return dtypes.as_dtype(x.type)
2301 assert oneof_value in fields, "Unsupported field type in " + str(x)
2302 return getattr(x, oneof_value)
2304 def _get_attr_type(self, name):
2305 """Returns the `DType` value of the attr of this op with the given `name`."""
2306 try:
2307 dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name)
2308 return _DTYPES_INTERN_TABLE[dtype_enum]
2309 except errors.InvalidArgumentError as e:
2310 # Convert to ValueError for backwards compatibility.
2311 raise ValueError(e.message)
2313 def _get_attr_bool(self, name):
2314 """Returns the `bool` value of the attr of this op with the given `name`."""
2315 try:
2316 return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name)
2317 except errors.InvalidArgumentError as e:
2318 # Convert to ValueError for backwards compatibility.
2319 raise ValueError(e.message)
2321 def _get_attr_int(self, name):
2322 """Returns the `int` value of the attr of this op with the given `name`."""
2323 try:
2324 return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name)
2325 except errors.InvalidArgumentError as e:
2326 # Convert to ValueError for backwards compatibility.
2327 raise ValueError(e.message)
2329 def experimental_set_type(self, type_proto):
2330 """Sets the corresponding node's `experimental_type` field.
2332 See the description of `NodeDef.experimental_type` for more info.
2334 Args:
2335 type_proto: A FullTypeDef proto message. The root type_if of this object
2336 must be `TFT_PRODUCT`, even for ops which only have a singlre return
2337 value.
2338 """
2339 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access
2340 if (type_proto.type_id
2341 not in (full_type_pb2.TFT_UNSET, full_type_pb2.TFT_PRODUCT)):
2342 raise ValueError("error setting the type of ", self.name,
2343 ": expected TFT_UNSET or TFT_PRODUCT, got ",
2344 type_proto.type_id)
2345 pywrap_tf_session.SetFullType(c_graph, self._c_op,
2346 type_proto.SerializeToString()) # pylint:disable=protected-access
2348 def run(self, feed_dict=None, session=None):
2349 """Runs this operation in a `Session`.
2351 Calling this method will execute all preceding operations that
2352 produce the inputs needed for this operation.
2354 *N.B.* Before invoking `Operation.run()`, its graph must have been
2355 launched in a session, and either a default session must be
2356 available, or `session` must be specified explicitly.
2358 Args:
2359 feed_dict: A dictionary that maps `Tensor` objects to feed values. See
2360 `tf.Session.run` for a description of the valid feed values.
2361 session: (Optional.) The `Session` to be used to run to this operation. If
2362 none, the default session will be used.
2363 """
2364 _run_using_default_session(self, feed_dict, self.graph, session)
2367# TODO(b/185395742): Clean up usages of _gradient_registry
2368gradient_registry = _gradient_registry = registry.Registry("gradient")
2371@tf_export("RegisterGradient")
2372class RegisterGradient(object):
2373 """A decorator for registering the gradient function for an op type.
2375 This decorator is only used when defining a new op type. For an op
2376 with `m` inputs and `n` outputs, the gradient function is a function
2377 that takes the original `Operation` and `n` `Tensor` objects
2378 (representing the gradients with respect to each output of the op),
2379 and returns `m` `Tensor` objects (representing the partial gradients
2380 with respect to each input of the op).
2382 For example, assuming that operations of type `"Sub"` take two
2383 inputs `x` and `y`, and return a single output `x - y`, the
2384 following gradient function would be registered:
2386 ```python
2387 @tf.RegisterGradient("Sub")
2388 def _sub_grad(unused_op, grad):
2389 return grad, tf.negative(grad)
2390 ```
2392 The decorator argument `op_type` is the string type of an
2393 operation. This corresponds to the `OpDef.name` field for the proto
2394 that defines the operation.
2395 """
2397 __slots__ = ["_op_type"]
2399 def __init__(self, op_type):
2400 """Creates a new decorator with `op_type` as the Operation type.
2402 Args:
2403 op_type: The string type of an operation. This corresponds to the
2404 `OpDef.name` field for the proto that defines the operation.
2406 Raises:
2407 TypeError: If `op_type` is not string.
2408 """
2409 if not isinstance(op_type, str):
2410 raise TypeError("op_type must be a string")
2411 self._op_type = op_type
2413 def __call__(self, f):
2414 """Registers the function `f` as gradient function for `op_type`."""
2415 gradient_registry.register(f, self._op_type)
2416 return f
2419@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient")
2420@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"])
2421def no_gradient(op_type):
2422 """Specifies that ops of type `op_type` is not differentiable.
2424 This function should *not* be used for operations that have a
2425 well-defined gradient that is not yet implemented.
2427 This function is only used when defining a new op type. It may be
2428 used for ops such as `tf.size()` that are not differentiable. For
2429 example:
2431 ```python
2432 tf.no_gradient("Size")
2433 ```
2435 The gradient computed for 'op_type' will then propagate zeros.
2437 For ops that have a well-defined gradient but are not yet implemented,
2438 no declaration should be made, and an error *must* be thrown if
2439 an attempt to request its gradient is made.
2441 Args:
2442 op_type: The string type of an operation. This corresponds to the
2443 `OpDef.name` field for the proto that defines the operation.
2445 Raises:
2446 TypeError: If `op_type` is not a string.
2448 """
2449 if not isinstance(op_type, str):
2450 raise TypeError("op_type must be a string")
2451 gradient_registry.register(None, op_type)
2454# Aliases for the old names, will be eventually removed.
2455NoGradient = no_gradient
2456NotDifferentiable = no_gradient
2459def get_gradient_function(op):
2460 """Returns the function that computes gradients for "op"."""
2461 if not op.inputs:
2462 return None
2464 gradient_function = op._gradient_function # pylint: disable=protected-access
2465 if gradient_function:
2466 return gradient_function
2468 try:
2469 op_type = op.get_attr("_gradient_op_type")
2470 except ValueError:
2471 op_type = op.type
2472 return gradient_registry.lookup(op_type)
2475def set_shape_and_handle_data_for_outputs(_):
2476 """No op. TODO(b/74620627): Remove this."""
2477 pass
2480class OpStats(object):
2481 """A holder for statistics about an operator.
2483 This class holds information about the resource requirements for an op,
2484 including the size of its weight parameters on-disk and how many FLOPS it
2485 requires to execute forward inference.
2487 If you define a new operation, you can create a function that will return a
2488 set of information about its usage of the CPU and disk space when serialized.
2489 The function itself takes a Graph object that's been set up so you can call
2490 methods like get_tensor_by_name to help calculate the results, and a NodeDef
2491 argument.
2493 """
2495 __slots__ = ["_statistic_type", "_value"]
2497 def __init__(self, statistic_type, value=None):
2498 """Sets up the initial placeholders for the statistics."""
2499 self.statistic_type = statistic_type
2500 self.value = value
2502 @property
2503 def statistic_type(self):
2504 return self._statistic_type
2506 @statistic_type.setter
2507 def statistic_type(self, statistic_type):
2508 self._statistic_type = statistic_type
2510 @property
2511 def value(self):
2512 return self._value
2514 @value.setter
2515 def value(self, value):
2516 self._value = value
2518 def __iadd__(self, other):
2519 if other.statistic_type != self.statistic_type:
2520 raise ValueError("Can't add an OpStat of type %s to one of %s." %
2521 (self.statistic_type, other.statistic_type))
2522 if self.value is None:
2523 self.value = other.value
2524 elif other.value is not None:
2525 self._value += other.value
2526 return self
2529_stats_registry = registry.Registry("statistical functions")
2532class RegisterStatistics(object):
2533 """A decorator for registering the statistics function for an op type.
2535 This decorator can be defined for an op type so that it gives a
2536 report on the resources used by an instance of an operator, in the
2537 form of an OpStats object.
2539 Well-known types of statistics include these so far:
2541 - flops: When running a graph, the bulk of the computation happens doing
2542 numerical calculations like matrix multiplications. This type allows a node
2543 to return how many floating-point operations it takes to complete. The
2544 total number of FLOPs for a graph is a good guide to its expected latency.
2546 You can add your own statistics just by picking a new type string, registering
2547 functions for the ops you care about, and then calling get_stats_for_node_def.
2549 If a statistic for an op is registered multiple times, a KeyError will be
2550 raised.
2552 Since the statistics is counted on a per-op basis. It is not suitable for
2553 model parameters (capacity), which is expected to be counted only once, even
2554 if it is shared by multiple ops. (e.g. RNN)
2556 For example, you can define a new metric called doohickey for a Foo operation
2557 by placing this in your code:
2559 ```python
2560 @ops.RegisterStatistics("Foo", "doohickey")
2561 def _calc_foo_bojangles(unused_graph, unused_node_def):
2562 return ops.OpStats("doohickey", 20)
2563 ```
2565 Then in client code you can retrieve the value by making this call:
2567 ```python
2568 doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
2569 ```
2571 If the NodeDef is for an op with a registered doohickey function, you'll get
2572 back the calculated amount in doohickey.value, or None if it's not defined.
2574 """
2576 __slots__ = ["_op_type", "_statistic_type"]
2578 def __init__(self, op_type, statistic_type):
2579 """Saves the `op_type` as the `Operation` type."""
2580 if not isinstance(op_type, str):
2581 raise TypeError("op_type must be a string.")
2582 if "," in op_type:
2583 raise TypeError("op_type must not contain a comma.")
2584 self._op_type = op_type
2585 if not isinstance(statistic_type, str):
2586 raise TypeError("statistic_type must be a string.")
2587 if "," in statistic_type:
2588 raise TypeError("statistic_type must not contain a comma.")
2589 self._statistic_type = statistic_type
2591 def __call__(self, f):
2592 """Registers "f" as the statistics function for "op_type"."""
2593 _stats_registry.register(f, self._op_type + "," + self._statistic_type)
2594 return f
2597def get_stats_for_node_def(graph, node, statistic_type):
2598 """Looks up the node's statistics function in the registry and calls it.
2600 This function takes a Graph object and a NodeDef from a GraphDef, and if
2601 there's an associated statistics method, calls it and returns a result. If no
2602 function has been registered for the particular node type, it returns an empty
2603 statistics object.
2605 Args:
2606 graph: A Graph object that's been set up with the node's graph.
2607 node: A NodeDef describing the operator.
2608 statistic_type: A string identifying the statistic we're interested in.
2610 Returns:
2611 An OpStats object containing information about resource usage.
2612 """
2614 try:
2615 stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
2616 result = stats_func(graph, node)
2617 except LookupError:
2618 result = OpStats(statistic_type)
2619 return result
2622def name_from_scope_name(name):
2623 """Returns the name of an op given the name of its scope.
2625 Args:
2626 name: the name of the scope.
2628 Returns:
2629 the name of the op (equal to scope name minus any trailing slash).
2630 """
2631 return name[:-1] if (name and name[-1] == "/") else name
2634_MUTATION_LOCK_GROUP = 0
2635_SESSION_RUN_LOCK_GROUP = 1
2638@tf_contextlib.contextmanager
2639def resource_creator_scope(resource_type, resource_creator):
2640 with get_default_graph()._resource_creator_scope(resource_type, # pylint: disable=protected-access
2641 resource_creator):
2642 yield
2645@tf_export("Graph")
2646class Graph(pywrap_tf_session.PyGraph):
2647 """A TensorFlow computation, represented as a dataflow graph.
2649 Graphs are used by `tf.function`s to represent the function's computations.
2650 Each graph contains a set of `tf.Operation` objects, which represent units of
2651 computation; and `tf.Tensor` objects, which represent the units of data that
2652 flow between operations.
2654 ### Using graphs directly (deprecated)
2656 A `tf.Graph` can be constructed and used directly without a `tf.function`, as
2657 was required in TensorFlow 1, but this is deprecated and it is recommended to
2658 use a `tf.function` instead. If a graph is directly used, other deprecated
2659 TensorFlow 1 classes are also required to execute the graph, such as a
2660 `tf.compat.v1.Session`.
2662 A default graph can be registered with the `tf.Graph.as_default` context
2663 manager. Then, operations will be added to the graph instead of being executed
2664 eagerly. For example:
2666 ```python
2667 g = tf.Graph()
2668 with g.as_default():
2669 # Define operations and tensors in `g`.
2670 c = tf.constant(30.0)
2671 assert c.graph is g
2672 ```
2674 `tf.compat.v1.get_default_graph()` can be used to obtain the default graph.
2676 Important note: This class *is not* thread-safe for graph construction. All
2677 operations should be created from a single thread, or external
2678 synchronization must be provided. Unless otherwise specified, all methods
2679 are not thread-safe.
2681 A `Graph` instance supports an arbitrary number of "collections"
2682 that are identified by name. For convenience when building a large
2683 graph, collections can store groups of related objects: for
2684 example, the `tf.Variable` uses a collection (named
2685 `tf.GraphKeys.GLOBAL_VARIABLES`) for
2686 all variables that are created during the construction of a graph. The caller
2687 may define additional collections by specifying a new name.
2688 """
2690 def __init__(self):
2691 """Creates a new, empty Graph."""
2692 super().__init__()
2693 # Protects core state that can be returned via public accessors.
2694 # Thread-safety is provided on a best-effort basis to support buggy
2695 # programs, and is not guaranteed by the public `tf.Graph` API.
2696 #
2697 # NOTE(mrry): This does not protect the various stacks. A warning will
2698 # be reported if these are used from multiple threads
2699 self._lock = threading.RLock()
2700 # The group lock synchronizes Session.run calls with methods that create
2701 # and mutate ops (e.g. Graph.create_op()). This synchronization is
2702 # necessary because it's illegal to modify an operation after it's been run.
2703 # The group lock allows any number of threads to mutate ops at the same time
2704 # but if any modification is going on, all Session.run calls have to wait.
2705 # Similarly, if one or more Session.run calls are going on, all mutate ops
2706 # have to wait until all Session.run calls have finished.
2707 self._group_lock = lock_util.GroupLock(num_groups=2)
2708 # Maps a name used in the graph to the next id to use for that name.
2709 self._names_in_use = {}
2710 self._stack_state_is_thread_local = False
2711 self._thread_local = threading.local()
2712 # Functions that will be applied to choose a device if none is specified.
2713 # In TF2.x or after switch_to_thread_local(),
2714 # self._thread_local._device_function_stack is used instead.
2715 self._graph_device_function_stack = traceable_stack.TraceableStack()
2716 # Default original_op applied to new ops.
2717 self._default_original_op = None
2718 # Current control flow context. It could be either CondContext or
2719 # WhileContext defined in ops/control_flow_ops.py
2720 self._control_flow_context = None
2721 # A new node will depend of the union of all of the nodes in the stack.
2722 # In TF2.x or after switch_to_thread_local(),
2723 # self._thread_local._control_dependencies_stack is used instead.
2724 self._graph_control_dependencies_stack = []
2725 # Arbitrary collections of objects.
2726 self._collections = {}
2727 # The graph-level random seed
2728 self._seed = None
2729 # A dictionary of attributes that should be applied to all ops.
2730 self._attr_scope_map = {}
2731 # A map from op type to the kernel label that should be used.
2732 self._op_to_kernel_label_map = {}
2733 # A map from op type to an alternative op type that should be used when
2734 # computing gradients.
2735 self._gradient_override_map = {}
2736 # A map from op type to a gradient function that should be used instead.
2737 self._gradient_function_map = {}
2738 # True if the graph is considered "finalized". In that case no
2739 # new operations can be added.
2740 self._finalized = False
2741 # Functions defined in the graph
2742 self._functions = collections.OrderedDict()
2743 # Default GraphDef versions
2744 self._graph_def_versions = versions_pb2.VersionDef(
2745 producer=versions.GRAPH_DEF_VERSION,
2746 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
2747 self._building_function = False
2748 # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(),
2749 # self._thread_local._colocation_stack is used instead.
2750 self._graph_colocation_stack = traceable_stack.TraceableStack()
2751 # Set of tensors that are dangerous to feed!
2752 self._unfeedable_tensors = object_identity.ObjectIdentitySet()
2753 # Set of operations that are dangerous to fetch!
2754 self._unfetchable_ops = set()
2755 # A map of tensor handle placeholder to tensor dtype.
2756 self._handle_feeders = {}
2757 # A map from tensor handle to its read op.
2758 self._handle_readers = {}
2759 # A map from tensor handle to its move op.
2760 self._handle_movers = {}
2761 # A map from tensor handle to its delete op.
2762 self._handle_deleters = {}
2763 # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
2764 # will be shared when defining function graphs, for example, so optimizers
2765 # being called inside function definitions behave as if they were seeing the
2766 # actual outside graph).
2767 self._graph_key = "graph-key-%d/" % (uid(),)
2768 # A string with the last reduction method passed to
2769 # losses.compute_weighted_loss(), or None. This is required only for
2770 # backward compatibility with Estimator and optimizer V1 use cases.
2771 self._last_loss_reduction = None
2772 # Flag that is used to indicate whether loss has been scaled by optimizer.
2773 # If this flag has been set, then estimator uses it to scale losss back
2774 # before reporting. This is required only for backward compatibility with
2775 # Estimator and optimizer V1 use cases.
2776 self._is_loss_scaled_by_optimizer = False
2777 self._container = ""
2779 # The current AutomaticControlDependencies context manager.
2780 self.experimental_acd_manager = None
2781 # Set to True if this graph is being built in an
2782 # AutomaticControlDependencies context.
2783 # Deprecated: use acd_manager instead.
2784 self._add_control_dependencies = False
2786 # Cache for OpDef protobufs retrieved via the C API.
2787 self._op_def_cache = {}
2788 # Cache for constant results of `broadcast_gradient_args()`. The keys are
2789 # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
2790 # values are tuples of reduction indices: (rx, ry).
2791 self._bcast_grad_args_cache = {}
2792 # Cache for constant results of `reduced_shape()`. The keys are pairs of
2793 # tuples: (input_shape_tuple, reduction_indices_tuple), and the values
2794 # are pairs of tuples: (output_shape_kept_dims, tile_scaling).
2795 self._reduced_shape_cache = {}
2797 if tf2.enabled():
2798 self.switch_to_thread_local()
2800 # `Graph` now _is_ the C graph, but we have many places that manually attempt
2801 # to manipulate the _c_graph object. Leave these accessors here until these
2802 # are cleaned up.
2803 @property
2804 def _c_graph(self):
2805 return self
2807 def __enter__(self):
2808 return self
2810 def __exit__(self, *args):
2811 return
2813 def get(self):
2814 return self
2816 # Note: this method is private because the API of tf.Graph() is public and
2817 # frozen, and this functionality is still not ready for public visibility.
2818 @tf_contextlib.contextmanager
2819 def _variable_creator_scope(self, creator, priority=100):
2820 """Scope which defines a variable creation function.
2822 Args:
2823 creator: A callable taking `next_creator` and `kwargs`. See the
2824 `tf.variable_creator_scope` docstring.
2825 priority: Creators with a higher `priority` are called first. Within the
2826 same priority, creators are called inner-to-outer.
2828 Yields:
2829 `_variable_creator_scope` is a context manager with a side effect, but
2830 doesn't return a value.
2832 Raises:
2833 RuntimeError: If variable creator scopes are not properly nested.
2834 """
2835 # This step keeps a reference to the existing stack, and it also initializes
2836 # self._thread_local._variable_creator_stack if it doesn't exist yet.
2837 old = self._variable_creator_stack
2838 new = list(old)
2839 new.append((priority, creator))
2840 # Sorting is stable, so we'll put higher-priority creators later in the list
2841 # but otherwise maintain registration order.
2842 new.sort(key=lambda item: item[0])
2843 self._thread_local._variable_creator_stack = new # pylint: disable=protected-access
2844 try:
2845 yield
2846 finally:
2847 if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access
2848 raise RuntimeError(
2849 "Exiting variable_creator_scope without proper nesting.")
2850 self._thread_local._variable_creator_stack = old # pylint: disable=protected-access
2852 # TODO(b/192405401): unify resource_creator_scope with variable_creator_scope.
2853 # pylint: disable=protected-access
2854 @tf_contextlib.contextmanager
2855 def _resource_creator_scope(self, resource_type, creator):
2856 """Scope which defines a resource creation function used by some resource.
2858 The resource should be a subclass of CapturableResource with a class method
2859 `cls._resource_type`, the output of which is what the `resource_type`
2860 argument should be. By default, `cls._resource_type` returns the class name,
2861 `cls.__name__`. Given a scope, creators being added with the same
2862 `resource_type` argument will be composed together to apply to all classes
2863 with this `_resource_type`.
2866 `creator` is expected to be a function with the following signature:
2868 ```
2869 def resource_creator(next_creator, *a, **kwargs)
2870 ```
2872 The creator is supposed to eventually call the next_creator to create an
2873 instance if it does want to create an instance and not call
2874 the class initialization method directly. This helps make creators
2875 composable. A creator may choose to create multiple instances, return
2876 already existing instances, or simply register that an instance was created
2877 and defer to the next creator in line. Creators can also modify keyword
2878 arguments seen by the next creators.
2880 Valid keyword arguments in `kwargs` depends on the specific resource
2881 class. For StaticHashTable, this may be:
2882 * initializer: The table initializer to use.
2883 * default_value: The value to use if a key is missing in the table.
2884 * name: Optional name for the table, default to None.
2887 Args:
2888 resource_type: the output of the resource class's `_resource_type` method.
2889 creator: the passed creator for the resource.
2891 Yields:
2892 A scope in which the creator is active
2894 Raises:
2895 RuntimeError: If resource_creator_scope is existed without proper nesting.
2896 """
2897 # This step keeps a reference to the existing stack, and it also initializes
2898 # self._thread_local._variable_creator_stack if it doesn't exist yet.
2899 old = self._resource_creator_stack
2900 new = copy.deepcopy(old)
2901 if isinstance(resource_type, (list, tuple)):
2902 for r in resource_type:
2903 new[r].append(creator)
2904 else:
2905 new[resource_type].append(creator)
2906 self._thread_local._resource_creator_stack = new
2907 try:
2908 yield
2909 finally:
2910 if self._thread_local._resource_creator_stack is not new:
2911 raise RuntimeError(
2912 "Exiting resource_creator_scope without proper nesting.")
2913 self._thread_local._resource_creator_stack = old
2915 @property
2916 def _resource_creator_stack(self):
2917 if not hasattr(self._thread_local, "_resource_creator_stack"):
2918 self._thread_local._resource_creator_stack = collections.defaultdict(list)
2919 return self._thread_local._resource_creator_stack
2921 @_resource_creator_stack.setter
2922 def _resource_creator_stack(self, resource_creator_stack):
2923 self._thread_local._resource_creator_stack = resource_creator_stack
2924 # pylint: enable=protected-access
2926 # Note: this method is private because the API of tf.Graph() is public and
2927 # frozen, and this functionality is still not ready for public visibility.
2928 @property
2929 def _variable_creator_stack(self):
2930 if not hasattr(self._thread_local, "_variable_creator_stack"):
2931 self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access
2933 # This previously returned a copy of the stack instead of the stack itself,
2934 # to guard against accidental mutation. Consider, however, code that wants
2935 # to save and restore the variable creator stack:
2936 # def f():
2937 # original_stack = graph._variable_creator_stack
2938 # graph._variable_creator_stack = new_stack
2939 # ... # Some code
2940 # graph._variable_creator_stack = original_stack
2941 #
2942 # And lets say you have some code that calls this function with some
2943 # variable_creator:
2944 # def g():
2945 # with variable_scope.variable_creator_scope(creator):
2946 # f()
2947 # When exiting the variable creator scope, it would see a different stack
2948 # object than it expected leading to a "Exiting variable_creator_scope
2949 # without proper nesting" error.
2950 return self._thread_local._variable_creator_stack # pylint: disable=protected-access
2952 @_variable_creator_stack.setter
2953 def _variable_creator_stack(self, variable_creator_stack):
2954 self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access
2956 def _check_not_finalized(self):
2957 """Check if the graph is finalized.
2959 Raises:
2960 RuntimeError: If the graph finalized.
2961 """
2962 if self._finalized:
2963 raise RuntimeError("Graph is finalized and cannot be modified.")
2965 @property
2966 def graph_def_versions(self):
2967 # pylint: disable=line-too-long
2968 """The GraphDef version information of this graph.
2970 For details on the meaning of each version, see
2971 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
2973 Returns:
2974 A `VersionDef`.
2975 """
2976 return versions_pb2.VersionDef.FromString(self._version_def)
2978 @property
2979 def seed(self):
2980 """The graph-level random seed of this graph."""
2981 return self._seed
2983 @seed.setter
2984 def seed(self, seed):
2985 self._seed = seed
2987 @property
2988 def finalized(self):
2989 """True if this graph has been finalized."""
2990 return self._finalized
2992 def finalize(self):
2993 """Finalizes this graph, making it read-only.
2995 After calling `g.finalize()`, no new operations can be added to
2996 `g`. This method is used to ensure that no operations are added
2997 to a graph when it is shared between multiple threads, for example
2998 when using a `tf.compat.v1.train.QueueRunner`.
2999 """
3000 self._finalized = True
3002 def _unsafe_unfinalize(self):
3003 """Opposite of `finalize`.
3005 Internal interface.
3007 NOTE: Unfinalizing a graph could have negative impact on performance,
3008 especially in a multi-threaded environment. Unfinalizing a graph
3009 when it is in use by a Session may lead to undefined behavior. Ensure
3010 that all sessions using a graph are closed before calling this method.
3011 """
3012 self._finalized = False
3014 def _get_control_flow_context(self):
3015 """Returns the current control flow context.
3017 Returns:
3018 A context object.
3019 """
3020 return self._control_flow_context
3022 def _set_control_flow_context(self, ctx):
3023 """Sets the current control flow context.
3025 Args:
3026 ctx: a context object.
3027 """
3028 self._control_flow_context = ctx
3030 def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
3031 """If this graph contains functions, copy them to `graph_def`."""
3032 bytesize = starting_bytesize
3033 for f in self._functions.values():
3034 bytesize += f.cached_definition.ByteSize()
3035 if bytesize >= (1 << 31) or bytesize < 0:
3036 raise ValueError("GraphDef cannot be larger than 2GB.")
3037 graph_def.library.function.extend([f.cached_definition])
3038 if getattr(f, "grad_func_name", None):
3039 grad_def = function_pb2.GradientDef()
3040 grad_def.function_name = f.name
3041 grad_def.gradient_func = f.grad_func_name
3042 graph_def.library.gradient.extend([grad_def])
3044 def _as_graph_def(self, from_version=None, add_shapes=False):
3045 # pylint: disable=line-too-long
3046 """Returns a serialized `GraphDef` representation of this graph.
3048 The serialized `GraphDef` can be imported into another `Graph`
3049 (using `tf.import_graph_def`) or used with the
3050 [C++ Session API](https://chromium.googlesource.com/external/github.com/tensorflow/tensorflow/+/r0.10/tensorflow/g3doc/api_docs/cc/index.md).
3052 This method is thread-safe.
3054 Args:
3055 from_version: Optional. If this is set, returns a `GraphDef` containing
3056 only the nodes that were added to this graph since its `version`
3057 property had the given value.
3058 add_shapes: If true, adds an "_output_shapes" list attr to each node with
3059 the inferred shapes of each of its outputs.
3061 Returns:
3062 A tuple containing a
3063 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3064 protocol buffer, and the version of the graph to which that
3065 `GraphDef` corresponds.
3067 Raises:
3068 ValueError: If the `graph_def` would be too large.
3070 """
3071 # pylint: enable=line-too-long
3072 with self._lock:
3073 with c_api_util.tf_buffer() as buf:
3074 with self._c_graph.get() as c_graph:
3075 pywrap_tf_session.TF_GraphToGraphDef(c_graph, buf)
3076 data = pywrap_tf_session.TF_GetBuffer(buf)
3077 graph = graph_pb2.GraphDef()
3078 graph.ParseFromString(compat.as_bytes(data))
3079 # Strip the experimental library field iff it's empty.
3080 if not graph.library.function:
3081 graph.ClearField("library")
3083 if add_shapes:
3084 for node in graph.node:
3085 op = self._get_operation_by_name(node.name)
3086 if op.outputs:
3087 node.attr["_output_shapes"].list.shape.extend(
3088 [output.get_shape().as_proto() for output in op.outputs])
3089 for function_def in graph.library.function:
3090 defined_function = self._functions[function_def.signature.name]
3091 try:
3092 func_graph = defined_function.graph
3093 except AttributeError:
3094 # _DefinedFunction doesn't have a graph, _EagerDefinedFunction
3095 # does. Both rely on ops.py, so we can't really isinstance check
3096 # them.
3097 continue
3098 input_shapes = function_def.attr["_input_shapes"]
3099 try:
3100 func_graph_inputs = func_graph.inputs
3101 except AttributeError:
3102 continue
3103 # TODO(b/141471245): Fix the inconsistency when inputs of func graph
3104 # are appended during gradient computation of while/cond.
3105 assert len(input_shapes.list.shape) in [0, len(func_graph_inputs)]
3106 # If the function_def has inputs already filled out, skip this step.
3107 if not input_shapes.list.shape:
3108 for input_tensor, arg_def in zip(func_graph_inputs,
3109 function_def.signature.input_arg):
3110 input_shapes.list.shape.add().CopyFrom(
3111 input_tensor.get_shape().as_proto())
3112 if input_tensor.dtype == dtypes.resource:
3113 _copy_handle_data_to_arg_def(input_tensor, arg_def)
3115 for output_tensor, arg_def in zip(func_graph.outputs,
3116 function_def.signature.output_arg):
3117 if output_tensor.dtype == dtypes.resource:
3118 _copy_handle_data_to_arg_def(output_tensor, arg_def)
3120 for node in function_def.node_def:
3121 try:
3122 op = func_graph.get_operation_by_name(node.name)
3123 except KeyError:
3124 continue
3125 outputs = op.outputs
3127 if op.type == "StatefulPartitionedCall":
3128 # Filter out any extra outputs (possibly added by function
3129 # backpropagation rewriting).
3130 num_outputs = len(node.attr["Tout"].list.type)
3131 outputs = outputs[:num_outputs]
3133 node.attr["_output_shapes"].list.shape.extend(
3134 [output.get_shape().as_proto() for output in outputs])
3136 return graph, self.version
3138 def as_graph_def(self, from_version=None, add_shapes=False):
3139 # pylint: disable=line-too-long
3140 """Returns a serialized `GraphDef` representation of this graph.
3142 The serialized `GraphDef` can be imported into another `Graph`
3143 (using `tf.import_graph_def`) or used with the
3144 [C++ Session API](../../api_docs/cc/index.md).
3146 This method is thread-safe.
3148 Args:
3149 from_version: Optional. If this is set, returns a `GraphDef` containing
3150 only the nodes that were added to this graph since its `version`
3151 property had the given value.
3152 add_shapes: If true, adds an "_output_shapes" list attr to each node with
3153 the inferred shapes of each of its outputs.
3155 Returns:
3156 A
3157 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3158 protocol buffer.
3160 Raises:
3161 ValueError: If the `graph_def` would be too large.
3162 """
3163 # pylint: enable=line-too-long
3164 result, _ = self._as_graph_def(from_version, add_shapes)
3165 return result
3167 def _is_function(self, name):
3168 """Tests whether 'name' is registered in this graph's function library.
3170 Args:
3171 name: string op name.
3173 Returns:
3174 bool indicating whether or not 'name' is registered in function library.
3175 """
3176 return compat.as_str(name) in self._functions
3178 def _get_function(self, name):
3179 """Returns the function definition for 'name'.
3181 Args:
3182 name: string function name.
3184 Returns:
3185 The function def proto.
3186 """
3187 return self._functions.get(compat.as_str(name), None)
3189 def _add_function_recursive(self, function, overwrite=False):
3190 """Adds function to the graph including other functions in its graph."""
3192 if self._is_function(function.name):
3193 if overwrite:
3194 self._remove_function(function.name)
3195 self._add_function(function)
3196 else:
3197 self._add_function(function)
3199 if hasattr(function, "graph"):
3200 for f in function.graph._functions.values(): # pylint: disable=protected-access
3201 if self._is_function(f.name):
3202 if overwrite:
3203 self._remove_function(f.name)
3204 self._add_function(f)
3205 else:
3206 self._add_function(f)
3208 def _add_function(self, function):
3209 """Adds a function to the graph.
3211 After the function has been added, you can call to the function by
3212 passing the function name in place of an op name to
3213 `Graph.create_op()`.
3215 Args:
3216 function: A `_DefinedFunction` object.
3218 Raises:
3219 ValueError: if another function is defined with the same name.
3220 """
3221 self._check_not_finalized()
3223 name = function.name
3224 # Sanity checks on gradient definition for deprecated _DefinedFunction.
3225 if getattr(function, "grad_func_name", None) and getattr(
3226 function, "python_grad_func", None
3227 ):
3228 raise ValueError("Gradient defined twice for function %s" % name)
3230 # Add function to graph
3231 # pylint: disable=protected-access
3232 with self._c_graph.get() as c_graph:
3233 with function._c_func.get() as func:
3234 if getattr(function, "_grad_func", None):
3235 # For deprecated _DefinedFunction.
3236 with function._grad_func._c_func.get() as gradient:
3237 pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, gradient)
3238 else:
3239 pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, None)
3240 # pylint: enable=protected-access
3242 self._functions[compat.as_str(name)] = function
3244 # Need a new-enough consumer to support the functions we add to the graph.
3245 if self._graph_def_versions.min_consumer < 12:
3246 self._graph_def_versions.min_consumer = 12
3248 def _remove_function(self, name):
3249 self._check_not_finalized()
3250 if not self._is_function(name):
3251 raise ValueError(f"Function {name!r} is not found in {self!r}.")
3253 with self._c_graph.get() as c_graph:
3254 pywrap_tf_session.TF_GraphRemoveFunction(c_graph, compat.as_bytes(name))
3255 del self._functions[compat.as_str(name)]
3257 @property
3258 def building_function(self):
3259 """Returns True iff this graph represents a function."""
3260 return self._building_function
3262 # Helper functions to create operations.
3263 @deprecated_args(None,
3264 "Shapes are always computed; don't use the compute_shapes "
3265 "as it has no effect.", "compute_shapes")
3266 @traceback_utils.filter_traceback
3267 def create_op(
3268 self,
3269 op_type,
3270 inputs,
3271 dtypes=None, # pylint: disable=redefined-outer-name
3272 input_types=None,
3273 name=None,
3274 attrs=None,
3275 op_def=None,
3276 compute_shapes=True,
3277 compute_device=True):
3278 """Creates an `Operation` in this graph.
3280 This is a low-level interface for creating an `Operation`. Most
3281 programs will not call this method directly, and instead use the
3282 Python op constructors, such as `tf.constant()`, which add ops to
3283 the default graph.
3285 Args:
3286 op_type: The `Operation` type to create. This corresponds to the
3287 `OpDef.name` field for the proto that defines the operation.
3288 inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3289 dtypes: (Optional) A list of `DType` objects that will be the types of the
3290 tensors that the operation produces.
3291 input_types: (Optional.) A list of `DType`s that will be the types of the
3292 tensors that the operation consumes. By default, uses the base `DType`
3293 of each input in `inputs`. Operations that expect reference-typed inputs
3294 must specify `input_types` explicitly.
3295 name: (Optional.) A string name for the operation. If not specified, a
3296 name is generated based on `op_type`.
3297 attrs: (Optional.) A dictionary where the key is the attribute name (a
3298 string) and the value is the respective `attr` attribute of the
3299 `NodeDef` proto that will represent the operation (an `AttrValue`
3300 proto).
3301 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3302 the operation will have.
3303 compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
3304 computed).
3305 compute_device: (Optional.) If True, device functions will be executed to
3306 compute the device property of the Operation.
3308 Raises:
3309 TypeError: if any of the inputs is not a `Tensor`.
3310 ValueError: if colocation conflicts with existing device assignment.
3312 Returns:
3313 An `Operation` object.
3314 """
3315 del compute_shapes
3316 for idx, a in enumerate(inputs):
3317 if not isinstance(a, Tensor):
3318 raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
3319 return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
3320 attrs, op_def, compute_device)
3322 def _create_op_internal(
3323 self,
3324 op_type,
3325 inputs,
3326 dtypes=None, # pylint: disable=redefined-outer-name
3327 input_types=None,
3328 name=None,
3329 attrs=None,
3330 op_def=None,
3331 compute_device=True):
3332 """Creates an `Operation` in this graph.
3334 Implements `Graph.create_op()` without the overhead of the deprecation
3335 wrapper.
3337 Args:
3338 op_type: The `Operation` type to create. This corresponds to the
3339 `OpDef.name` field for the proto that defines the operation.
3340 inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3341 dtypes: (Optional) A list of `DType` objects that will be the types of the
3342 tensors that the operation produces.
3343 input_types: (Optional.) A list of `DType`s that will be the types of the
3344 tensors that the operation consumes. By default, uses the base `DType`
3345 of each input in `inputs`. Operations that expect reference-typed inputs
3346 must specify `input_types` explicitly.
3347 name: (Optional.) A string name for the operation. If not specified, a
3348 name is generated based on `op_type`.
3349 attrs: (Optional.) A dictionary where the key is the attribute name (a
3350 string) and the value is the respective `attr` attribute of the
3351 `NodeDef` proto that will represent the operation (an `AttrValue`
3352 proto).
3353 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3354 the operation will have.
3355 compute_device: (Optional.) If True, device functions will be executed to
3356 compute the device property of the Operation.
3358 Raises:
3359 ValueError: if colocation conflicts with existing device assignment.
3361 Returns:
3362 An `Operation` object.
3363 """
3364 self._check_not_finalized()
3365 if name is None:
3366 name = op_type
3367 # If a names ends with a '/' it is a "name scope" and we use it as-is,
3368 # after removing the trailing '/'.
3369 if name and name[-1] == "/":
3370 name = name_from_scope_name(name)
3371 else:
3372 name = self.unique_name(name)
3374 node_def = _NodeDef(op_type, name, attrs)
3376 input_ops = set(t.op for t in inputs)
3377 control_inputs = self._control_dependencies_for_inputs(input_ops)
3378 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
3379 # Session.run call cannot occur between creating and mutating the op.
3380 with self._mutation_lock():
3381 ret = Operation.from_node_def(
3382 node_def,
3383 self,
3384 inputs=inputs,
3385 output_types=dtypes,
3386 control_inputs=control_inputs,
3387 input_types=input_types,
3388 original_op=self._default_original_op,
3389 op_def=op_def,
3390 )
3391 self._create_op_helper(ret, compute_device=compute_device)
3392 return ret
3394 def _create_op_from_tf_operation(self, c_op, compute_device=True):
3395 """Creates an `Operation` in this graph from the supplied TF_Operation.
3397 This method is like create_op() except the new Operation is constructed
3398 using `c_op`. The returned Operation will have `c_op` as its _c_op
3399 field. This is used to create Operation objects around TF_Operations created
3400 indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
3402 This function does not call Operation._control_flow_post_processing or
3403 Graph._control_dependencies_for_inputs (since the inputs may not be
3404 available yet). The caller is responsible for calling these methods.
3406 Args:
3407 c_op: a wrapped TF_Operation
3408 compute_device: (Optional.) If True, device functions will be executed to
3409 compute the device property of the Operation.
3411 Returns:
3412 An `Operation` object.
3413 """
3414 self._check_not_finalized()
3415 ret = Operation._from_c_op(c_op=c_op, g=self) # pylint: disable=protected-access
3416 # If a name_scope was created with ret.name but no nodes were created in it,
3417 # the name will still appear in _names_in_use even though the name hasn't
3418 # been used. This is ok, just leave _names_in_use as-is in this case.
3419 # TODO(skyewm): make the C API guarantee no name conflicts.
3420 name_key = ret.name.lower()
3421 if name_key not in self._names_in_use:
3422 self._names_in_use[name_key] = 1
3423 self._create_op_helper(ret, compute_device=compute_device)
3424 return ret
3426 def _create_op_helper(self, op, compute_device=True):
3427 """Common logic for creating an op in this graph."""
3428 # Apply any additional attributes requested. Do not overwrite any existing
3429 # attributes.
3430 for key, value in self._attr_scope_map.items():
3431 try:
3432 op.get_attr(key)
3433 except ValueError:
3434 if callable(value):
3435 value = value(op.node_def)
3436 if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
3437 raise TypeError(
3438 "Callable for scope map key '%s' must return either None or "
3439 "an AttrValue protocol buffer; but it returned: %s" %
3440 (key, value))
3441 if value:
3442 op._set_attr(key, value) # pylint: disable=protected-access
3444 # Apply a kernel label if one has been specified for this op type.
3445 try:
3446 kernel_label = self._op_to_kernel_label_map[op.type]
3447 op._set_attr("_kernel", # pylint: disable=protected-access
3448 attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
3449 except KeyError:
3450 pass
3452 op._gradient_function = self._gradient_function_map.get(op.type) # pylint: disable=protected-access
3454 # Apply the overriding op type for gradients if one has been specified for
3455 # this op type.
3456 try:
3457 mapped_op_type = self._gradient_override_map[op.type]
3458 op._set_attr("_gradient_op_type", # pylint: disable=protected-access
3459 attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
3460 except KeyError:
3461 pass
3463 self._record_op_seen_by_control_dependencies(op)
3465 if compute_device:
3466 self._apply_device_functions(op)
3468 # Snapshot the colocation stack metadata before we might generate error
3469 # messages using it. Note that this snapshot depends on the actual stack
3470 # and is independent of the op's _class attribute.
3471 # pylint: disable=protected-access
3472 op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
3473 # pylint: enable=protected-access
3475 if self._colocation_stack:
3476 all_colocation_groups = []
3477 is_device_set = False
3478 for colocation_op in self._colocation_stack.peek_objs():
3479 try:
3480 all_colocation_groups.extend(colocation_op.colocation_groups())
3481 except AttributeError:
3482 pass
3483 if colocation_op.device and not is_device_set:
3484 # pylint: disable=protected-access
3485 op._set_device(colocation_op.device)
3486 # pylint: enable=protected-access
3487 is_device_set = True
3489 all_colocation_groups = sorted(set(all_colocation_groups))
3490 # pylint: disable=protected-access
3491 op._set_attr(
3492 "_class",
3493 attr_value_pb2.AttrValue(
3494 list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
3495 # pylint: enable=protected-access
3497 # Sets "container" attribute if
3498 # (1) self._container is not None
3499 # (2) "is_stateful" is set in OpDef
3500 # (3) "container" attribute is in OpDef
3501 # (4) "container" attribute is None
3502 if self._container and op._is_stateful: # pylint: disable=protected-access
3503 try:
3504 container_attr = op.get_attr("container")
3505 except ValueError:
3506 # "container" attribute is not in OpDef
3507 pass
3508 else:
3509 if not container_attr:
3510 op._set_attr("container", attr_value_pb2.AttrValue( # pylint: disable=protected-access
3511 s=compat.as_bytes(self._container)))
3513 def _add_new_tf_operations(self, compute_devices=True):
3514 """Creates `Operations` in this graph for any new TF_Operations.
3516 This is useful for when TF_Operations are indirectly created by the C API
3517 outside of the Operation constructor (e.g. by TF_ImportGraphDef,
3518 TF_FinishWhile). This ensures there are corresponding Operations for all
3519 TF_Operations in the underlying TF_Graph.
3521 Args:
3522 compute_devices: (Optional.) If True, device functions will be executed to
3523 compute the device properties of each new Operation.
3525 Returns:
3526 A list of the new `Operation` objects.
3527 """
3528 self._check_not_finalized()
3530 # Create all Operation objects before accessing their inputs since an op may
3531 # be created before its inputs.
3532 new_ops = [
3533 self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
3534 for c_op in self.new_operations()
3535 ]
3537 # pylint: disable=protected-access
3538 for op in new_ops:
3539 new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
3540 op._add_control_inputs(new_control_inputs)
3541 op._control_flow_post_processing()
3542 # pylint: enable=protected-access
3544 return new_ops
3546 def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
3547 """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
3549 This function validates that `obj` represents an element of this
3550 graph, and gives an informative error message if it is not.
3552 This function is the canonical way to get/validate an object of
3553 one of the allowed types from an external argument reference in the
3554 Session API.
3556 This method may be called concurrently from multiple threads.
3558 Args:
3559 obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can
3560 also be any object with an `_as_graph_element()` method that returns a
3561 value of one of these types. Note: `_as_graph_element` will be called
3562 inside the graph's lock and so may not modify the graph.
3563 allow_tensor: If true, `obj` may refer to a `Tensor`.
3564 allow_operation: If true, `obj` may refer to an `Operation`.
3566 Returns:
3567 The `Tensor` or `Operation` in the Graph corresponding to `obj`.
3569 Raises:
3570 TypeError: If `obj` is not a type we support attempting to convert
3571 to types.
3572 ValueError: If `obj` is of an appropriate type but invalid. For
3573 example, an invalid string.
3574 KeyError: If `obj` is not an object in the graph.
3575 """
3576 if self._finalized:
3577 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3579 with self._lock:
3580 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3582 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
3583 """See `Graph.as_graph_element()` for details."""
3584 # The vast majority of this function is figuring
3585 # out what an API user might be doing wrong, so
3586 # that we can give helpful error messages.
3587 #
3588 # Ideally, it would be nice to split it up, but we
3589 # need context to generate nice error messages.
3591 if allow_tensor and allow_operation:
3592 types_str = "Tensor or Operation"
3593 elif allow_tensor:
3594 types_str = "Tensor"
3595 elif allow_operation:
3596 types_str = "Operation"
3597 else:
3598 raise ValueError("allow_tensor and allow_operation can't both be False.")
3600 temp_obj = _as_graph_element(obj)
3601 if temp_obj is not None:
3602 obj = temp_obj
3604 # If obj appears to be a name...
3605 if isinstance(obj, compat.bytes_or_text_types):
3606 name = compat.as_str(obj)
3608 if ":" in name and allow_tensor:
3609 # Looks like a Tensor name and can be a Tensor.
3610 try:
3611 op_name, out_n = name.split(":")
3612 out_n = int(out_n)
3613 except:
3614 raise ValueError("The name %s looks a like a Tensor name, but is "
3615 "not a valid one. Tensor names must be of the "
3616 "form \"<op_name>:<output_index>\"." % repr(name))
3617 try:
3618 op = self._get_operation_by_name(op_name)
3619 except KeyError as exc:
3620 raise KeyError(
3621 "The name %s refers to a Tensor which does not "
3622 "exist. The operation, %s, does not exist in the "
3623 "graph." % (repr(name), repr(op_name))
3624 ) from exc
3626 try:
3627 return op.outputs[out_n]
3628 except:
3629 raise KeyError("The name %s refers to a Tensor which does not "
3630 "exist. The operation, %s, exists but only has "
3631 "%s outputs." %
3632 (repr(name), repr(op_name), len(op.outputs)))
3634 elif ":" in name and not allow_tensor:
3635 # Looks like a Tensor name but can't be a Tensor.
3636 raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
3637 (repr(name), types_str))
3639 elif ":" not in name and allow_operation:
3640 # Looks like an Operation name and can be an Operation.
3641 try:
3642 op = self._get_operation_by_name(name)
3643 except KeyError as exc:
3644 raise KeyError(
3645 "The name %s refers to an Operation not in the graph."
3646 % repr(name)
3647 ) from exc
3648 return op
3650 elif ":" not in name and not allow_operation:
3651 # Looks like an Operation name but can't be an Operation.
3652 try:
3653 op = self._get_operation_by_name(name)
3654 # Yep, it's an Operation name
3655 err_msg = ("The name %s refers to an Operation, not a %s." %
3656 (repr(name), types_str))
3657 except KeyError:
3658 err_msg = ("The name %s looks like an (invalid) Operation name, "
3659 "not a %s." % (repr(name), types_str))
3660 err_msg += (" Tensor names must be of the form "
3661 "\"<op_name>:<output_index>\".")
3662 raise ValueError(err_msg)
3664 elif isinstance(obj, Tensor) and allow_tensor:
3665 # Actually obj is just the object it's referring to.
3666 if obj.graph is not self:
3667 raise ValueError("Tensor %s is not an element of this graph." % obj)
3668 return obj
3669 elif isinstance(obj, Operation) and allow_operation:
3670 # Actually obj is just the object it's referring to.
3671 if obj.graph is not self:
3672 raise ValueError("Operation %s is not an element of this graph." % obj)
3673 return obj
3674 else:
3675 # We give up!
3676 raise TypeError("Can not convert a %s into a %s." %
3677 (type(obj).__name__, types_str))
3679 def get_operation_by_name(self, name):
3680 """Returns the `Operation` with the given `name`.
3682 This method may be called concurrently from multiple threads.
3684 Args:
3685 name: The name of the `Operation` to return.
3687 Returns:
3688 The `Operation` with the given `name`.
3690 Raises:
3691 TypeError: If `name` is not a string.
3692 KeyError: If `name` does not correspond to an operation in this graph.
3693 """
3695 if not isinstance(name, str):
3696 raise TypeError("Operation names are strings (or similar), not %s." %
3697 type(name).__name__)
3698 return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
3700 def _get_operation_by_tf_operation(self, tf_oper):
3701 op_name = pywrap_tf_session.TF_OperationName(tf_oper)
3702 return self._get_operation_by_name(op_name)
3704 def get_tensor_by_name(self, name):
3705 """Returns the `Tensor` with the given `name`.
3707 This method may be called concurrently from multiple threads.
3709 Args:
3710 name: The name of the `Tensor` to return.
3712 Returns:
3713 The `Tensor` with the given `name`.
3715 Raises:
3716 TypeError: If `name` is not a string.
3717 KeyError: If `name` does not correspond to a tensor in this graph.
3718 """
3719 # Names should be strings.
3720 if not isinstance(name, str):
3721 raise TypeError("Tensor names are strings (or similar), not %s." %
3722 type(name).__name__)
3723 return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
3725 def _get_tensor_by_tf_output(self, tf_output):
3726 """Returns the `Tensor` representing `tf_output`.
3728 Note that there is only one such `Tensor`, i.e. multiple calls to this
3729 function with the same TF_Output value will always return the same `Tensor`
3730 object.
3732 Args:
3733 tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
3735 Returns:
3736 The `Tensor` that represents `tf_output`.
3737 """
3738 op = self._get_operation_by_tf_operation(tf_output.oper)
3739 return op.outputs[tf_output.index]
3741 def op_def_for_type(self, type): # pylint: disable=redefined-builtin
3742 """Returns the `OpDef` proto for `type`. `type` is a string."""
3743 # NOTE: No locking is required because the lookup and insertion operations
3744 # on Python dictionaries are atomic.
3745 try:
3746 return self._op_def_cache[type]
3747 except KeyError:
3748 self._op_def_cache[type] = op_def_pb2.OpDef.FromString(
3749 self._op_def_for_type(type)
3750 )
3751 return self._op_def_cache[type]
3753 def as_default(self):
3754 """Returns a context manager that makes this `Graph` the default graph.
3756 This method should be used if you want to create multiple graphs
3757 in the same process. For convenience, a global default graph is
3758 provided, and all ops will be added to this graph if you do not
3759 create a new graph explicitly.
3761 Use this method with the `with` keyword to specify that ops created within
3762 the scope of a block should be added to this graph. In this case, once
3763 the scope of the `with` is exited, the previous default graph is set again
3764 as default. There is a stack, so it's ok to have multiple nested levels
3765 of `as_default` calls.
3767 The default graph is a property of the current thread. If you
3768 create a new thread, and wish to use the default graph in that
3769 thread, you must explicitly add a `with g.as_default():` in that
3770 thread's function.
3772 The following code examples are equivalent:
3774 ```python
3775 # 1. Using Graph.as_default():
3776 g = tf.Graph()
3777 with g.as_default():
3778 c = tf.constant(5.0)
3779 assert c.graph is g
3781 # 2. Constructing and making default:
3782 with tf.Graph().as_default() as g:
3783 c = tf.constant(5.0)
3784 assert c.graph is g
3785 ```
3787 If eager execution is enabled ops created under this context manager will be
3788 added to the graph instead of executed eagerly.
3790 Returns:
3791 A context manager for using this graph as the default graph.
3792 """
3793 return _default_graph_stack.get_controller(self)
3795 @property
3796 def collections(self):
3797 """Returns the names of the collections known to this graph."""
3798 return list(self._collections)
3800 def add_to_collection(self, name, value):
3801 """Stores `value` in the collection with the given `name`.
3803 Note that collections are not sets, so it is possible to add a value to
3804 a collection several times.
3806 Args:
3807 name: The key for the collection. The `GraphKeys` class contains many
3808 standard names for collections.
3809 value: The value to add to the collection.
3810 """ # pylint: disable=g-doc-exception
3811 self._check_not_finalized()
3812 with self._lock:
3813 if name not in self._collections:
3814 self._collections[name] = [value]
3815 else:
3816 self._collections[name].append(value)
3818 def add_to_collections(self, names, value):
3819 """Stores `value` in the collections given by `names`.
3821 Note that collections are not sets, so it is possible to add a value to
3822 a collection several times. This function makes sure that duplicates in
3823 `names` are ignored, but it will not check for pre-existing membership of
3824 `value` in any of the collections in `names`.
3826 `names` can be any iterable, but if `names` is a string, it is treated as a
3827 single collection name.
3829 Args:
3830 names: The keys for the collections to add to. The `GraphKeys` class
3831 contains many standard names for collections.
3832 value: The value to add to the collections.
3833 """
3834 # Make sure names are unique, but treat strings as a single collection name
3835 names = (names,) if isinstance(names, str) else set(names)
3836 for name in names:
3837 self.add_to_collection(name, value)
3839 def get_collection_ref(self, name):
3840 """Returns a list of values in the collection with the given `name`.
3842 If the collection exists, this returns the list itself, which can
3843 be modified in place to change the collection. If the collection does
3844 not exist, it is created as an empty list and the list is returned.
3846 This is different from `get_collection()` which always returns a copy of
3847 the collection list if it exists and never creates an empty collection.
3849 Args:
3850 name: The key for the collection. For example, the `GraphKeys` class
3851 contains many standard names for collections.
3853 Returns:
3854 The list of values in the collection with the given `name`, or an empty
3855 list if no value has been added to that collection.
3856 """ # pylint: disable=g-doc-exception
3857 with self._lock:
3858 coll_list = self._collections.get(name, None)
3859 if coll_list is None:
3860 coll_list = []
3861 self._collections[name] = coll_list
3862 return coll_list
3864 def get_collection(self, name, scope=None):
3865 """Returns a list of values in the collection with the given `name`.
3867 This is different from `get_collection_ref()` which always returns the
3868 actual collection list if it exists in that it returns a new list each time
3869 it is called.
3871 Args:
3872 name: The key for the collection. For example, the `GraphKeys` class
3873 contains many standard names for collections.
3874 scope: (Optional.) A string. If supplied, the resulting list is filtered
3875 to include only items whose `name` attribute matches `scope` using
3876 `re.match`. Items without a `name` attribute are never returned if a
3877 scope is supplied. The choice of `re.match` means that a `scope` without
3878 special tokens filters by prefix.
3880 Returns:
3881 The list of values in the collection with the given `name`, or
3882 an empty list if no value has been added to that collection. The
3883 list contains the values in the order under which they were
3884 collected.
3885 """ # pylint: disable=g-doc-exception
3886 with self._lock:
3887 collection = self._collections.get(name, None)
3888 if collection is None:
3889 return []
3890 if scope is None:
3891 return list(collection)
3892 else:
3893 c = []
3894 regex = re.compile(scope)
3895 for item in collection:
3896 try:
3897 if regex.match(item.name):
3898 c.append(item)
3899 except AttributeError:
3900 # Collection items with no name are ignored.
3901 pass
3902 return c
3904 def get_all_collection_keys(self):
3905 """Returns a list of collections used in this graph."""
3906 with self._lock:
3907 return [x for x in self._collections if isinstance(x, str)]
3909 def clear_collection(self, name):
3910 """Clears all values in a collection.
3912 Args:
3913 name: The key for the collection. The `GraphKeys` class contains many
3914 standard names for collections.
3915 """
3916 self._check_not_finalized()
3917 with self._lock:
3918 if name in self._collections:
3919 del self._collections[name]
3921 @tf_contextlib.contextmanager
3922 def _original_op(self, op):
3923 """Python 'with' handler to help annotate ops with their originator.
3925 An op may have an 'original_op' property that indicates the op on which
3926 it was based. For example a replica op is based on the op that was
3927 replicated and a gradient op is based on the op that was differentiated.
3929 All ops created in the scope of this 'with' handler will have
3930 the given 'op' as their original op.
3932 Args:
3933 op: The Operation that all ops created in this scope will have as their
3934 original op.
3936 Yields:
3937 Nothing.
3938 """
3939 old_original_op = self._default_original_op
3940 self._default_original_op = op
3941 try:
3942 yield
3943 finally:
3944 self._default_original_op = old_original_op
3946 @property
3947 def _name_stack(self):
3948 # This may be called from a thread where name_stack doesn't yet exist.
3949 if not hasattr(self._thread_local, "_name_stack"):
3950 self._thread_local._name_stack = ""
3951 return self._thread_local._name_stack
3953 @_name_stack.setter
3954 def _name_stack(self, name_stack):
3955 self._thread_local._name_stack = name_stack
3957 # pylint: disable=g-doc-return-or-yield,line-too-long
3958 @tf_contextlib.contextmanager
3959 def name_scope(self, name):
3960 """Returns a context manager that creates hierarchical names for operations.
3962 A graph maintains a stack of name scopes. A `with name_scope(...):`
3963 statement pushes a new name onto the stack for the lifetime of the context.
3965 The `name` argument will be interpreted as follows:
3967 * A string (not ending with '/') will create a new name scope, in which
3968 `name` is appended to the prefix of all operations created in the
3969 context. If `name` has been used before, it will be made unique by
3970 calling `self.unique_name(name)`.
3971 * A scope previously captured from a `with g.name_scope(...) as
3972 scope:` statement will be treated as an "absolute" name scope, which
3973 makes it possible to re-enter existing scopes.
3974 * A value of `None` or the empty string will reset the current name scope
3975 to the top-level (empty) name scope.
3977 For example:
3979 ```python
3980 with tf.Graph().as_default() as g:
3981 c = tf.constant(5.0, name="c")
3982 assert c.op.name == "c"
3983 c_1 = tf.constant(6.0, name="c")
3984 assert c_1.op.name == "c_1"
3986 # Creates a scope called "nested"
3987 with g.name_scope("nested") as scope:
3988 nested_c = tf.constant(10.0, name="c")
3989 assert nested_c.op.name == "nested/c"
3991 # Creates a nested scope called "inner".
3992 with g.name_scope("inner"):
3993 nested_inner_c = tf.constant(20.0, name="c")
3994 assert nested_inner_c.op.name == "nested/inner/c"
3996 # Create a nested scope called "inner_1".
3997 with g.name_scope("inner"):
3998 nested_inner_1_c = tf.constant(30.0, name="c")
3999 assert nested_inner_1_c.op.name == "nested/inner_1/c"
4001 # Treats `scope` as an absolute name scope, and
4002 # switches to the "nested/" scope.
4003 with g.name_scope(scope):
4004 nested_d = tf.constant(40.0, name="d")
4005 assert nested_d.op.name == "nested/d"
4007 with g.name_scope(""):
4008 e = tf.constant(50.0, name="e")
4009 assert e.op.name == "e"
4010 ```
4012 The name of the scope itself can be captured by `with
4013 g.name_scope(...) as scope:`, which stores the name of the scope
4014 in the variable `scope`. This value can be used to name an
4015 operation that represents the overall result of executing the ops
4016 in a scope. For example:
4018 ```python
4019 inputs = tf.constant(...)
4020 with g.name_scope('my_layer') as scope:
4021 weights = tf.Variable(..., name="weights")
4022 biases = tf.Variable(..., name="biases")
4023 affine = tf.matmul(inputs, weights) + biases
4024 output = tf.nn.relu(affine, name=scope)
4025 ```
4027 NOTE: This constructor validates the given `name`. Valid scope
4028 names match one of the following regular expressions:
4030 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
4031 [A-Za-z0-9_.\\-/]* (for other scopes)
4033 Args:
4034 name: A name for the scope.
4036 Returns:
4037 A context manager that installs `name` as a new name scope.
4039 Raises:
4040 ValueError: If `name` is not a valid scope name, according to the rules
4041 above.
4042 """
4043 if name:
4044 if isinstance(name, compat.bytes_or_text_types):
4045 name = compat.as_str(name)
4047 if self._name_stack:
4048 # Scopes created in a nested scope may have initial characters
4049 # that are illegal as the initial character of an op name
4050 # (viz. '-', '\', '/', and '_').
4051 if not _VALID_SCOPE_NAME_REGEX.match(name):
4052 raise ValueError(
4053 f"'{name}' is not a valid scope name. A scope name has to match "
4054 f"the following pattern: {_VALID_SCOPE_NAME_REGEX.pattern}")
4055 else:
4056 # Scopes created in the root must match the more restrictive
4057 # op name regex, which constrains the initial character.
4058 if not _VALID_OP_NAME_REGEX.match(name):
4059 raise ValueError(
4060 f"'{name}' is not a valid root scope name. A root scope name has "
4061 f"to match the following pattern: {_VALID_OP_NAME_REGEX.pattern}")
4062 old_stack = self._name_stack
4063 if not name: # Both for name=None and name="" we re-set to empty scope.
4064 new_stack = ""
4065 returned_scope = ""
4066 elif name[-1] == "/":
4067 new_stack = name_from_scope_name(name)
4068 returned_scope = name
4069 else:
4070 new_stack = self.unique_name(name)
4071 returned_scope = new_stack + "/"
4072 self._name_stack = new_stack
4073 try:
4074 yield returned_scope
4075 finally:
4076 self._name_stack = old_stack
4078 # pylint: enable=g-doc-return-or-yield,line-too-long
4080 def unique_name(self, name, mark_as_used=True):
4081 """Return a unique operation name for `name`.
4083 Note: You rarely need to call `unique_name()` directly. Most of
4084 the time you just need to create `with g.name_scope()` blocks to
4085 generate structured names.
4087 `unique_name` is used to generate structured names, separated by
4088 `"/"`, to help identify operations when debugging a graph.
4089 Operation names are displayed in error messages reported by the
4090 TensorFlow runtime, and in various visualization tools such as
4091 TensorBoard.
4093 If `mark_as_used` is set to `True`, which is the default, a new
4094 unique name is created and marked as in use. If it's set to `False`,
4095 the unique name is returned without actually being marked as used.
4096 This is useful when the caller simply wants to know what the name
4097 to be created will be.
4099 Args:
4100 name: The name for an operation.
4101 mark_as_used: Whether to mark this name as being used.
4103 Returns:
4104 A string to be passed to `create_op()` that will be used
4105 to name the operation being created.
4106 """
4107 if self._name_stack:
4108 name = self._name_stack + "/" + name
4110 # For the sake of checking for names in use, we treat names as case
4111 # insensitive (e.g. foo = Foo).
4112 name_key = name.lower()
4113 i = self._names_in_use.get(name_key, 0)
4114 # Increment the number for "name_key".
4115 if mark_as_used:
4116 self._names_in_use[name_key] = i + 1
4117 if i > 0:
4118 base_name_key = name_key
4119 # Make sure the composed name key is not already used.
4120 while name_key in self._names_in_use:
4121 name_key = "%s_%d" % (base_name_key, i)
4122 i += 1
4123 # Mark the composed name_key as used in case someone wants
4124 # to call unique_name("name_1").
4125 if mark_as_used:
4126 self._names_in_use[name_key] = 1
4128 # Return the new name with the original capitalization of the given name.
4129 name = "%s_%d" % (name, i - 1)
4130 return name
4132 def get_name_scope(self):
4133 """Returns the current name scope.
4135 For example:
4137 ```python
4138 with tf.name_scope('scope1'):
4139 with tf.name_scope('scope2'):
4140 print(tf.compat.v1.get_default_graph().get_name_scope())
4141 ```
4142 would print the string `scope1/scope2`.
4144 Returns:
4145 A string representing the current name scope.
4146 """
4147 return self._name_stack
4149 @tf_contextlib.contextmanager
4150 def _colocate_with_for_gradient(self, op, gradient_uid,
4151 ignore_existing=False):
4152 with self.colocate_with(op, ignore_existing):
4153 if gradient_uid is not None:
4154 ctx = _get_enclosing_context(self)
4155 if ctx is not None:
4156 ctx.EnterGradientColocation(op, gradient_uid)
4157 try:
4158 yield
4159 finally:
4160 ctx.ExitGradientColocation(op, gradient_uid)
4161 else:
4162 yield
4163 else:
4164 yield
4166 @tf_contextlib.contextmanager
4167 def colocate_with(self, op, ignore_existing=False):
4168 """Returns a context manager that specifies an op to colocate with.
4170 Note: this function is not for public use, only for internal libraries.
4172 For example:
4174 ```python
4175 a = tf.Variable([1.0])
4176 with g.colocate_with(a):
4177 b = tf.constant(1.0)
4178 c = tf.add(a, b)
4179 ```
4181 `b` and `c` will always be colocated with `a`, no matter where `a`
4182 is eventually placed.
4184 **NOTE** Using a colocation scope resets any existing device constraints.
4186 If `op` is `None` then `ignore_existing` must be `True` and the new
4187 scope resets all colocation and device constraints.
4189 Args:
4190 op: The op to colocate all created ops with, or `None`.
4191 ignore_existing: If true, only applies colocation of this op within the
4192 context, rather than applying all colocation properties on the stack.
4193 If `op` is `None`, this value must be `True`.
4195 Raises:
4196 ValueError: if op is None but ignore_existing is False.
4198 Yields:
4199 A context manager that specifies the op with which to colocate
4200 newly created ops.
4201 """
4202 if op is None and not ignore_existing:
4203 raise ValueError("Trying to reset colocation (op is None) but "
4204 "ignore_existing is not True")
4205 op, device_only_candidate = _op_to_colocate_with(op, self)
4207 # By default, colocate_with resets the device function stack,
4208 # since colocate_with is typically used in specific internal
4209 # library functions where colocation is intended to be "stronger"
4210 # than device functions.
4211 #
4212 # In the future, a caller may specify that device_functions win
4213 # over colocation, in which case we can add support.
4214 device_fn_tmp = self._device_function_stack
4215 self._device_function_stack = traceable_stack.TraceableStack()
4217 if ignore_existing:
4218 current_stack = self._colocation_stack
4219 self._colocation_stack = traceable_stack.TraceableStack()
4221 if op is not None:
4222 # offset refers to the stack frame used for storing code location.
4223 # We use 4, the sum of 1 to use our caller's stack frame and 3
4224 # to jump over layers of context managers above us.
4225 self._colocation_stack.push_obj(op, offset=4)
4226 if device_only_candidate is not None:
4227 self._colocation_stack.push_obj(device_only_candidate, offset=4)
4228 elif not ignore_existing:
4229 raise ValueError("Trying to reset colocation (op is None) but "
4230 "ignore_existing is not True")
4231 try:
4232 yield
4233 finally:
4234 # Restore device function stack
4235 self._device_function_stack = device_fn_tmp
4236 if op is not None:
4237 self._colocation_stack.pop_obj()
4238 if device_only_candidate is not None:
4239 self._colocation_stack.pop_obj()
4241 # Reset the colocation stack if requested.
4242 if ignore_existing:
4243 self._colocation_stack = current_stack
4245 def _add_device_to_stack(self, device_name_or_function, offset=0):
4246 """Add device to stack manually, separate from a context manager."""
4247 total_offset = 1 + offset
4248 spec = _UserDeviceSpec(device_name_or_function)
4249 self._device_function_stack.push_obj(spec, offset=total_offset)
4250 return spec
4252 @tf_contextlib.contextmanager
4253 def device(self, device_name_or_function):
4254 # pylint: disable=line-too-long
4255 """Returns a context manager that specifies the default device to use.
4257 The `device_name_or_function` argument may either be a device name
4258 string, a device function, or None:
4260 * If it is a device name string, all operations constructed in
4261 this context will be assigned to the device with that name, unless
4262 overridden by a nested `device()` context.
4263 * If it is a function, it will be treated as a function from
4264 Operation objects to device name strings, and invoked each time
4265 a new Operation is created. The Operation will be assigned to
4266 the device with the returned name.
4267 * If it is None, all `device()` invocations from the enclosing context
4268 will be ignored.
4270 For information about the valid syntax of device name strings, see
4271 the documentation in
4272 [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
4274 For example:
4276 ```python
4277 with g.device('/device:GPU:0'):
4278 # All operations constructed in this context will be placed
4279 # on GPU 0.
4280 with g.device(None):
4281 # All operations constructed in this context will have no
4282 # assigned device.
4284 # Defines a function from `Operation` to device string.
4285 def matmul_on_gpu(n):
4286 if n.type == "MatMul":
4287 return "/device:GPU:0"
4288 else:
4289 return "/cpu:0"
4291 with g.device(matmul_on_gpu):
4292 # All operations of type "MatMul" constructed in this context
4293 # will be placed on GPU 0; all other operations will be placed
4294 # on CPU 0.
4295 ```
4297 **N.B.** The device scope may be overridden by op wrappers or
4298 other library code. For example, a variable assignment op
4299 `v.assign()` must be colocated with the `tf.Variable` `v`, and
4300 incompatible device scopes will be ignored.
4302 Args:
4303 device_name_or_function: The device name or function to use in the
4304 context.
4306 Yields:
4307 A context manager that specifies the default device to use for newly
4308 created ops.
4310 Raises:
4311 RuntimeError: If device scopes are not properly nested.
4312 """
4313 self._add_device_to_stack(device_name_or_function, offset=2)
4314 old_top_of_stack = self._device_function_stack.peek_top_obj()
4315 try:
4316 yield
4317 finally:
4318 new_top_of_stack = self._device_function_stack.peek_top_obj()
4319 if old_top_of_stack is not new_top_of_stack:
4320 raise RuntimeError("Exiting device scope without proper scope nesting.")
4321 self._device_function_stack.pop_obj()
4323 def _apply_device_functions(self, op):
4324 """Applies the current device function stack to the given operation."""
4325 # Apply any device functions in LIFO order, so that the most recently
4326 # pushed function has the first chance to apply a device to the op.
4327 # We apply here because the result can depend on the Operation's
4328 # signature, which is computed in the Operation constructor.
4329 # pylint: disable=protected-access
4330 prior_device_string = None
4331 for device_spec in self._device_function_stack.peek_objs():
4332 if device_spec.is_null_merge:
4333 continue
4335 if device_spec.function is None:
4336 break
4338 device_string = device_spec.string_merge(op)
4340 # Take advantage of the fact that None is a singleton and Python interns
4341 # strings, since identity checks are faster than equality checks.
4342 if device_string is not prior_device_string:
4343 op._set_device_from_string(device_string)
4344 prior_device_string = device_string
4345 op._device_code_locations = self._snapshot_device_function_stack_metadata()
4346 # pylint: enable=protected-access
4348 # pylint: disable=g-doc-return-or-yield
4349 @tf_contextlib.contextmanager
4350 def container(self, container_name):
4351 """Returns a context manager that specifies the resource container to use.
4353 Stateful operations, such as variables and queues, can maintain their
4354 states on devices so that they can be shared by multiple processes.
4355 A resource container is a string name under which these stateful
4356 operations are tracked. These resources can be released or cleared
4357 with `tf.Session.reset()`.
4359 For example:
4361 ```python
4362 with g.container('experiment0'):
4363 # All stateful Operations constructed in this context will be placed
4364 # in resource container "experiment0".
4365 v1 = tf.Variable([1.0])
4366 v2 = tf.Variable([2.0])
4367 with g.container("experiment1"):
4368 # All stateful Operations constructed in this context will be
4369 # placed in resource container "experiment1".
4370 v3 = tf.Variable([3.0])
4371 q1 = tf.queue.FIFOQueue(10, tf.float32)
4372 # All stateful Operations constructed in this context will be
4373 # be created in the "experiment0".
4374 v4 = tf.Variable([4.0])
4375 q1 = tf.queue.FIFOQueue(20, tf.float32)
4376 with g.container(""):
4377 # All stateful Operations constructed in this context will be
4378 # be placed in the default resource container.
4379 v5 = tf.Variable([5.0])
4380 q3 = tf.queue.FIFOQueue(30, tf.float32)
4382 # Resets container "experiment0", after which the state of v1, v2, v4, q1
4383 # will become undefined (such as uninitialized).
4384 tf.Session.reset(target, ["experiment0"])
4385 ```
4387 Args:
4388 container_name: container name string.
4390 Returns:
4391 A context manager for defining resource containers for stateful ops,
4392 yields the container name.
4393 """
4394 original_container = self._container
4395 self._container = container_name
4396 try:
4397 yield self._container
4398 finally:
4399 self._container = original_container
4401 # pylint: enable=g-doc-return-or-yield
4403 class _ControlDependenciesController(object):
4404 """Context manager for `control_dependencies()`."""
4406 def __init__(self, graph, control_inputs):
4407 """Create a new `_ControlDependenciesController`.
4409 A `_ControlDependenciesController` is the context manager for
4410 `with tf.control_dependencies()` blocks. These normally nest,
4411 as described in the documentation for `control_dependencies()`.
4413 The `control_inputs` argument list control dependencies that must be
4414 added to the current set of control dependencies. Because of
4415 uniquification the set can be empty even if the caller passed a list of
4416 ops. The special value `None` indicates that we want to start a new
4417 empty set of control dependencies instead of extending the current set.
4419 In that case we also clear the current control flow context, which is an
4420 additional mechanism to add control dependencies.
4422 Args:
4423 graph: The graph that this controller is managing.
4424 control_inputs: List of ops to use as control inputs in addition to the
4425 current control dependencies. None to indicate that the dependencies
4426 should be cleared.
4427 """
4428 self._graph = graph
4429 if control_inputs is None:
4430 self._control_inputs_val = []
4431 self._new_stack = True
4432 else:
4433 self._control_inputs_val = control_inputs
4434 self._new_stack = False
4435 self._seen_nodes = set()
4436 self._old_stack = None
4437 self._old_control_flow_context = None
4439 # pylint: disable=protected-access
4441 def __enter__(self):
4442 if self._new_stack:
4443 # Clear the control_dependencies graph.
4444 self._old_stack = self._graph._control_dependencies_stack
4445 self._graph._control_dependencies_stack = []
4446 # Clear the control_flow_context too.
4447 self._old_control_flow_context = self._graph._get_control_flow_context()
4448 self._graph._set_control_flow_context(None)
4449 self._graph._push_control_dependencies_controller(self)
4451 def __exit__(self, unused_type, unused_value, unused_traceback):
4452 self._graph._pop_control_dependencies_controller(self)
4453 if self._new_stack:
4454 self._graph._control_dependencies_stack = self._old_stack
4455 self._graph._set_control_flow_context(self._old_control_flow_context)
4457 # pylint: enable=protected-access
4459 @property
4460 def control_inputs(self):
4461 return self._control_inputs_val
4463 def add_op(self, op):
4464 if isinstance(op, Tensor):
4465 op = op.ref()
4466 self._seen_nodes.add(op)
4468 def op_in_group(self, op):
4469 if isinstance(op, Tensor):
4470 op = op.ref()
4471 return op in self._seen_nodes
4473 def _push_control_dependencies_controller(self, controller):
4474 self._control_dependencies_stack.append(controller)
4476 def _pop_control_dependencies_controller(self, controller):
4477 assert self._control_dependencies_stack[-1] is controller
4478 self._control_dependencies_stack.pop()
4480 def _current_control_dependencies(self):
4481 ret = set()
4482 for controller in self._control_dependencies_stack:
4483 for op in controller.control_inputs:
4484 ret.add(op)
4485 return ret
4487 def _control_dependencies_for_inputs(self, input_ops):
4488 """For an op that takes `input_ops` as inputs, compute control inputs.
4490 The returned control dependencies should yield an execution that
4491 is equivalent to adding all control inputs in
4492 self._control_dependencies_stack to a newly created op. However,
4493 this function attempts to prune the returned control dependencies
4494 by observing that nodes created within the same `with
4495 control_dependencies(...):` block may have data dependencies that make
4496 the explicit approach redundant.
4498 Args:
4499 input_ops: The data input ops for an op to be created.
4501 Returns:
4502 A list of control inputs for the op to be created.
4503 """
4504 ret = []
4505 for controller in self._control_dependencies_stack:
4506 # If any of the input_ops already depends on the inputs from controller,
4507 # we say that the new op is dominated (by that input), and we therefore
4508 # do not need to add control dependencies for this controller's inputs.
4509 dominated = False
4510 for op in input_ops:
4511 if controller.op_in_group(op):
4512 dominated = True
4513 break
4514 if not dominated:
4515 # Don't add a control input if we already have a data dependency on i.
4516 # NOTE(mrry): We do not currently track transitive data dependencies,
4517 # so we may add redundant control inputs.
4518 ret.extend(c for c in controller.control_inputs if c not in input_ops)
4519 return ret
4521 def _record_op_seen_by_control_dependencies(self, op):
4522 """Record that the given op depends on all registered control dependencies.
4524 Args:
4525 op: An Operation.
4526 """
4527 for controller in self._control_dependencies_stack:
4528 controller.add_op(op)
4530 def control_dependencies(self, control_inputs):
4531 """Returns a context manager that specifies control dependencies.
4533 Use with the `with` keyword to specify that all operations constructed
4534 within the context should have control dependencies on
4535 `control_inputs`. For example:
4537 ```python
4538 with g.control_dependencies([a, b, c]):
4539 # `d` and `e` will only run after `a`, `b`, and `c` have executed.
4540 d = ...
4541 e = ...
4542 ```
4544 Multiple calls to `control_dependencies()` can be nested, and in
4545 that case a new `Operation` will have control dependencies on the union
4546 of `control_inputs` from all active contexts.
4548 ```python
4549 with g.control_dependencies([a, b]):
4550 # Ops constructed here run after `a` and `b`.
4551 with g.control_dependencies([c, d]):
4552 # Ops constructed here run after `a`, `b`, `c`, and `d`.
4553 ```
4555 You can pass None to clear the control dependencies:
4557 ```python
4558 with g.control_dependencies([a, b]):
4559 # Ops constructed here run after `a` and `b`.
4560 with g.control_dependencies(None):
4561 # Ops constructed here run normally, not waiting for either `a` or `b`.
4562 with g.control_dependencies([c, d]):
4563 # Ops constructed here run after `c` and `d`, also not waiting
4564 # for either `a` or `b`.
4565 ```
4567 *N.B.* The control dependencies context applies *only* to ops that
4568 are constructed within the context. Merely using an op or tensor
4569 in the context does not add a control dependency. The following
4570 example illustrates this point:
4572 ```python
4573 # WRONG
4574 def my_func(pred, tensor):
4575 t = tf.matmul(tensor, tensor)
4576 with tf.control_dependencies([pred]):
4577 # The matmul op is created outside the context, so no control
4578 # dependency will be added.
4579 return t
4581 # RIGHT
4582 def my_func(pred, tensor):
4583 with tf.control_dependencies([pred]):
4584 # The matmul op is created in the context, so a control dependency
4585 # will be added.
4586 return tf.matmul(tensor, tensor)
4587 ```
4589 Also note that though execution of ops created under this scope will trigger
4590 execution of the dependencies, the ops created under this scope might still
4591 be pruned from a normal tensorflow graph. For example, in the following
4592 snippet of code the dependencies are never executed:
4594 ```python
4595 loss = model.loss()
4596 with tf.control_dependencies(dependencies):
4597 loss = loss + tf.constant(1) # note: dependencies ignored in the
4598 # backward pass
4599 return tf.gradients(loss, model.variables)
4600 ```
4602 This is because evaluating the gradient graph does not require evaluating
4603 the constant(1) op created in the forward pass.
4605 Args:
4606 control_inputs: A list of `Operation` or `Tensor` objects which must be
4607 executed or computed before running the operations defined in the
4608 context. Can also be `None` to clear the control dependencies.
4610 Returns:
4611 A context manager that specifies control dependencies for all
4612 operations constructed within the context.
4614 Raises:
4615 TypeError: If `control_inputs` is not a list of `Operation` or
4616 `Tensor` objects.
4617 """
4618 if control_inputs is None:
4619 return self._ControlDependenciesController(self, None)
4620 # First convert the inputs to ops, and deduplicate them.
4621 # NOTE(mrry): Other than deduplication, we do not currently track direct
4622 # or indirect dependencies between control_inputs, which may result in
4623 # redundant control inputs.
4624 control_ops = []
4625 current = self._current_control_dependencies()
4626 for c in control_inputs:
4627 # The hasattr(handle) is designed to match ResourceVariables. This is so
4628 # control dependencies on a variable or on an unread variable don't
4629 # trigger reads.
4630 if (isinstance(c, internal.IndexedSlices) or
4631 (hasattr(c, "_handle") and hasattr(c, "op"))):
4632 c = c.op
4633 c = self.as_graph_element(c)
4634 if isinstance(c, Tensor):
4635 c = c.op
4636 elif not isinstance(c, Operation):
4637 raise TypeError("Control input must be Operation or Tensor: %s" % c)
4638 if c not in current:
4639 control_ops.append(c)
4640 current.add(c)
4641 # Mark this op with an attribute indicating that it is used as a manual
4642 # control dep in order to allow tracking how common utilization of
4643 # manual control deps in graphs run through the MLIR Bridge are. See
4644 # go/manual-control-dependencies-bridge for details.
4645 # pylint: disable=protected-access
4646 c._set_attr("_has_manual_control_dependencies",
4647 attr_value_pb2.AttrValue(b=True))
4648 # pylint: enable=protected-access
4649 return self._ControlDependenciesController(self, control_ops)
4651 # pylint: disable=g-doc-return-or-yield
4652 @tf_contextlib.contextmanager
4653 def _attr_scope(self, attr_map):
4654 """EXPERIMENTAL: A context manager for setting attributes on operators.
4656 This context manager can be used to add additional
4657 attributes to operators within the scope of the context.
4659 For example:
4661 with ops.Graph().as_default() as g:
4662 f_1 = Foo() # No extra attributes
4663 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
4664 f_2 = Foo() # Additional attribute _a=False
4665 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
4666 f_3 = Foo() # Additional attribute _a=False
4667 with g._attr_scope({"_a": None}):
4668 f_4 = Foo() # No additional attributes.
4670 Args:
4671 attr_map: A dictionary mapping attr name strings to AttrValue protocol
4672 buffers or None.
4674 Returns:
4675 A context manager that sets the kernel label to be used for one or more
4676 ops created in that context.
4678 Raises:
4679 TypeError: If attr_map is not a dictionary mapping
4680 strings to AttrValue protobufs.
4681 """
4682 if not isinstance(attr_map, dict):
4683 raise TypeError("attr_map must be a dictionary mapping "
4684 "strings to AttrValue protocol buffers")
4685 # The saved_attrs dictionary stores any currently-set labels that
4686 # will be overridden by this context manager.
4687 saved_attrs = {}
4688 # Install the given attribute
4689 for name, attr in attr_map.items():
4690 if not (isinstance(name, str) and
4691 (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
4692 callable(attr))):
4693 raise TypeError("attr_map must be a dictionary mapping "
4694 "strings to AttrValue protocol buffers or "
4695 "callables that emit AttrValue protocol buffers")
4696 try:
4697 saved_attrs[name] = self._attr_scope_map[name]
4698 except KeyError:
4699 pass
4700 if attr is None:
4701 del self._attr_scope_map[name]
4702 else:
4703 self._attr_scope_map[name] = attr
4704 try:
4705 yield # The code within the context runs here.
4706 finally:
4707 # Remove the attributes set for this context, and restore any saved
4708 # attributes.
4709 for name, attr in attr_map.items():
4710 try:
4711 self._attr_scope_map[name] = saved_attrs[name]
4712 except KeyError:
4713 del self._attr_scope_map[name]
4715 # pylint: enable=g-doc-return-or-yield
4717 # pylint: disable=g-doc-return-or-yield
4718 @tf_contextlib.contextmanager
4719 def _kernel_label_map(self, op_to_kernel_label_map):
4720 """EXPERIMENTAL: A context manager for setting kernel labels.
4722 This context manager can be used to select particular
4723 implementations of kernels within the scope of the context.
4725 For example:
4727 with ops.Graph().as_default() as g:
4728 f_1 = Foo() # Uses the default registered kernel for the Foo op.
4729 with g.kernel_label_map({"Foo": "v_2"}):
4730 f_2 = Foo() # Uses the registered kernel with label "v_2"
4731 # for the Foo op.
4732 with g.kernel_label_map({"Foo": "v_3"}):
4733 f_3 = Foo() # Uses the registered kernel with label "v_3"
4734 # for the Foo op.
4735 with g.kernel_label_map({"Foo": ""}):
4736 f_4 = Foo() # Uses the default registered kernel
4737 # for the Foo op.
4739 Args:
4740 op_to_kernel_label_map: A dictionary mapping op type strings to kernel
4741 label strings.
4743 Returns:
4744 A context manager that sets the kernel label to be used for one or more
4745 ops created in that context.
4747 Raises:
4748 TypeError: If op_to_kernel_label_map is not a dictionary mapping
4749 strings to strings.
4750 """
4751 if not isinstance(op_to_kernel_label_map, dict):
4752 raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4753 "strings to strings")
4754 # The saved_labels dictionary stores any currently-set labels that
4755 # will be overridden by this context manager.
4756 saved_labels = {}
4757 # Install the given label
4758 for op_type, label in op_to_kernel_label_map.items():
4759 if not (isinstance(op_type, str) and
4760 isinstance(label, str)):
4761 raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4762 "strings to strings")
4763 try:
4764 saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
4765 except KeyError:
4766 pass
4767 self._op_to_kernel_label_map[op_type] = label
4768 try:
4769 yield # The code within the context runs here.
4770 finally:
4771 # Remove the labels set for this context, and restore any saved labels.
4772 for op_type, label in op_to_kernel_label_map.items():
4773 try:
4774 self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
4775 except KeyError:
4776 del self._op_to_kernel_label_map[op_type]
4778 # pylint: enable=g-doc-return-or-yield
4780 @tf_contextlib.contextmanager
4781 def _override_gradient_function(self, gradient_function_map):
4782 """Specify gradient function for the given op type."""
4784 # This is an internal API and we don't need nested context for this.
4785 # TODO(mdan): make it a proper context manager.
4786 assert not self._gradient_function_map
4787 self._gradient_function_map = gradient_function_map
4788 try:
4789 yield
4790 finally:
4791 self._gradient_function_map = {}
4793 # pylint: disable=g-doc-return-or-yield
4794 @tf_contextlib.contextmanager
4795 def gradient_override_map(self, op_type_map):
4796 """EXPERIMENTAL: A context manager for overriding gradient functions.
4798 This context manager can be used to override the gradient function
4799 that will be used for ops within the scope of the context.
4801 For example:
4803 ```python
4804 @tf.RegisterGradient("CustomSquare")
4805 def _custom_square_grad(op, grad):
4806 # ...
4808 with tf.Graph().as_default() as g:
4809 c = tf.constant(5.0)
4810 s_1 = tf.square(c) # Uses the default gradient for tf.square.
4811 with g.gradient_override_map({"Square": "CustomSquare"}):
4812 s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the
4813 # gradient of s_2.
4814 ```
4816 Args:
4817 op_type_map: A dictionary mapping op type strings to alternative op type
4818 strings.
4820 Returns:
4821 A context manager that sets the alternative op type to be used for one
4822 or more ops created in that context.
4824 Raises:
4825 TypeError: If `op_type_map` is not a dictionary mapping strings to
4826 strings.
4827 """
4828 if not isinstance(op_type_map, dict):
4829 raise TypeError("op_type_map must be a dictionary mapping "
4830 "strings to strings")
4831 # The saved_mappings dictionary stores any currently-set mappings that
4832 # will be overridden by this context manager.
4833 saved_mappings = {}
4834 # Install the given label
4835 for op_type, mapped_op_type in op_type_map.items():
4836 if not (isinstance(op_type, str) and
4837 isinstance(mapped_op_type, str)):
4838 raise TypeError("op_type_map must be a dictionary mapping "
4839 "strings to strings")
4840 try:
4841 saved_mappings[op_type] = self._gradient_override_map[op_type]
4842 except KeyError:
4843 pass
4844 self._gradient_override_map[op_type] = mapped_op_type
4845 try:
4846 yield # The code within the context runs here.
4847 finally:
4848 # Remove the labels set for this context, and restore any saved labels.
4849 for op_type, mapped_op_type in op_type_map.items():
4850 try:
4851 self._gradient_override_map[op_type] = saved_mappings[op_type]
4852 except KeyError:
4853 del self._gradient_override_map[op_type]
4855 # pylint: enable=g-doc-return-or-yield
4857 def prevent_feeding(self, tensor):
4858 """Marks the given `tensor` as unfeedable in this graph."""
4859 self._unfeedable_tensors.add(tensor)
4861 def is_feedable(self, tensor):
4862 """Returns `True` if and only if `tensor` is feedable."""
4863 return tensor not in self._unfeedable_tensors
4865 def prevent_fetching(self, op):
4866 """Marks the given `op` as unfetchable in this graph."""
4867 self._unfetchable_ops.add(op)
4869 def is_fetchable(self, tensor_or_op):
4870 """Returns `True` if and only if `tensor_or_op` is fetchable."""
4871 if isinstance(tensor_or_op, Tensor):
4872 return tensor_or_op.op not in self._unfetchable_ops
4873 else:
4874 return tensor_or_op not in self._unfetchable_ops
4876 def switch_to_thread_local(self):
4877 """Make device, colocation and dependencies stacks thread-local.
4879 Device, colocation and dependencies stacks are not thread-local be default.
4880 If multiple threads access them, then the state is shared. This means that
4881 one thread may affect the behavior of another thread.
4883 After this method is called, the stacks become thread-local. If multiple
4884 threads access them, then the state is not shared. Each thread uses its own
4885 value; a thread doesn't affect other threads by mutating such a stack.
4887 The initial value for every thread's stack is set to the current value
4888 of the stack when `switch_to_thread_local()` was first called.
4889 """
4890 if not self._stack_state_is_thread_local:
4891 self._stack_state_is_thread_local = True
4893 @property
4894 def _device_function_stack(self):
4895 if self._stack_state_is_thread_local:
4896 # This may be called from a thread where device_function_stack doesn't yet
4897 # exist.
4898 # pylint: disable=protected-access
4899 if not hasattr(self._thread_local, "_device_function_stack"):
4900 stack_copy_for_this_thread = self._graph_device_function_stack.copy()
4901 self._thread_local._device_function_stack = stack_copy_for_this_thread
4902 return self._thread_local._device_function_stack
4903 # pylint: enable=protected-access
4904 else:
4905 return self._graph_device_function_stack
4907 @property
4908 def _device_functions_outer_to_inner(self):
4909 user_device_specs = self._device_function_stack.peek_objs()
4910 device_functions = [spec.function for spec in user_device_specs]
4911 device_functions_outer_to_inner = list(reversed(device_functions))
4912 return device_functions_outer_to_inner
4914 def _snapshot_device_function_stack_metadata(self):
4915 """Return device function stack as a list of TraceableObjects.
4917 Returns:
4918 [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
4919 member is a displayable name for the user's argument to Graph.device, and
4920 the filename and lineno members point to the code location where
4921 Graph.device was called directly or indirectly by the user.
4922 """
4923 snapshot = []
4924 for obj in self._device_function_stack.peek_traceable_objs():
4925 obj_copy = obj.copy_metadata()
4926 obj_copy.obj = obj.obj.display_name
4927 snapshot.append(obj_copy)
4928 return snapshot
4930 @_device_function_stack.setter
4931 def _device_function_stack(self, device_function_stack):
4932 if self._stack_state_is_thread_local:
4933 # pylint: disable=protected-access
4934 self._thread_local._device_function_stack = device_function_stack
4935 # pylint: enable=protected-access
4936 else:
4937 self._graph_device_function_stack = device_function_stack
4939 @property
4940 def _colocation_stack(self):
4941 """Return thread-local copy of colocation stack."""
4942 if self._stack_state_is_thread_local:
4943 # This may be called from a thread where colocation_stack doesn't yet
4944 # exist.
4945 # pylint: disable=protected-access
4946 if not hasattr(self._thread_local, "_colocation_stack"):
4947 stack_copy_for_this_thread = self._graph_colocation_stack.copy()
4948 self._thread_local._colocation_stack = stack_copy_for_this_thread
4949 return self._thread_local._colocation_stack
4950 # pylint: enable=protected-access
4951 else:
4952 return self._graph_colocation_stack
4954 def _snapshot_colocation_stack_metadata(self):
4955 """Return colocation stack metadata as a dictionary."""
4956 return {
4957 traceable_obj.obj.name: traceable_obj.copy_metadata()
4958 for traceable_obj in self._colocation_stack.peek_traceable_objs()
4959 }
4961 @_colocation_stack.setter
4962 def _colocation_stack(self, colocation_stack):
4963 if self._stack_state_is_thread_local:
4964 # pylint: disable=protected-access
4965 self._thread_local._colocation_stack = colocation_stack
4966 # pylint: enable=protected-access
4967 else:
4968 self._graph_colocation_stack = colocation_stack
4970 @property
4971 def _control_dependencies_stack(self):
4972 if self._stack_state_is_thread_local:
4973 # This may be called from a thread where control_dependencies_stack
4974 # doesn't yet exist.
4975 if not hasattr(self._thread_local, "_control_dependencies_stack"):
4976 self._thread_local._control_dependencies_stack = (
4977 self._graph_control_dependencies_stack[:])
4978 return self._thread_local._control_dependencies_stack
4979 else:
4980 return self._graph_control_dependencies_stack
4982 @_control_dependencies_stack.setter
4983 def _control_dependencies_stack(self, control_dependencies):
4984 if self._stack_state_is_thread_local:
4985 self._thread_local._control_dependencies_stack = control_dependencies
4986 else:
4987 self._graph_control_dependencies_stack = control_dependencies
4989 @property
4990 def _distribution_strategy_stack(self):
4991 """A stack to maintain distribution strategy context for each thread."""
4992 if not hasattr(self._thread_local, "_distribution_strategy_stack"):
4993 self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access
4994 return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access
4996 @_distribution_strategy_stack.setter
4997 def _distribution_strategy_stack(self, _distribution_strategy_stack):
4998 self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access
4999 _distribution_strategy_stack)
5001 @property
5002 def _global_distribute_strategy_scope(self):
5003 """For implementing `tf.distribute.set_strategy()`."""
5004 if not hasattr(self._thread_local, "distribute_strategy_scope"):
5005 self._thread_local.distribute_strategy_scope = None
5006 return self._thread_local.distribute_strategy_scope
5008 @_global_distribute_strategy_scope.setter
5009 def _global_distribute_strategy_scope(self, distribute_strategy_scope):
5010 self._thread_local.distribute_strategy_scope = (distribute_strategy_scope)
5012 def _mutation_lock(self):
5013 """Returns a lock to guard code that creates & mutates ops.
5015 See the comment for self._group_lock for more info.
5016 """
5017 return self._group_lock.group(_MUTATION_LOCK_GROUP)
5019 def _session_run_lock(self):
5020 """Returns a lock to guard code for Session.run.
5022 See the comment for self._group_lock for more info.
5023 """
5024 return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
5027# TODO(agarwal): currently device directives in an outer eager scope will not
5028# apply to inner graph mode code. Fix that.
5031@tf_export(v1=["device"])
5032def device(device_name_or_function):
5033 """Wrapper for `Graph.device()` using the default graph.
5035 See `tf.Graph.device` for more details.
5037 Args:
5038 device_name_or_function: The device name or function to use in the context.
5040 Returns:
5041 A context manager that specifies the default device to use for newly
5042 created ops.
5044 Raises:
5045 RuntimeError: If eager execution is enabled and a function is passed in.
5046 """
5047 if context.executing_eagerly():
5048 if callable(device_name_or_function):
5049 raise RuntimeError(
5050 "tf.device does not support functions when eager execution "
5051 "is enabled.")
5052 return context.device(device_name_or_function)
5053 elif executing_eagerly_outside_functions():
5054 @tf_contextlib.contextmanager
5055 def combined(device_name_or_function):
5056 with get_default_graph().device(device_name_or_function):
5057 if not callable(device_name_or_function):
5058 with context.device(device_name_or_function):
5059 yield
5060 else:
5061 yield
5062 return combined(device_name_or_function)
5063 else:
5064 return get_default_graph().device(device_name_or_function)
5067@tf_export("device", v1=[])
5068def device_v2(device_name):
5069 """Specifies the device for ops created/executed in this context.
5071 This function specifies the device to be used for ops created/executed in a
5072 particular context. Nested contexts will inherit and also create/execute
5073 their ops on the specified device. If a specific device is not required,
5074 consider not using this function so that a device can be automatically
5075 assigned. In general the use of this function is optional. `device_name` can
5076 be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially
5077 specified, containing only a subset of the "/"-separated fields. Any fields
5078 which are specified will override device annotations from outer scopes.
5080 For example:
5082 ```python
5083 with tf.device('/job:foo'):
5084 # ops created here have devices with /job:foo
5085 with tf.device('/job:bar/task:0/device:gpu:2'):
5086 # ops created here have the fully specified device above
5087 with tf.device('/device:gpu:1'):
5088 # ops created here have the device '/job:foo/device:gpu:1'
5089 ```
5091 Args:
5092 device_name: The device name to use in the context.
5094 Returns:
5095 A context manager that specifies the default device to use for newly
5096 created ops.
5098 Raises:
5099 RuntimeError: If a function is passed in.
5100 """
5101 if callable(device_name):
5102 raise RuntimeError("tf.device does not support functions.")
5103 return device(device_name)
5106@tf_export(v1=["container"])
5107def container(container_name):
5108 """Wrapper for `Graph.container()` using the default graph.
5110 Args:
5111 container_name: The container string to use in the context.
5113 Returns:
5114 A context manager that specifies the default container to use for newly
5115 created stateful ops.
5116 """
5117 return get_default_graph().container(container_name)
5120def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
5121 if context.executing_eagerly():
5122 if op is not None:
5123 if not hasattr(op, "device"):
5124 op = convert_to_tensor(op)
5125 return device(op.device)
5126 else:
5127 return NullContextmanager()
5128 else:
5129 default_graph = get_default_graph()
5130 if isinstance(op, EagerTensor):
5131 if default_graph.building_function:
5132 return default_graph.device(op.device)
5133 else:
5134 raise ValueError("Encountered an Eager-defined Tensor during graph "
5135 "construction, but a function was not being built.")
5136 return default_graph._colocate_with_for_gradient(
5137 op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
5140# Internal interface to colocate_with. colocate_with has been deprecated from
5141# public API. There are still a few internal uses of colocate_with. Add internal
5142# only API for those uses to avoid deprecation warning.
5143def colocate_with(op, ignore_existing=False):
5144 return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
5147@deprecation.deprecated(
5148 date=None, instructions="Colocations handled automatically by placer.")
5149@tf_export(v1=["colocate_with"])
5150def _colocate_with(op, ignore_existing=False):
5151 return colocate_with(op, ignore_existing)
5154@tf_export("control_dependencies")
5155def control_dependencies(control_inputs):
5156 """Wrapper for `Graph.control_dependencies()` using the default graph.
5158 See `tf.Graph.control_dependencies` for more details.
5160 In TensorFlow 2 with eager and/or Autograph, you should not need this method
5161 most of the times, as ops execute in the expected order thanks to automatic
5162 control dependencies. Only use it to manually control ordering, for example as
5163 a workaround to known issues such as `tf.function` with `tf.debugging.assert*`
5164 and `tf.py_function`.
5165 For example:
5167 >>> @tf.function(
5168 ... input_signature=[tf.TensorSpec([None, None], tf.float32),
5169 ... tf.TensorSpec([None, None], tf.float32)])
5170 ... def my_assert_func_1(x, bias):
5171 ... # `tf.function` attempts to execute `tf.math.add` in parallel to
5172 ... # `assert_equal`. As a result an error can get raised from `tf.math.add`
5173 ... # without triggering the assertion error.
5174 ... tf.assert_equal(tf.shape(x)[1],
5175 ... tf.shape(bias)[1],
5176 ... message='bad shape')
5177 ... return x + bias
5179 >>> # Error raised in either `add` or `assert`
5180 >>> my_assert_func_1(tf.ones((2, 5)), tf.ones((2, 7)))
5181 Traceback (most recent call last):
5182 ...
5183 InvalidArgumentError: ...
5186 >>> @tf.function(
5187 ... input_signature=[tf.TensorSpec([None, None], tf.float32),
5188 ... tf.TensorSpec([None, None], tf.float32)])
5189 ... def my_assert_func_2(x, bias):
5190 ... with tf.control_dependencies(
5191 ... [tf.assert_equal(tf.shape(x)[1],
5192 ... tf.shape(bias)[1],
5193 ... message='bad shape')]):
5194 ... return x + bias
5196 >>> # Error raised in `assert`
5197 >>> my_assert_func_2(tf.ones((2, 5)), tf.ones((2, 7)))
5198 Traceback (most recent call last):
5199 ...
5200 InvalidArgumentError: ...
5202 When eager execution is enabled, any callable object in the `control_inputs`
5203 list will be called.
5205 Args:
5206 control_inputs: A list of `Operation` or `Tensor` objects which must be
5207 executed or computed before running the operations defined in the context.
5208 Can also be `None` to clear the control dependencies. If eager execution
5209 is enabled, any callable object in the `control_inputs` list will be
5210 called.
5212 Returns:
5213 A context manager that specifies control dependencies for all
5214 operations constructed within the context.
5215 """
5216 if context.executing_eagerly():
5217 if control_inputs:
5218 # Execute any pending callables.
5219 for control in control_inputs:
5220 if callable(control):
5221 control()
5222 return NullContextmanager()
5223 else:
5224 return get_default_graph().control_dependencies(control_inputs)
5226# TODO(b/271463878): Remove in favor of direct references to `stack`.
5227get_default_session = stack.get_default_session
5230def _eval_using_default_session(tensors, feed_dict, graph, session=None):
5231 """Uses the default session to evaluate one or more tensors.
5233 Args:
5234 tensors: A single Tensor, or a list of Tensor objects.
5235 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5236 numpy ndarrays, TensorProtos, or strings.
5237 graph: The graph in which the tensors are defined.
5238 session: (Optional) A different session to use to evaluate "tensors".
5240 Returns:
5241 Either a single numpy ndarray if "tensors" is a single tensor; or a list
5242 of numpy ndarrays that each correspond to the respective element in
5243 "tensors".
5245 Raises:
5246 ValueError: If no default session is available; the default session
5247 does not have "graph" as its graph; or if "session" is specified,
5248 and it does not have "graph" as its graph.
5249 """
5250 if session is None:
5251 session = stack.get_default_session()
5252 if session is None:
5253 raise ValueError("Cannot evaluate tensor using `eval()`: No default "
5254 "session is registered. Use `with "
5255 "sess.as_default()` or pass an explicit session to "
5256 "`eval(session=sess)`")
5257 if session.graph is not graph:
5258 raise ValueError("Cannot use the default session to evaluate tensor: "
5259 "the tensor's graph is different from the session's "
5260 "graph. Pass an explicit session to "
5261 "`eval(session=sess)`.")
5262 else:
5263 if session.graph is not graph:
5264 raise ValueError("Cannot use the given session to evaluate tensor: "
5265 "the tensor's graph is different from the session's "
5266 "graph.")
5267 return session.run(tensors, feed_dict)
5270def _run_using_default_session(operation, feed_dict, graph, session=None):
5271 """Uses the default session to run "operation".
5273 Args:
5274 operation: The Operation to be run.
5275 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5276 numpy ndarrays, TensorProtos, or strings.
5277 graph: The graph in which "operation" is defined.
5278 session: (Optional) A different session to use to run "operation".
5280 Raises:
5281 ValueError: If no default session is available; the default session
5282 does not have "graph" as its graph; or if "session" is specified,
5283 and it does not have "graph" as its graph.
5284 """
5285 if session is None:
5286 session = stack.get_default_session()
5287 if session is None:
5288 raise ValueError("Cannot execute operation using `run()`: No default "
5289 "session is registered. Use `with "
5290 "sess.as_default():` or pass an explicit session to "
5291 "`run(session=sess)`")
5292 if session.graph is not graph:
5293 raise ValueError("Cannot use the default session to execute operation: "
5294 "the operation's graph is different from the "
5295 "session's graph. Pass an explicit session to "
5296 "run(session=sess).")
5297 else:
5298 if session.graph is not graph:
5299 raise ValueError("Cannot use the given session to execute operation: "
5300 "the operation's graph is different from the session's "
5301 "graph.")
5302 session.run(operation, feed_dict)
5305class _DefaultGraphStack(stack.DefaultStack): # pylint: disable=protected-access
5306 """A thread-local stack of objects for providing an implicit default graph."""
5308 def __init__(self):
5309 super(_DefaultGraphStack, self).__init__()
5310 self._global_default_graph = None
5312 def get_default(self):
5313 """Override that returns a global default if the stack is empty."""
5314 if self.stack:
5315 return self.stack[-1]
5316 elif self._global_default_graph:
5317 return self._global_default_graph
5318 else:
5319 self._global_default_graph = Graph()
5320 return self._global_default_graph
5322 def _GetGlobalDefaultGraph(self):
5323 if self._global_default_graph is None:
5324 # TODO(mrry): Perhaps log that the default graph is being used, or set
5325 # provide some other feedback to prevent confusion when a mixture of
5326 # the global default graph and an explicit graph are combined in the
5327 # same process.
5328 self._global_default_graph = Graph()
5329 return self._global_default_graph
5331 def reset(self):
5332 super(_DefaultGraphStack, self).reset()
5333 self._global_default_graph = None
5335 @tf_contextlib.contextmanager
5336 def get_controller(self, default):
5337 context.context().context_switches.push(default.building_function,
5338 default.as_default,
5339 default._device_function_stack)
5340 try:
5341 with super(_DefaultGraphStack,
5342 self).get_controller(default) as g, context.graph_mode():
5343 yield g
5344 finally:
5345 # If an exception is raised here it may be hiding a related exception in
5346 # the try-block (just above).
5347 context.context().context_switches.pop()
5350_default_graph_stack = _DefaultGraphStack()
5353# Shared helper used in init_scope and executing_eagerly_outside_functions
5354# to obtain the outermost context that is not building a function, and the
5355# innermost non empty device stack.
5356def _get_outer_context_and_inner_device_stack():
5357 """Get the outermost context not building a function."""
5358 default_graph = get_default_graph()
5359 outer_context = None
5360 innermost_nonempty_device_stack = default_graph._device_function_stack # pylint: disable=protected-access
5362 if not _default_graph_stack.stack:
5363 # If the default graph stack is empty, then we cannot be building a
5364 # function. Install the global graph (which, in this case, is also the
5365 # default graph) as the outer context.
5366 if default_graph.building_function:
5367 raise RuntimeError("The global graph is building a function.")
5368 outer_context = default_graph.as_default
5369 else:
5370 # Find a context that is not building a function.
5371 for stack_entry in reversed(context.context().context_switches.stack):
5372 if not innermost_nonempty_device_stack:
5373 innermost_nonempty_device_stack = stack_entry.device_stack
5374 if not stack_entry.is_building_function:
5375 outer_context = stack_entry.enter_context_fn
5376 break
5378 if outer_context is None:
5379 # As a last resort, obtain the global default graph; this graph doesn't
5380 # necessarily live on the graph stack (and hence it doesn't necessarily
5381 # live on the context stack), but it is stored in the graph stack's
5382 # encapsulating object.
5383 outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default # pylint: disable=protected-access
5385 if outer_context is None:
5386 # Sanity check; this shouldn't be triggered.
5387 raise RuntimeError("All graphs are building functions, and no "
5388 "eager context was previously active.")
5390 return outer_context, innermost_nonempty_device_stack
5393# pylint: disable=g-doc-return-or-yield,line-too-long
5394@tf_export("init_scope")
5395@tf_contextlib.contextmanager
5396def init_scope():
5397 """A context manager that lifts ops out of control-flow scopes and function-building graphs.
5399 There is often a need to lift variable initialization ops out of control-flow
5400 scopes, function-building graphs, and gradient tapes. Entering an
5401 `init_scope` is a mechanism for satisfying these desiderata. In particular,
5402 entering an `init_scope` has three effects:
5404 (1) All control dependencies are cleared the moment the scope is entered;
5405 this is equivalent to entering the context manager returned from
5406 `control_dependencies(None)`, which has the side-effect of exiting
5407 control-flow scopes like `tf.cond` and `tf.while_loop`.
5409 (2) All operations that are created while the scope is active are lifted
5410 into the lowest context on the `context_stack` that is not building a
5411 graph function. Here, a context is defined as either a graph or an eager
5412 context. Every context switch, i.e., every installation of a graph as
5413 the default graph and every switch into eager mode, is logged in a
5414 thread-local stack called `context_switches`; the log entry for a
5415 context switch is popped from the stack when the context is exited.
5416 Entering an `init_scope` is equivalent to crawling up
5417 `context_switches`, finding the first context that is not building a
5418 graph function, and entering it. A caveat is that if graph mode is
5419 enabled but the default graph stack is empty, then entering an
5420 `init_scope` will simply install a fresh graph as the default one.
5422 (3) The gradient tape is paused while the scope is active.
5424 When eager execution is enabled, code inside an init_scope block runs with
5425 eager execution enabled even when tracing a `tf.function`. For example:
5427 ```python
5428 tf.compat.v1.enable_eager_execution()
5430 @tf.function
5431 def func():
5432 # A function constructs TensorFlow graphs,
5433 # it does not execute eagerly.
5434 assert not tf.executing_eagerly()
5435 with tf.init_scope():
5436 # Initialization runs with eager execution enabled
5437 assert tf.executing_eagerly()
5438 ```
5440 Raises:
5441 RuntimeError: if graph state is incompatible with this initialization.
5442 """
5443 # pylint: enable=g-doc-return-or-yield,line-too-long
5445 if context.executing_eagerly():
5446 # Fastpath.
5447 with record.stop_recording():
5448 yield
5449 else:
5450 # Retrieve the active name scope: entering an `init_scope` preserves
5451 # the name scope of the current context.
5452 scope = get_default_graph().get_name_scope()
5453 if scope and scope[-1] != "/":
5454 # Names that end with trailing slashes are treated by `name_scope` as
5455 # absolute.
5456 scope = scope + "/"
5458 outer_context, innermost_nonempty_device_stack = (
5459 _get_outer_context_and_inner_device_stack())
5461 outer_graph = None
5462 outer_device_stack = None
5463 try:
5464 with outer_context(), name_scope(
5465 scope, skip_on_eager=False), control_dependencies(
5466 None), record.stop_recording():
5467 context_manager = NullContextmanager
5468 context_manager_input = None
5469 if not context.executing_eagerly():
5470 # The device stack is preserved when lifting into a graph. Eager
5471 # execution doesn't implement device stacks and in particular it
5472 # doesn't support device functions, so in general it's not possible
5473 # to do the same when lifting into the eager context.
5474 outer_graph = get_default_graph()
5475 outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access
5476 outer_graph._device_function_stack = innermost_nonempty_device_stack # pylint: disable=protected-access
5477 elif innermost_nonempty_device_stack is not None:
5478 for device_spec in innermost_nonempty_device_stack.peek_objs():
5479 if device_spec.function is None:
5480 break
5481 if device_spec.raw_string:
5482 context_manager = context.device
5483 context_manager_input = device_spec.raw_string
5484 break
5485 # It is currently not possible to have a device function in V2,
5486 # but in V1 we are unable to apply device functions in eager mode.
5487 # This means that we will silently skip some of the entries on the
5488 # device stack in V1 + eager mode.
5490 with context_manager(context_manager_input):
5491 yield
5492 finally:
5493 # If an exception is raised here it may be hiding a related exception in
5494 # try-block (just above).
5495 if outer_graph is not None:
5496 outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access
5499@tf_export(v1=["executing_eagerly_outside_functions"])
5500def executing_eagerly_outside_functions():
5501 """Returns True if executing eagerly, even if inside a graph function.
5503 This function will check the outermost context for the program and see if
5504 it is in eager mode. It is useful comparing to `tf.executing_eagerly()`,
5505 which checks the current context and will return `False` within a
5506 `tf.function` body. It can be used to build library that behave differently
5507 in eager runtime and v1 session runtime (deprecated).
5509 Example:
5511 >>> tf.compat.v1.enable_eager_execution()
5512 >>> @tf.function
5513 ... def func():
5514 ... # A function constructs TensorFlow graphs, it does not execute eagerly,
5515 ... # but the outer most context is still eager.
5516 ... assert not tf.executing_eagerly()
5517 ... return tf.compat.v1.executing_eagerly_outside_functions()
5518 >>> func()
5519 <tf.Tensor: shape=(), dtype=bool, numpy=True>
5521 Returns:
5522 boolean, whether the outermost context is in eager mode.
5523 """
5524 if context.executing_eagerly():
5525 return True
5526 else:
5527 outer_context, _ = _get_outer_context_and_inner_device_stack()
5528 with outer_context():
5529 return context.executing_eagerly()
5532@tf_export("inside_function", v1=[])
5533def inside_function():
5534 """Indicates whether the caller code is executing inside a `tf.function`.
5536 Returns:
5537 Boolean, True if the caller code is executing inside a `tf.function`
5538 rather than eagerly.
5540 Example:
5542 >>> tf.inside_function()
5543 False
5544 >>> @tf.function
5545 ... def f():
5546 ... print(tf.inside_function())
5547 >>> f()
5548 True
5549 """
5550 return get_default_graph().building_function
5553@tf_export(v1=["enable_eager_execution"])
5554def enable_eager_execution(config=None, device_policy=None,
5555 execution_mode=None):
5556 """Enables eager execution for the lifetime of this program.
5558 Eager execution provides an imperative interface to TensorFlow. With eager
5559 execution enabled, TensorFlow functions execute operations immediately (as
5560 opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`)
5561 and
5562 return concrete values (as opposed to symbolic references to a node in a
5563 computational graph).
5565 For example:
5567 ```python
5568 tf.compat.v1.enable_eager_execution()
5570 # After eager execution is enabled, operations are executed as they are
5571 # defined and Tensor objects hold concrete values, which can be accessed as
5572 # numpy.ndarray`s through the numpy() method.
5573 assert tf.multiply(6, 7).numpy() == 42
5574 ```
5576 Eager execution cannot be enabled after TensorFlow APIs have been used to
5577 create or execute graphs. It is typically recommended to invoke this function
5578 at program startup and not in a library (as most libraries should be usable
5579 both with and without eager execution).
5581 @compatibility(TF2)
5582 This function is not necessary if you are using TF2. Eager execution is
5583 enabled by default.
5584 @end_compatibility
5586 Args:
5587 config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the
5588 environment in which operations are executed. Note that
5589 `tf.compat.v1.ConfigProto` is also used to configure graph execution (via
5590 `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto`
5591 are not implemented (or are irrelevant) when eager execution is enabled.
5592 device_policy: (Optional.) Policy controlling how operations requiring
5593 inputs on a specific device (e.g., a GPU 0) handle inputs on a different
5594 device (e.g. GPU 1 or CPU). When set to None, an appropriate value will
5595 be picked automatically. The value picked may change between TensorFlow
5596 releases.
5597 Valid values:
5598 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the
5599 placement is not correct.
5600 - DEVICE_PLACEMENT_WARN: copies the tensors which are not
5601 on the right device but logs a warning.
5602 - DEVICE_PLACEMENT_SILENT: silently copies the tensors.
5603 Note that this may hide performance problems as there is no notification
5604 provided when operations are blocked on the tensor being copied between
5605 devices.
5606 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies
5607 int32 tensors, raising errors on the other ones.
5608 execution_mode: (Optional.) Policy controlling how operations dispatched are
5609 actually executed. When set to None, an appropriate value will be picked
5610 automatically. The value picked may change between TensorFlow releases.
5611 Valid values:
5612 - SYNC: executes each operation synchronously.
5613 - ASYNC: executes each operation asynchronously. These
5614 operations may return "non-ready" handles.
5616 Raises:
5617 ValueError: If eager execution is enabled after creating/executing a
5618 TensorFlow graph, or if options provided conflict with a previous call
5619 to this function.
5620 """
5621 _api_usage_gauge.get_cell().set(True)
5622 logging.vlog(1, "Enabling eager execution")
5623 if context.default_execution_mode != context.EAGER_MODE:
5624 return enable_eager_execution_internal(
5625 config=config,
5626 device_policy=device_policy,
5627 execution_mode=execution_mode,
5628 server_def=None)
5631@tf_export(v1=["disable_eager_execution"])
5632def disable_eager_execution():
5633 """Disables eager execution.
5635 This function can only be called before any Graphs, Ops, or Tensors have been
5636 created.
5638 @compatibility(TF2)
5639 This function is not necessary if you are using TF2. Eager execution is
5640 enabled by default. If you want to use Graph mode please consider
5641 [tf.function](https://www.tensorflow.org/api_docs/python/tf/function).
5642 @end_compatibility
5643 """
5644 _api_usage_gauge.get_cell().set(False)
5645 logging.vlog(1, "Disabling eager execution")
5646 context.default_execution_mode = context.GRAPH_MODE
5647 c = context.context_safe()
5648 if c is not None:
5649 c._thread_local_data.is_eager = False # pylint: disable=protected-access
5652def enable_eager_execution_internal(config=None,
5653 device_policy=None,
5654 execution_mode=None,
5655 server_def=None):
5656 """Enables eager execution for the lifetime of this program.
5658 Most of the doc string for enable_eager_execution is relevant here as well.
5660 Args:
5661 config: See enable_eager_execution doc string
5662 device_policy: See enable_eager_execution doc string
5663 execution_mode: See enable_eager_execution doc string
5664 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on
5665 remote devices. GrpcServers need to be started by creating an identical
5666 server_def to this, and setting the appropriate task_indexes, so that the
5667 servers can communicate. It will then be possible to execute operations on
5668 remote devices.
5670 Raises:
5671 ValueError
5673 """
5674 if config is not None and not isinstance(config, config_pb2.ConfigProto):
5675 raise TypeError("config must be a tf.ConfigProto, but got %s" %
5676 type(config))
5677 if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
5678 context.DEVICE_PLACEMENT_WARN,
5679 context.DEVICE_PLACEMENT_SILENT,
5680 context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
5681 raise ValueError("device_policy must be one of None, DEVICE_PLACEMENT_*")
5682 if execution_mode not in (None, context.SYNC, context.ASYNC):
5683 raise ValueError("execution_mode must be one of None, SYNC, " "ASYNC")
5684 if context.default_execution_mode == context.GRAPH_MODE:
5685 graph_mode_has_been_used = (
5686 _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access
5687 if graph_mode_has_been_used:
5688 raise ValueError(
5689 "tf.enable_eager_execution must be called at program startup.")
5690 context.default_execution_mode = context.EAGER_MODE
5691 # pylint: disable=protected-access
5692 with context._context_lock:
5693 if context._context is None:
5694 context._set_context_locked(context.Context(
5695 config=config,
5696 device_policy=device_policy,
5697 execution_mode=execution_mode,
5698 server_def=server_def))
5699 elif ((config is not None and config is not context._context._config) or
5700 (device_policy is not None and
5701 device_policy is not context._context._device_policy) or
5702 (execution_mode is not None and
5703 execution_mode is not context._context._execution_mode)):
5704 raise ValueError(
5705 "Trying to change the options of an active eager"
5706 " execution. Context config: %s, specified config:"
5707 " %s. Context device policy: %s, specified device"
5708 " policy: %s. Context execution mode: %s, "
5709 " specified execution mode %s." %
5710 (context._context._config, config, context._context._device_policy,
5711 device_policy, context._context._execution_mode, execution_mode))
5712 else:
5713 # We already created everything, so update the thread local data.
5714 context._context._thread_local_data.is_eager = True
5716 # Monkey patch to get rid of an unnecessary conditional since the context is
5717 # now initialized.
5718 context.context = context.context_safe
5721def eager_run(main=None, argv=None):
5722 """Runs the program with an optional main function and argv list.
5724 The program will run with eager execution enabled.
5726 Example:
5727 ```python
5728 import tensorflow as tf
5729 # Import subject to future changes:
5731 def main(_):
5732 u = tf.constant(6.0)
5733 v = tf.constant(7.0)
5734 print(u * v)
5736 if __name__ == "__main__":
5737 tfe.run()
5738 ```
5740 Args:
5741 main: the main function to run.
5742 argv: the arguments to pass to it.
5743 """
5744 enable_eager_execution()
5745 app.run(main, argv)
5748@tf_export(v1=["reset_default_graph"])
5749def reset_default_graph():
5750 """Clears the default graph stack and resets the global default graph.
5752 NOTE: The default graph is a property of the current thread. This
5753 function applies only to the current thread. Calling this function while
5754 a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will
5755 result in undefined
5756 behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
5757 after calling this function will result in undefined behavior.
5759 @compatibility(TF2)
5760 `reset_default_graph` does not work with either eager execution or
5761 `tf.function`, and you should not invoke it directly. To migrate code that
5762 uses Graph-related functions to TF2, rewrite the code without them. See the
5763 [migration guide](https://www.tensorflow.org/guide/migrate) for more
5764 description about the behavior and semantic changes between Tensorflow 1 and
5765 Tensorflow 2.
5766 @end_compatibility
5768 Raises:
5769 AssertionError: If this function is called within a nested graph.
5770 """
5771 if not _default_graph_stack.is_cleared():
5772 raise AssertionError("Do not use tf.reset_default_graph() to clear "
5773 "nested graphs. If you need a cleared graph, "
5774 "exit the nesting and create a new graph.")
5775 _default_graph_stack.reset()
5778@tf_export(v1=["get_default_graph"])
5779def get_default_graph():
5780 """Returns the default graph for the current thread.
5782 The returned graph will be the innermost graph on which a
5783 `Graph.as_default()` context has been entered, or a global default
5784 graph if none has been explicitly created.
5786 NOTE: The default graph is a property of the current thread. If you
5787 create a new thread, and wish to use the default graph in that
5788 thread, you must explicitly add a `with g.as_default():` in that
5789 thread's function.
5791 @compatibility(TF2)
5792 `get_default_graph` does not work with either eager execution or
5793 `tf.function`, and you should not invoke it directly. To migrate code that
5794 uses Graph-related functions to TF2, rewrite the code without them. See the
5795 [migration guide](https://www.tensorflow.org/guide/migrate) for more
5796 description about the behavior and semantic changes between Tensorflow 1 and
5797 Tensorflow 2.
5798 @end_compatibility
5800 Returns:
5801 The default `Graph` being used in the current thread.
5802 """
5803 return _default_graph_stack.get_default()
5806def has_default_graph():
5807 """Returns True if there is a default graph."""
5808 return len(_default_graph_stack.stack) >= 1
5811# Exported due to b/171079555
5812@tf_export("__internal__.get_name_scope", v1=[])
5813def get_name_scope():
5814 """Returns the current name scope in the default_graph.
5816 For example:
5818 ```python
5819 with tf.name_scope('scope1'):
5820 with tf.name_scope('scope2'):
5821 print(tf.get_name_scope())
5822 ```
5823 would print the string `scope1/scope2`.
5825 Returns:
5826 A string representing the current name scope.
5827 """
5828 if context.executing_eagerly():
5829 return context.context().scope_name.rstrip("/")
5830 return get_default_graph().get_name_scope()
5833def _assert_same_graph(original_item, item):
5834 """Fail if the 2 items are from different graphs.
5836 Args:
5837 original_item: Original item to check against.
5838 item: Item to check.
5840 Raises:
5841 ValueError: if graphs do not match.
5842 """
5843 original_graph = getattr(original_item, "graph", None)
5844 graph = getattr(item, "graph", None)
5845 if original_graph and graph and original_graph is not graph:
5846 raise ValueError(
5847 "%s must be from the same graph as %s (graphs are %s and %s)." %
5848 (item, original_item, graph, original_graph))
5851def _get_graph_from_inputs(op_input_list, graph=None):
5852 """Returns the appropriate graph to use for the given inputs.
5854 This library method provides a consistent algorithm for choosing the graph
5855 in which an Operation should be constructed:
5857 1. If the default graph is being used to construct a function, we
5858 use the default graph.
5859 2. If the "graph" is specified explicitly, we validate that all of the inputs
5860 in "op_input_list" are compatible with that graph.
5861 3. Otherwise, we attempt to select a graph from the first Operation-
5862 or Tensor-valued input in "op_input_list", and validate that all other
5863 such inputs are in the same graph.
5864 4. If the graph was not specified and it could not be inferred from
5865 "op_input_list", we attempt to use the default graph.
5867 Args:
5868 op_input_list: A list of inputs to an operation, which may include `Tensor`,
5869 `Operation`, and other objects that may be converted to a graph element.
5870 graph: (Optional) The explicit graph to use.
5872 Raises:
5873 TypeError: If op_input_list is not a list or tuple, or if graph is not a
5874 Graph.
5875 ValueError: If a graph is explicitly passed and not all inputs are from it,
5876 or if the inputs are from multiple graphs, or we could not find a graph
5877 and there was no default graph.
5879 Returns:
5880 The appropriate graph to use for the given inputs.
5882 """
5883 current_default_graph = get_default_graph()
5884 if current_default_graph.building_function:
5885 return current_default_graph
5887 op_input_list = tuple(op_input_list) # Handle generators correctly
5888 if graph and not isinstance(graph, Graph):
5889 raise TypeError("Input graph needs to be a Graph: %s" % (graph,))
5891 # 1. We validate that all of the inputs are from the same graph. This is
5892 # either the supplied graph parameter, or the first one selected from one
5893 # the graph-element-valued inputs. In the latter case, we hold onto
5894 # that input in original_graph_element so we can provide a more
5895 # informative error if a mismatch is found.
5896 original_graph_element = None
5897 for op_input in op_input_list:
5898 # Determine if this is a valid graph_element.
5899 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
5900 # up.
5901 graph_element = None
5902 if isinstance(op_input, (Operation, internal.NativeObject)) and (
5903 (not isinstance(op_input, Tensor)) or type(op_input) == Tensor # pylint: disable=unidiomatic-typecheck
5904 ):
5905 graph_element = op_input
5906 else:
5907 graph_element = _as_graph_element(op_input)
5909 if graph_element is not None:
5910 if not graph:
5911 original_graph_element = graph_element
5912 graph = getattr(graph_element, "graph", None)
5913 elif original_graph_element is not None:
5914 _assert_same_graph(original_graph_element, graph_element)
5915 elif graph_element.graph is not graph:
5916 raise ValueError("%s is not from the passed-in graph." % graph_element)
5918 # 2. If all else fails, we use the default graph, which is always there.
5919 return graph or current_default_graph
5922@tf_export(v1=["GraphKeys"])
5923class GraphKeys(object):
5924 """Standard names to use for graph collections.
5926 The standard library uses various well-known names to collect and
5927 retrieve values associated with a graph. For example, the
5928 `tf.Optimizer` subclasses default to optimizing the variables
5929 collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
5930 specified, but it is also possible to pass an explicit list of
5931 variables.
5933 The following standard keys are defined:
5935 * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
5936 across distributed environment (model variables are subset of these). See
5937 `tf.compat.v1.global_variables`
5938 for more details.
5939 Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
5940 and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
5941 * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
5942 machine. Usually used for temporarily variables, like counters.
5943 * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
5944 model for inference (feed forward).
5945 * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
5946 be trained by an optimizer. See
5947 `tf.compat.v1.trainable_variables`
5948 for more details.
5949 * `SUMMARIES`: the summary `Tensor` objects that have been created in the
5950 graph. See
5951 `tf.compat.v1.summary.merge_all`
5952 for more details.
5953 * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
5954 produce input for a computation. See
5955 `tf.compat.v1.train.start_queue_runners`
5956 for more details.
5957 * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
5958 keep moving averages. See
5959 `tf.compat.v1.moving_average_variables`
5960 for more details.
5961 * `REGULARIZATION_LOSSES`: regularization losses collected during graph
5962 construction.
5964 The following standard keys are _defined_, but their collections are **not**
5965 automatically populated as many of the others are:
5967 * `WEIGHTS`
5968 * `BIASES`
5969 * `ACTIVATIONS`
5970 """
5972 # Key to collect Variable objects that are global (shared across machines).
5973 # Default collection for all variables, except local ones.
5974 GLOBAL_VARIABLES = "variables"
5975 # Key to collect local variables that are local to the machine and are not
5976 # saved/restored.
5977 LOCAL_VARIABLES = "local_variables"
5978 # Key to collect local variables which are used to accumulate internal state
5979 # to be used in tf.metrics.*.
5980 METRIC_VARIABLES = "metric_variables"
5981 # Key to collect model variables defined by layers.
5982 MODEL_VARIABLES = "model_variables"
5983 # Key to collect Variable objects that will be trained by the
5984 # optimizers.
5985 TRAINABLE_VARIABLES = "trainable_variables"
5986 # Key to collect summaries.
5987 SUMMARIES = "summaries"
5988 # Key to collect QueueRunners.
5989 QUEUE_RUNNERS = "queue_runners"
5990 # Key to collect table initializers.
5991 TABLE_INITIALIZERS = "table_initializer"
5992 # Key to collect asset filepaths. An asset represents an external resource
5993 # like a vocabulary file.
5994 ASSET_FILEPATHS = "asset_filepaths"
5995 # Key to collect Variable objects that keep moving averages.
5996 MOVING_AVERAGE_VARIABLES = "moving_average_variables"
5997 # Key to collect regularization losses at graph construction.
5998 REGULARIZATION_LOSSES = "regularization_losses"
5999 # Key to collect concatenated sharded variables.
6000 CONCATENATED_VARIABLES = "concatenated_variables"
6001 # Key to collect savers.
6002 SAVERS = "savers"
6003 # Key to collect weights
6004 WEIGHTS = "weights"
6005 # Key to collect biases
6006 BIASES = "biases"
6007 # Key to collect activations
6008 ACTIVATIONS = "activations"
6009 # Key to collect update_ops
6010 UPDATE_OPS = "update_ops"
6011 # Key to collect losses
6012 LOSSES = "losses"
6013 # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
6014 SAVEABLE_OBJECTS = "saveable_objects"
6015 # Key to collect all shared resources used by the graph which need to be
6016 # initialized once per cluster.
6017 RESOURCES = "resources"
6018 # Key to collect all shared resources used in this graph which need to be
6019 # initialized once per session.
6020 LOCAL_RESOURCES = "local_resources"
6021 # Trainable resource-style variables.
6022 TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
6024 # Key to indicate various ops.
6025 INIT_OP = "init_op"
6026 LOCAL_INIT_OP = "local_init_op"
6027 READY_OP = "ready_op"
6028 READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
6029 SUMMARY_OP = "summary_op"
6030 GLOBAL_STEP = "global_step"
6032 # Used to count the number of evaluations performed during a single evaluation
6033 # run.
6034 EVAL_STEP = "eval_step"
6035 TRAIN_OP = "train_op"
6037 # Key for control flow context.
6038 COND_CONTEXT = "cond_context"
6039 WHILE_CONTEXT = "while_context"
6041 # Used to store v2 summary names.
6042 _SUMMARY_COLLECTION = "_SUMMARY_V2"
6044 # List of all collections that keep track of variables.
6045 _VARIABLE_COLLECTIONS = [
6046 GLOBAL_VARIABLES,
6047 LOCAL_VARIABLES,
6048 METRIC_VARIABLES,
6049 MODEL_VARIABLES,
6050 TRAINABLE_VARIABLES,
6051 MOVING_AVERAGE_VARIABLES,
6052 CONCATENATED_VARIABLES,
6053 TRAINABLE_RESOURCE_VARIABLES,
6054 ]
6056 # Key for streaming model ports.
6057 # NOTE(yuanbyu): internal and experimental.
6058 _STREAMING_MODEL_PORTS = "streaming_model_ports"
6060 @decorator_utils.classproperty
6061 @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
6062 def VARIABLES(cls): # pylint: disable=no-self-argument
6063 return cls.GLOBAL_VARIABLES
6066def dismantle_graph(graph):
6067 """Cleans up reference cycles from a `Graph`.
6069 Helpful for making sure the garbage collector doesn't need to run after a
6070 temporary `Graph` is no longer needed.
6072 Args:
6073 graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
6074 after this function runs.
6075 """
6076 graph._functions.clear() # pylint: disable=protected-access
6077 graph.Dismantle()
6080@tf_export(v1=["add_to_collection"])
6081def add_to_collection(name, value):
6082 """Wrapper for `Graph.add_to_collection()` using the default graph.
6084 See `tf.Graph.add_to_collection`
6085 for more details.
6087 Args:
6088 name: The key for the collection. For example, the `GraphKeys` class
6089 contains many standard names for collections.
6090 value: The value to add to the collection.
6092 @compatibility(eager)
6093 Collections are only supported in eager when variables are created inside
6094 an EagerVariableStore (e.g. as part of a layer or template).
6095 @end_compatibility
6096 """
6097 get_default_graph().add_to_collection(name, value)
6100@tf_export(v1=["add_to_collections"])
6101def add_to_collections(names, value):
6102 """Wrapper for `Graph.add_to_collections()` using the default graph.
6104 See `tf.Graph.add_to_collections`
6105 for more details.
6107 Args:
6108 names: The key for the collections. The `GraphKeys` class contains many
6109 standard names for collections.
6110 value: The value to add to the collections.
6112 @compatibility(eager)
6113 Collections are only supported in eager when variables are created inside
6114 an EagerVariableStore (e.g. as part of a layer or template).
6115 @end_compatibility
6116 """
6117 get_default_graph().add_to_collections(names, value)
6120@tf_export(v1=["get_collection_ref"])
6121def get_collection_ref(key):
6122 """Wrapper for `Graph.get_collection_ref()` using the default graph.
6124 See `tf.Graph.get_collection_ref`
6125 for more details.
6127 Args:
6128 key: The key for the collection. For example, the `GraphKeys` class contains
6129 many standard names for collections.
6131 Returns:
6132 The list of values in the collection with the given `name`, or an empty
6133 list if no value has been added to that collection. Note that this returns
6134 the collection list itself, which can be modified in place to change the
6135 collection.
6137 @compatibility(eager)
6138 Collections are not supported when eager execution is enabled.
6139 @end_compatibility
6140 """
6141 return get_default_graph().get_collection_ref(key)
6144@tf_export(v1=["get_collection"])
6145def get_collection(key, scope=None):
6146 """Wrapper for `Graph.get_collection()` using the default graph.
6148 See `tf.Graph.get_collection`
6149 for more details.
6151 Args:
6152 key: The key for the collection. For example, the `GraphKeys` class contains
6153 many standard names for collections.
6154 scope: (Optional.) If supplied, the resulting list is filtered to include
6155 only items whose `name` attribute matches using `re.match`. Items without
6156 a `name` attribute are never returned if a scope is supplied and the
6157 choice or `re.match` means that a `scope` without special tokens filters
6158 by prefix.
6160 Returns:
6161 The list of values in the collection with the given `name`, or
6162 an empty list if no value has been added to that collection. The
6163 list contains the values in the order under which they were
6164 collected.
6166 @compatibility(eager)
6167 Collections are not supported when eager execution is enabled.
6168 @end_compatibility
6169 """
6170 return get_default_graph().get_collection(key, scope)
6173def get_all_collection_keys():
6174 """Returns a list of collections used in the default graph."""
6175 return get_default_graph().get_all_collection_keys()
6178def name_scope(name, default_name=None, values=None, skip_on_eager=True):
6179 """Internal-only entry point for `name_scope*`.
6181 Internal ops do not use the public API and instead rely on
6182 `ops.name_scope` regardless of the execution mode. This function
6183 dispatches to the correct `name_scope*` implementation based on
6184 the arguments provided and the current mode. Specifically,
6186 * if `values` contains a graph tensor `Graph.name_scope` is used;
6187 * `name_scope_v1` is used in graph mode;
6188 * `name_scope_v2` -- in eager mode.
6190 Args:
6191 name: The name argument that is passed to the op function.
6192 default_name: The default name to use if the `name` argument is `None`.
6193 values: The list of `Tensor` arguments that are passed to the op function.
6194 skip_on_eager: Indicates to return NullContextmanager if executing eagerly.
6195 By default this is True since naming tensors and operations in eager mode
6196 have little use and cause unnecessary performance overhead. However, it is
6197 important to preserve variable names since they are often useful for
6198 debugging and saved models.
6200 Returns:
6201 `name_scope*` context manager.
6202 """
6203 if not context.executing_eagerly():
6204 return internal_name_scope_v1(name, default_name, values)
6206 if skip_on_eager:
6207 return NullContextmanager()
6209 name = default_name if name is None else name
6210 if values:
6211 # The presence of a graph tensor in `values` overrides the context.
6212 # TODO(slebedev): this is Keras-specific and should be removed.
6213 # pylint: disable=unidiomatic-typecheck
6214 graph_value = next((value for value in values if type(value) == Tensor),
6215 None)
6216 # pylint: enable=unidiomatic-typecheck
6217 if graph_value is not None:
6218 return graph_value.graph.name_scope(name)
6220 return name_scope_v2(name or "")
6223class internal_name_scope_v1(object): # pylint: disable=invalid-name
6224 """Graph-only version of `name_scope_v1`."""
6226 @property
6227 def name(self):
6228 return self._name
6230 def __init__(self, name, default_name=None, values=None):
6231 """Initialize the context manager.
6233 Args:
6234 name: The name argument that is passed to the op function.
6235 default_name: The default name to use if the `name` argument is `None`.
6236 values: The list of `Tensor` arguments that are passed to the op function.
6238 Raises:
6239 TypeError: if `default_name` is passed in but not a string.
6240 """
6241 if not (default_name is None or isinstance(default_name, str)):
6242 raise TypeError(
6243 "`default_name` type (%s) is not a string type. You likely meant to "
6244 "pass this into the `values` kwarg." % type(default_name))
6245 self._name = default_name if name is None else name
6246 self._default_name = default_name
6247 self._values = values
6249 def __enter__(self):
6250 """Start the scope block.
6252 Returns:
6253 The scope name.
6255 Raises:
6256 ValueError: if neither `name` nor `default_name` is provided
6257 but `values` are.
6258 """
6259 if self._name is None and self._values is not None:
6260 # We only raise an error if values is not None (provided) because
6261 # currently tf.name_scope(None) (values=None then) is sometimes used as
6262 # an idiom to reset to top scope.
6263 raise ValueError(
6264 "At least one of name (%s) and default_name (%s) must be provided."
6265 % (self._name, self._default_name))
6267 g = get_default_graph()
6268 if self._values and not g.building_function:
6269 # Specialize based on the knowledge that `_get_graph_from_inputs()`
6270 # ignores `inputs` when building a function.
6271 g_from_inputs = _get_graph_from_inputs(self._values)
6272 if g_from_inputs is not g:
6273 g = g_from_inputs
6274 self._g_manager = g.as_default()
6275 self._g_manager.__enter__()
6276 else:
6277 self._g_manager = None
6278 else:
6279 self._g_manager = None
6281 try:
6282 self._name_scope = g.name_scope(self._name)
6283 return self._name_scope.__enter__()
6284 except:
6285 if self._g_manager is not None:
6286 self._g_manager.__exit__(*sys.exc_info())
6287 raise
6289 def __exit__(self, *exc_info):
6290 self._name_scope.__exit__(*exc_info)
6291 if self._g_manager is not None:
6292 self._g_manager.__exit__(*exc_info)
6295# Named like a function for backwards compatibility with the
6296# @tf_contextlib.contextmanager version, which was switched to a class to avoid
6297# some object creation overhead.
6298@tf_export(v1=["name_scope"])
6299class name_scope_v1(object): # pylint: disable=invalid-name
6300 """A context manager for use when defining a Python op.
6302 This context manager validates that the given `values` are from the
6303 same graph, makes that graph the default graph, and pushes a
6304 name scope in that graph (see
6305 `tf.Graph.name_scope`
6306 for more details on that).
6308 For example, to define a new Python op called `my_op`:
6310 ```python
6311 def my_op(a, b, c, name=None):
6312 with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
6313 a = tf.convert_to_tensor(a, name="a")
6314 b = tf.convert_to_tensor(b, name="b")
6315 c = tf.convert_to_tensor(c, name="c")
6316 # Define some computation that uses `a`, `b`, and `c`.
6317 return foo_op(..., name=scope)
6318 ```
6319 """
6321 __slots__ = ["_name", "_name_scope"]
6323 @property
6324 def name(self):
6325 return self._name
6327 def __init__(self, name, default_name=None, values=None):
6328 """Initialize the context manager.
6330 Args:
6331 name: The name argument that is passed to the op function.
6332 default_name: The default name to use if the `name` argument is `None`.
6333 values: The list of `Tensor` arguments that are passed to the op function.
6335 Raises:
6336 TypeError: if `default_name` is passed in but not a string.
6337 """
6338 self._name_scope = name_scope(
6339 name, default_name, values, skip_on_eager=False)
6340 self._name = default_name if name is None else name
6342 def __enter__(self):
6343 return self._name_scope.__enter__()
6345 def __exit__(self, *exc_info):
6346 return self._name_scope.__exit__(*exc_info)
6349@tf_export("get_current_name_scope", v1=[])
6350def get_current_name_scope():
6351 """Returns current full name scope specified by `tf.name_scope(...)`s.
6353 For example,
6354 ```python
6355 with tf.name_scope("outer"):
6356 tf.get_current_name_scope() # "outer"
6358 with tf.name_scope("inner"):
6359 tf.get_current_name_scope() # "outer/inner"
6360 ```
6362 In other words, `tf.get_current_name_scope()` returns the op name prefix that
6363 will be prepended to, if an op is created at that place.
6365 Note that `@tf.function` resets the name scope stack as shown below.
6367 ```
6368 with tf.name_scope("outer"):
6370 @tf.function
6371 def foo(x):
6372 with tf.name_scope("inner"):
6373 return tf.add(x * x) # Op name is "inner/Add", not "outer/inner/Add"
6374 ```
6375 """
6377 ctx = context.context()
6378 if ctx.executing_eagerly():
6379 return ctx.scope_name.rstrip("/")
6380 else:
6381 return get_default_graph().get_name_scope()
6384@tf_export("name_scope", v1=[])
6385class name_scope_v2(object):
6386 """A context manager for use when defining a Python op.
6388 This context manager pushes a name scope, which will make the name of all
6389 operations added within it have a prefix.
6391 For example, to define a new Python op called `my_op`:
6393 ```python
6394 def my_op(a, b, c, name=None):
6395 with tf.name_scope("MyOp") as scope:
6396 a = tf.convert_to_tensor(a, name="a")
6397 b = tf.convert_to_tensor(b, name="b")
6398 c = tf.convert_to_tensor(c, name="c")
6399 # Define some computation that uses `a`, `b`, and `c`.
6400 return foo_op(..., name=scope)
6401 ```
6403 When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`,
6404 and `MyOp/c`.
6406 Inside a `tf.function`, if the scope name already exists, the name will be
6407 made unique by appending `_n`. For example, calling `my_op` the second time
6408 will generate `MyOp_1/a`, etc.
6409 """
6411 __slots__ = ["_name", "_exit_fns"]
6413 def __init__(self, name):
6414 """Initialize the context manager.
6416 Args:
6417 name: The prefix to use on all names created within the name scope.
6419 Raises:
6420 ValueError: If name is not a string.
6421 """
6422 if not isinstance(name, str):
6423 raise ValueError("name for name_scope must be a string.")
6424 self._name = name
6425 self._exit_fns = []
6427 @property
6428 def name(self):
6429 return self._name
6431 def __enter__(self):
6432 """Start the scope block.
6434 Returns:
6435 The scope name.
6436 """
6437 ctx = context.context()
6438 if ctx.executing_eagerly():
6439 # Names are not auto-incremented in eager mode.
6440 # A trailing slash breaks out of nested name scopes, indicating a
6441 # fully specified scope name, for compatibility with Graph.name_scope.
6442 # This also prevents auto-incrementing.
6443 old_name = ctx.scope_name
6444 name = self._name
6445 if not name:
6446 scope_name = ""
6447 elif name[-1] == "/":
6448 scope_name = name
6449 elif old_name:
6450 scope_name = old_name + name + "/"
6451 else:
6452 scope_name = name + "/"
6453 ctx.scope_name = scope_name
6455 def _restore_name_scope(*_):
6456 ctx.scope_name = old_name
6458 self._exit_fns.append(_restore_name_scope)
6459 else:
6460 scope = get_default_graph().name_scope(self._name)
6461 scope_name = scope.__enter__()
6462 self._exit_fns.append(scope.__exit__)
6463 return scope_name
6465 def __exit__(self, type_arg, value_arg, traceback_arg):
6466 self._exit_fns.pop()(type_arg, value_arg, traceback_arg)
6467 return False # False values do not suppress exceptions
6469 def __getstate__(self):
6470 return self._name, self._exit_fns
6472 def __setstate__(self, state):
6473 self._name = state[0]
6474 self._exit_fns = state[1]
6477def strip_name_scope(name, export_scope):
6478 """Removes name scope from a name.
6480 Args:
6481 name: A `string` name.
6482 export_scope: Optional `string`. Name scope to remove.
6484 Returns:
6485 Name with name scope removed, or the original name if export_scope
6486 is None.
6487 """
6488 if export_scope:
6489 if export_scope[-1] == "/":
6490 export_scope = export_scope[:-1]
6492 try:
6493 # Strips export_scope/, export_scope///,
6494 # ^export_scope/, loc:@export_scope/.
6495 str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
6496 return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
6497 except TypeError as e:
6498 # If the name is not of a type we can process, simply return it.
6499 logging.warning(e)
6500 return name
6501 else:
6502 return name
6505def prepend_name_scope(name, import_scope):
6506 """Prepends name scope to a name.
6508 Args:
6509 name: A `string` name.
6510 import_scope: Optional `string`. Name scope to add.
6512 Returns:
6513 Name with name scope added, or the original name if import_scope
6514 is None.
6515 """
6516 if import_scope:
6517 if import_scope[-1] == "/":
6518 import_scope = import_scope[:-1]
6520 try:
6521 str_to_replace = r"([\^]|loc:@|^)(.*)"
6522 return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
6523 compat.as_str(name))
6524 except TypeError as e:
6525 # If the name is not of a type we can process, simply return it.
6526 logging.warning(e)
6527 return name
6528 else:
6529 return name
6532# pylint: disable=g-doc-return-or-yield
6533# pylint: disable=not-context-manager
6534@tf_export(v1=["op_scope"])
6535@tf_contextlib.contextmanager
6536def op_scope(values, name, default_name=None):
6537 """DEPRECATED. Same as name_scope above, just different argument order."""
6538 logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
6539 " use tf.name_scope(name, default_name, values)")
6540 with name_scope(name, default_name=default_name, values=values) as scope:
6541 yield scope
6544_proto_function_registry = registry.Registry("proto functions")
6547def register_proto_function(collection_name,
6548 proto_type=None,
6549 to_proto=None,
6550 from_proto=None):
6551 """Registers `to_proto` and `from_proto` functions for collection_name.
6553 `to_proto` function converts a Python object to the corresponding protocol
6554 buffer, and returns the protocol buffer.
6556 `from_proto` function converts protocol buffer into a Python object, and
6557 returns the object..
6559 Args:
6560 collection_name: Name of the collection.
6561 proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
6562 `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
6563 to_proto: Function that implements Python object to protobuf conversion.
6564 from_proto: Function that implements protobuf to Python object conversion.
6565 """
6566 if to_proto and not callable(to_proto):
6567 raise TypeError("to_proto must be callable.")
6568 if from_proto and not callable(from_proto):
6569 raise TypeError("from_proto must be callable.")
6571 _proto_function_registry.register((proto_type, to_proto, from_proto),
6572 collection_name)
6575def get_collection_proto_type(collection_name):
6576 """Returns the proto_type for collection_name."""
6577 try:
6578 return _proto_function_registry.lookup(collection_name)[0]
6579 except LookupError:
6580 return None
6583def get_to_proto_function(collection_name):
6584 """Returns the to_proto function for collection_name."""
6585 try:
6586 return _proto_function_registry.lookup(collection_name)[1]
6587 except LookupError:
6588 return None
6591def get_from_proto_function(collection_name):
6592 """Returns the from_proto function for collection_name."""
6593 try:
6594 return _proto_function_registry.lookup(collection_name)[2]
6595 except LookupError:
6596 return None
6599def _op_to_colocate_with(v, graph):
6600 """Operation object corresponding to v to use for colocation constraints."""
6601 if v is None:
6602 return None, None
6603 if isinstance(v, Operation):
6604 return v, None
6606 # We always want to colocate with the reference op.
6607 # When 'v' is a ResourceVariable, the reference op is the handle creating op.
6608 #
6609 # What this should be is:
6610 # if isinstance(v, ResourceVariable):
6611 # return v.handle.op, v
6612 # However, that would require a circular import dependency.
6613 # As of October 2018, there were attempts underway to remove
6614 # colocation constraints altogether. Assuming that will
6615 # happen soon, perhaps this hack to work around the circular
6616 # import dependency is acceptable.
6617 if hasattr(v, "handle") and isinstance(v.handle, Tensor):
6618 device_only_candidate = lambda: None
6619 device_only_candidate.device = v.device
6620 device_only_candidate.name = v.name
6621 if graph.building_function:
6622 return graph.capture(v.handle).op, device_only_candidate
6623 else:
6624 return v.handle.op, device_only_candidate
6625 if isinstance(v, EagerTensor) and not context.executing_eagerly():
6626 return convert_to_tensor(v, as_ref=True).op, None
6627 elif isinstance(v, internal.NativeObject):
6628 return v.op, None
6629 else:
6630 return convert_to_tensor(v, as_ref=True).op, None
6633# Helper functions for op wrapper modules generated by `python_op_gen`.
6636def to_raw_op(f):
6637 """Make a given op wrapper function `f` raw.
6639 Raw op wrappers can only be called with keyword arguments.
6641 Args:
6642 f: An op wrapper function to make raw.
6644 Returns:
6645 Raw `f`.
6646 """
6647 # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail
6648 # due to double-registration.
6649 f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__,
6650 f.__closure__)
6651 return kwarg_only(f)
6654def raise_from_not_ok_status(e, name):
6655 e.message += (" name: " + str(name if name is not None else ""))
6656 raise core._status_to_exception(e) from None # pylint: disable=protected-access
6659def add_exit_callback_to_default_func_graph(fn):
6660 """Add a callback to run when the default function graph goes out of scope.
6662 Usage:
6664 ```python
6665 @tf.function
6666 def fn(x, v):
6667 expensive = expensive_object(v)
6668 add_exit_callback_to_default_func_graph(lambda: expensive.release())
6669 return g(x, expensive)
6671 fn(x=tf.constant(...), v=...)
6672 # `expensive` has been released.
6673 ```
6675 Args:
6676 fn: A callable that takes no arguments and whose output is ignored.
6677 To be executed when exiting func graph scope.
6679 Raises:
6680 RuntimeError: If executed when the current default graph is not a FuncGraph,
6681 or not currently executing in function creation mode (e.g., if inside
6682 an init_scope).
6683 """
6684 default_graph = get_default_graph()
6685 if not default_graph._building_function: # pylint: disable=protected-access
6686 raise RuntimeError(
6687 "Cannot add scope exit callbacks when not building a function. "
6688 "Default graph: {}".format(default_graph))
6689 default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access
6692def _reconstruct_sequence_inputs(op_def, inputs, attrs):
6693 """Regroups a flat list of input tensors into scalar and sequence inputs.
6695 Args:
6696 op_def: The `op_def_pb2.OpDef` (for knowing the input types)
6697 inputs: a list of input `Tensor`s to the op.
6698 attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
6699 how long each sequence is)
6701 Returns:
6702 A list of `Tensor`s (corresponding to scalar inputs) and lists of
6703 `Tensor`s (corresponding to sequence inputs).
6704 """
6705 grouped_inputs = []
6706 i = 0
6707 for input_arg in op_def.input_arg:
6708 if input_arg.number_attr:
6709 input_len = attrs[input_arg.number_attr].i
6710 is_sequence = True
6711 elif input_arg.type_list_attr:
6712 input_len = len(attrs[input_arg.type_list_attr].list.type)
6713 is_sequence = True
6714 else:
6715 input_len = 1
6716 is_sequence = False
6718 if is_sequence:
6719 grouped_inputs.append(inputs[i:i + input_len])
6720 else:
6721 grouped_inputs.append(inputs[i])
6722 i += input_len
6724 assert i == len(inputs)
6725 return grouped_inputs
6728_numpy_style_type_promotion = False
6731def enable_numpy_style_type_promotion():
6732 """If called, follows NumPy's rules for type promotion.
6734 Used for enabling NumPy behavior on methods for TF NumPy.
6735 """
6736 global _numpy_style_type_promotion
6737 _numpy_style_type_promotion = True
6740_numpy_style_slicing = False
6743def enable_numpy_style_slicing():
6744 """If called, follows NumPy's rules for slicing Tensors.
6746 Used for enabling NumPy behavior on slicing for TF NumPy.
6747 """
6748 global _numpy_style_slicing
6749 _numpy_style_slicing = True
6752class _TensorIterator(object):
6753 """Iterates over the leading dim of a Tensor. Performs no error checks."""
6755 __slots__ = ["_tensor", "_index", "_limit"]
6757 def __init__(self, tensor, dim0):
6758 self._tensor = tensor
6759 self._index = 0
6760 self._limit = dim0
6762 def __iter__(self):
6763 return self
6765 def __next__(self):
6766 if self._index == self._limit:
6767 raise StopIteration
6768 result = self._tensor[self._index]
6769 self._index += 1
6770 return result
6772 next = __next__ # python2.x compatibility.
6775def set_int_list_attr(op, attr_name, ints):
6776 """TF internal method used to set a list(int) attribute in the node_def."""
6777 ints_list = attr_value_pb2.AttrValue.ListValue(i=ints)
6778 op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access
6781def _get_enclosing_context(graph):
6782 # pylint: disable=protected-access
6783 if graph is None:
6784 return None
6786 if graph._control_flow_context is not None:
6787 return graph._control_flow_context
6789 if graph.building_function and hasattr(graph, "outer_graph"):
6790 return _get_enclosing_context(graph.outer_graph)
6793# TODO(b/271463878): Remove in favor of direct references to `handle_data_util`.
6794get_resource_handle_data = handle_data_util.get_resource_handle_data
6797def _copy_handle_data_to_arg_def(tensor, arg_def):
6798 handle_data = handle_data_util.get_resource_handle_data(tensor)
6799 if handle_data.shape_and_type:
6800 shape_and_type = handle_data.shape_and_type[0]
6801 proto = arg_def.handle_data.add()
6802 proto.dtype = shape_and_type.dtype
6803 proto.shape.CopyFrom(handle_data.shape_and_type[0].shape)
6806# This will be replaced by a concrete implementation in a future CL.
6807@tf_export("__internal__.SymbolicTensor")
6808class SymbolicTensor(object):
6809 """Stub class for symbolic tensors."""
6812@tf_export("is_symbolic_tensor", v1=["is_symbolic_tensor"])
6813def is_symbolic_tensor(tensor):
6814 """Test if `tensor` is a symbolic Tensor.
6816 Args:
6817 tensor: a tensor-like object
6819 Returns:
6820 True if `tensor` is a symbolic tensor (not an eager tensor).
6821 """
6822 return type(tensor) == Tensor # pylint: disable=unidiomatic-typecheck