Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py: 27%
897 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to use variables as resources."""
17# pylint: disable=g-bad-name
18import contextlib
19import functools
20import weakref
22import numpy as np
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.framework import variable_pb2
26from tensorflow.core.function import trace_type
27from tensorflow.core.protobuf import struct_pb2
28from tensorflow.python.checkpoint import tensor_callable
29from tensorflow.python.client import pywrap_tf_session
30from tensorflow.python.compat import compat as forward_compat
31from tensorflow.python.eager import context
32from tensorflow.python.eager import record
33from tensorflow.python.eager import tape
34from tensorflow.python.framework import auto_control_deps_utils as acd
35from tensorflow.python.framework import composite_tensor
36from tensorflow.python.framework import composite_tensor_gradient
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import cpp_shape_inference_pb2
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import errors
41from tensorflow.python.framework import indexed_slices
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import tensor as tensor_module
44from tensorflow.python.framework import tensor_conversion_registry
45from tensorflow.python.framework import tensor_shape
46from tensorflow.python.framework import tensor_spec
47from tensorflow.python.ops import array_ops
48from tensorflow.python.ops import gen_array_ops
49from tensorflow.python.ops import gen_resource_variable_ops
50from tensorflow.python.ops import gen_state_ops
51from tensorflow.python.ops import handle_data_util
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import state_ops
54from tensorflow.python.ops import variables
55# go/tf-wildcard-import
56# pylint: disable=wildcard-import
57from tensorflow.python.ops.gen_resource_variable_ops import *
58# pylint: enable=wildcard-import
59from tensorflow.python.saved_model import nested_structure_coder
60from tensorflow.python.trackable import base as trackable
61from tensorflow.python.types import core
62from tensorflow.python.util import _pywrap_utils
63from tensorflow.python.util import compat
64from tensorflow.python.util.deprecation import deprecated
65from tensorflow.python.util.tf_export import tf_export
67acd.register_read_only_resource_op("ReadVariableOp")
68acd.register_read_only_resource_op("VariableShape")
69acd.register_read_only_resource_op("ResourceGather")
70acd.register_read_only_resource_op("ResourceGatherNd")
71acd.register_read_only_resource_op("_ReadVariablesOp")
73# TODO(allenl): Remove this alias and migrate callers.
74get_resource_handle_data = handle_data_util.get_resource_handle_data
77def get_eager_safe_handle_data(handle):
78 """Get the data handle from the Tensor `handle`."""
79 assert isinstance(handle, ops.Tensor)
81 if isinstance(handle, ops.EagerTensor):
82 return handle._handle_data # pylint: disable=protected-access
83 else:
84 return get_resource_handle_data(handle)
87def _set_handle_shapes_and_types(tensor, handle_data, graph_mode):
88 """Sets the shape inference result HandleData on tensor.
90 Args:
91 tensor: A `Tensor` or `EagerTensor`.
92 handle_data: A `CppShapeInferenceResult.HandleData`.
93 graph_mode: A python bool.
94 """
95 tensor._handle_data = handle_data # pylint: disable=protected-access
96 if not graph_mode:
97 return
99 # Not an EagerTensor, so a graph tensor.
100 shapes, types = zip(
101 *[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type])
102 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
103 shapes = [
104 [d.size for d in s.dim] # pylint: disable=g-complex-comprehension
105 if not s.unknown_rank else None for s in shapes
106 ]
107 with tensor._op.graph._c_graph.get() as c_graph: # pylint: disable=protected-access
108 pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
109 c_graph,
110 tensor._as_tf_output(), # pylint: disable=protected-access
111 shapes,
112 ranks,
113 types)
116def _combine_handle_data(handle, initial_value):
117 """Concats HandleData from tensors `handle` and `initial_value`.
119 Args:
120 handle: A `Tensor` of dtype `resource`.
121 initial_value: A `Tensor`.
123 Returns:
124 A `CppShapeInferenceResult.HandleData`. If `initial_value` has dtype
125 `variant`, the `HandleData` contains the concatenation of the shape_and_type
126 from both `handle` and `initial_value`.
128 Raises:
129 RuntimeError: If handle, which was returned by VarHandleOp, either has
130 no handle data, or its len(handle_data.shape_and_type) != 1.
131 """
132 assert handle.dtype == dtypes.resource
134 variable_handle_data = get_eager_safe_handle_data(handle)
136 if initial_value.dtype != dtypes.variant:
137 return variable_handle_data
139 extra_handle_data = get_eager_safe_handle_data(initial_value)
140 if extra_handle_data is not None and extra_handle_data.is_set:
141 if (variable_handle_data is None or not variable_handle_data.is_set or
142 len(variable_handle_data.shape_and_type) != 1):
143 raise RuntimeError(
144 "Expected VarHandleOp to return a length==1 shape_and_type, "
145 f"but saw: '{variable_handle_data}'")
146 variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
147 return variable_handle_data
150def _variable_handle_from_shape_and_dtype(shape,
151 dtype,
152 shared_name,
153 name,
154 graph_mode,
155 initial_value=None):
156 """Create a variable handle, copying in handle data from `initial_value`."""
157 container = ops.get_default_graph()._container # pylint: disable=protected-access
158 if container is None:
159 container = ""
160 shape = tensor_shape.as_shape(shape)
161 dtype = dtypes.as_dtype(dtype)
162 if not graph_mode:
163 if shared_name is not None:
164 raise errors.InternalError(
165 node_def=None,
166 op=None,
167 message="Using an explicit shared_name is "
168 "not allowed when executing eagerly.")
169 shared_name = context.anonymous_name()
171 handle = gen_resource_variable_ops.var_handle_op(
172 shape=shape,
173 dtype=dtype,
174 shared_name=shared_name,
175 name=name,
176 container=container)
177 if initial_value is None:
178 initial_value = handle
179 if graph_mode:
180 full_handle_data = _combine_handle_data(handle, initial_value)
181 _set_handle_shapes_and_types(handle, full_handle_data, graph_mode)
182 return handle
183 else:
184 handle_data = handle_data_util.create_handle_data(shape, dtype)
185 if initial_value is not None and initial_value.dtype == dtypes.variant:
186 extra_handle_data = get_eager_safe_handle_data(initial_value)
187 if extra_handle_data is not None and extra_handle_data.is_set:
188 if (not handle_data.is_set or len(handle_data.shape_and_type) != 1):
189 raise RuntimeError(
190 "Expected VarHandleOp to return a length==1 shape_and_type, "
191 f"but saw: '{handle_data}'")
192 handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
194 _set_handle_shapes_and_types(handle, handle_data, graph_mode)
195 return handle
198def eager_safe_variable_handle(initial_value, shape, shared_name, name,
199 graph_mode):
200 """Creates a variable handle with information to do shape inference.
202 The dtype is read from `initial_value` and stored in the returned
203 resource tensor's handle data.
205 If `initial_value.dtype == tf.variant`, we additionally extract the handle
206 data (if any) from `initial_value` and append it to the `handle_data`.
207 In this case, the returned tensor's handle data is in the form
209 ```
210 is_set: true
211 shape_and_type {
212 shape {
213 // initial_value.shape
214 }
215 dtype: DT_VARIANT
216 }
217 shape_and_type {
218 // handle_data(initial_value).shape_and_type[0]
219 }
220 shape_and_type {
221 // handle_data(initial_value).shape_and_type[1]
222 }
223 ...
224 ```
226 Ops that read from this tensor, such as `ReadVariableOp` and
227 `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]`
228 correspond to the handle data of the variant(s) stored in the Variable.
230 Args:
231 initial_value: A `Tensor`.
232 shape: The shape of the handle data. Can be `TensorShape(None)` (i.e.
233 unknown shape).
234 shared_name: A string.
235 name: A string.
236 graph_mode: A python bool.
238 Returns:
239 The handle, a `Tensor` of type `resource`.
240 """
241 dtype = initial_value.dtype.base_dtype
242 return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name,
243 graph_mode, initial_value)
246@contextlib.contextmanager
247def _handle_graph(handle):
248 # Note: might have an eager tensor but not be executing eagerly when building
249 # functions.
250 if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or
251 ops.has_default_graph()):
252 yield
253 else:
254 with handle.graph.as_default():
255 yield
258class EagerResourceDeleter:
259 """An object which cleans up a resource handle.
261 An alternative to defining a __del__ method on an object. The intended use is
262 that ResourceVariables or other objects with resource handles will maintain a
263 single reference to this object. When the parent object is collected, this
264 object will be too. Even if the parent object is part of a reference cycle,
265 the cycle will be collectable.
266 """
268 __slots__ = ["_handle", "_handle_device", "_context"]
270 def __init__(self, handle, handle_device):
271 if not isinstance(handle, ops.Tensor):
272 raise ValueError(
273 (f"Passed handle={handle} to EagerResourceDeleter. Was expecting "
274 f"the handle to be a `tf.Tensor`."))
275 self._handle = handle
276 self._handle_device = handle_device
277 # This is held since the __del__ function runs an op, and if the context()
278 # is collected before this object, there will be a segfault when running the
279 # op.
280 self._context = context.context()
282 def __del__(self):
283 # Resources follow object-identity when executing eagerly, so it is safe to
284 # delete the resource we have a handle to.
285 try:
286 # A packed EagerTensor doesn't own any resource.
287 if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed:
288 return
289 # This resource was created in eager mode. However, this destructor may be
290 # running in graph mode (especially during unit tests). To clean up
291 # successfully, we switch back into eager mode temporarily.
292 with context.eager_mode():
293 with ops.device(self._handle_device):
294 gen_resource_variable_ops.destroy_resource_op(
295 self._handle, ignore_lookup_error=True)
296 except TypeError:
297 # Suppress some exceptions, mainly for the case when we're running on
298 # module deletion. Things that can go wrong include the context module
299 # already being unloaded, self._handle._handle_data no longer being
300 # valid, and so on. Printing warnings in these cases is silly
301 # (exceptions raised from __del__ are printed as warnings to stderr).
302 pass # 'NoneType' object is not callable when the handle has been
303 # partially unloaded.
304 except AttributeError:
305 pass # 'NoneType' object has no attribute 'eager_mode' when context has
306 # been unloaded. Will catch other module unloads as well.
309def shape_safe_assign_variable_handle(handle, shape, value, name=None):
310 """Helper that checks shape compatibility and assigns variable."""
311 with _handle_graph(handle):
312 value_tensor = ops.convert_to_tensor(value)
313 shape.assert_is_compatible_with(value_tensor.shape)
314 return gen_resource_variable_ops.assign_variable_op(
315 handle, value_tensor, name=name)
318def _maybe_set_handle_data(dtype, handle, tensor):
319 if dtype == dtypes.variant:
320 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
321 # variant's handle data. Extract it.
322 handle_data = get_eager_safe_handle_data(handle)
323 if handle_data.is_set and len(handle_data.shape_and_type) > 1:
324 tensor._handle_data = ( # pylint: disable=protected-access
325 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
326 is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
329def variable_accessed(variable):
330 """Records that `variable` was accessed for the tape and FuncGraph."""
331 if hasattr(ops.get_default_graph(), "watch_variable"):
332 ops.get_default_graph().watch_variable(variable)
333 if variable.trainable:
334 tape.variable_accessed(variable)
337def default_variable_creator_v2(next_creator=None, **kwargs):
338 """Default variable creator."""
339 assert next_creator is None
340 initial_value = kwargs.get("initial_value", None)
341 trainable = kwargs.get("trainable", None)
342 validate_shape = kwargs.get("validate_shape", True)
343 caching_device = kwargs.get("caching_device", None)
344 name = kwargs.get("name", None)
345 variable_def = kwargs.get("variable_def", None)
346 dtype = kwargs.get("dtype", None)
347 import_scope = kwargs.get("import_scope", None)
348 constraint = kwargs.get("constraint", None)
349 distribute_strategy = kwargs.get("distribute_strategy", None)
350 synchronization = kwargs.get("synchronization", None)
351 aggregation = kwargs.get("aggregation", None)
352 shape = kwargs.get("shape", None)
353 experimental_enable_variable_lifting = kwargs.get(
354 "experimental_enable_variable_lifting", None)
356 return ResourceVariable(
357 initial_value=initial_value,
358 trainable=trainable,
359 validate_shape=validate_shape,
360 caching_device=caching_device,
361 name=name,
362 dtype=dtype,
363 constraint=constraint,
364 variable_def=variable_def,
365 import_scope=import_scope,
366 distribute_strategy=distribute_strategy,
367 synchronization=synchronization,
368 aggregation=aggregation,
369 shape=shape,
370 experimental_enable_variable_lifting=experimental_enable_variable_lifting,
371 )
374variables.default_variable_creator_v2 = default_variable_creator_v2
377class BaseResourceVariable(variables.Variable, core.Tensor):
378 """A python variable from an existing handle."""
380 # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
381 def __init__( # pylint: disable=super-init-not-called
382 self,
383 trainable=None,
384 shape=None,
385 dtype=None,
386 handle=None,
387 constraint=None,
388 synchronization=None,
389 aggregation=None,
390 distribute_strategy=None,
391 name=None,
392 unique_id=None,
393 handle_name=None,
394 graph_element=None,
395 initial_value=None,
396 initializer_op=None,
397 is_initialized_op=None,
398 cached_value=None,
399 save_slice_info=None,
400 caching_device=None,
401 in_graph_mode=None,
402 validate_shape=True,
403 **unused_kwargs):
404 """Creates a variable from a handle.
406 Args:
407 trainable: If `True`, GradientTapes automatically watch uses of this
408 Variable.
409 shape: The variable's shape. This shape can be set to tf.TensorShape(None)
410 in order to assign values of different shapes to this variable.
411 Otherwise (i.e. if the shape is fully determined), it will trigger run
412 time checks to ensure that each assignment is of the same shape.
413 dtype: The variable's dtype.
414 handle: The variable's handle
415 constraint: An optional projection function to be applied to the variable
416 after being updated by an `Optimizer` (e.g. used to implement norm
417 constraints or value constraints for layer weights). The function must
418 take as input the unprojected Tensor representing the value of the
419 variable and return the Tensor for the projected value (which must have
420 the same shape). Constraints are not safe to use when doing asynchronous
421 distributed training.
422 synchronization: Indicates when a distributed a variable will be
423 aggregated. Accepted values are constants defined in the class
424 `tf.VariableSynchronization`. By default the synchronization is set to
425 `AUTO` and the current `DistributionStrategy` chooses when to
426 synchronize.
427 aggregation: Indicates how a distributed variable will be aggregated.
428 Accepted values are constants defined in the class
429 `tf.VariableAggregation`.
430 distribute_strategy: The distribution strategy this variable was created
431 under.
432 name: The name for this variable.
433 unique_id: Internal. Unique ID for this variable's handle.
434 handle_name: The name for the variable's handle.
435 graph_element: Optional, required only in session.run-mode. Pre-created
436 tensor which reads this variable's value.
437 initial_value: Optional. Variable's initial value.
438 initializer_op: Operation which assigns the variable's initial value.
439 is_initialized_op: Pre-created operation to check whether this variable is
440 initialized.
441 cached_value: Pre-created operation to read this variable in a specific
442 device.
443 save_slice_info: Metadata for variable partitioning.
444 caching_device: Optional device string or function describing where the
445 Variable should be cached for reading. Defaults to the Variable's
446 device. If not `None`, caches on another device. Typical use is to
447 cache on the device where the Ops using the Variable reside, to
448 deduplicate copying through `Switch` and other conditional statements.
449 in_graph_mode: whether we are executing in TF1 graph mode. If None, will
450 detect within the function. This is to avoid repeated init_scope()
451 conetxt entrances which can add up.
452 validate_shape: If `False`, allows the variable to be initialized with a
453 value of unknown shape. If `True`, the default, the shape of
454 `initial_value` must be known.
455 """
456 if in_graph_mode is None:
457 with ops.init_scope():
458 self._in_graph_mode = not context.executing_eagerly()
459 else:
460 self._in_graph_mode = in_graph_mode
461 synchronization, aggregation, trainable = (
462 variables.validate_synchronization_aggregation_trainable(
463 synchronization, aggregation, trainable, name))
464 self._trainable = trainable
465 self._synchronization = synchronization
466 self._aggregation = aggregation
467 self._save_slice_info = save_slice_info
468 self._initial_value = initial_value
469 self._initializer_op = initializer_op
470 self._is_initialized_op = is_initialized_op
471 self._graph_element = graph_element
472 self._caching_device = caching_device
473 self._cached_value = cached_value
474 self._distribute_strategy = distribute_strategy
475 # Store the graph key so optimizers know how to only retrieve variables from
476 # this graph. Guaranteed to be the same as the eager graph_key.
477 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
478 self._shape = tensor_shape.as_shape(shape)
479 self._dtype = dtypes.as_dtype(dtype)
480 self._handle = handle
481 self._unique_id = unique_id
482 if handle_name is None:
483 self._handle_name = "Variable:0"
484 else:
485 self._handle_name = handle_name + ":0"
486 self._constraint = constraint
487 self._cached_shape_as_list = None
488 self._validate_shape = validate_shape
490 def __repr__(self):
491 if context.executing_eagerly() and not self._in_graph_mode:
492 # If we cannot read the value for any reason (e.g. variable uninitialized
493 # during tf.function tracing), still produce a __repr__. Note that for
494 # async eager, errors due to uninitialized variables will raise in
495 # ops.value_text when the handle is resolved, so we need to keep that
496 # under the try...except if we want to suppress them.
497 try:
498 with ops.device(self.device):
499 value_text = ops.value_text(self.read_value(), is_repr=True)
500 except: # pylint: disable=bare-except
501 value_text = "numpy=<unavailable>"
503 return "<tf.Variable '%s' shape=%s dtype=%s, %s>" % (
504 self.name, self.get_shape(), self.dtype.name, value_text)
505 else:
506 return "<tf.Variable '%s' shape=%s dtype=%s>" % (
507 self.name, self.get_shape(), self.dtype.name)
509 def __tf_tracing_type__(self, signature_context):
510 alias_id = signature_context.alias_global_id(self._handle._id) # pylint:disable=protected-access
511 # TODO(xjun): Create variable placeholders directly from VariableSpec
512 # without using original values.
513 signature_context.add_placeholder(alias_id, self)
514 return VariableSpec(shape=self.shape,
515 dtype=self.dtype,
516 trainable=self.trainable,
517 alias_id=alias_id)
519 @contextlib.contextmanager
520 def _assign_dependencies(self):
521 """Makes assignments depend on the cached value, if any.
523 This prevents undefined behavior with reads not ordered wrt writes.
525 Yields:
526 None.
527 """
528 if self._cached_value is not None:
529 with ops.control_dependencies([self._cached_value]):
530 yield
531 else:
532 yield
534 def __array__(self, dtype=None):
535 """Allows direct conversion to a numpy array.
537 >>> np.array(tf.Variable([1.0]))
538 array([1.], dtype=float32)
540 Returns:
541 The variable value as a numpy array.
542 """
543 # You can't return `self.numpy()` here because for scalars
544 # that raises:
545 # ValueError: object __array__ method not producing an array
546 # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give
547 # the same error. The `EagerTensor` class must be doing something behind the
548 # scenes to make `np.array(tf.constant(1))` work.
549 return np.asarray(self.numpy(), dtype=dtype)
551 def __nonzero__(self):
552 return self.__bool__()
554 def __bool__(self):
555 return bool(self.read_value())
557 def __copy__(self):
558 return self
560 def __deepcopy__(self, memo):
561 if not context.executing_eagerly():
562 raise NotImplementedError(
563 "__deepcopy__() is only available when eager execution is enabled.")
564 copied_variable = ResourceVariable(
565 initial_value=self.read_value(),
566 trainable=self._trainable,
567 constraint=self._constraint,
568 dtype=self._dtype,
569 name=self._shared_name,
570 distribute_strategy=self._distribute_strategy,
571 synchronization=self.synchronization,
572 aggregation=self.aggregation)
573 memo[self._unique_id] = copied_variable
574 return copied_variable
576 @property
577 def dtype(self):
578 """The dtype of this variable."""
579 return self._dtype
581 @property
582 def device(self):
583 """The device this variable is on."""
584 return self.handle.device
586 @property
587 def graph(self):
588 """The `Graph` of this variable."""
589 return self.handle.graph
591 @property
592 def name(self):
593 """The name of the handle for this variable."""
594 return self._handle_name
596 @property
597 def shape(self):
598 """The shape of this variable."""
599 return self._shape
601 def set_shape(self, shape):
602 self._shape = self._shape.merge_with(shape)
604 def _shape_as_list(self):
605 if self.shape.ndims is None:
606 return None
607 return [dim.value for dim in self.shape.dims]
609 def _shape_tuple(self):
610 shape = self._shape_as_list()
611 if shape is None:
612 return None
613 return tuple(shape)
615 @property
616 def create(self):
617 """The op responsible for initializing this variable."""
618 if not self._in_graph_mode:
619 raise RuntimeError("This operation is not supported "
620 "when eager execution is enabled.")
621 return self._initializer_op
623 @property
624 def handle(self):
625 """The handle by which this variable can be accessed."""
626 return self._handle
628 def value(self):
629 """A cached operation which reads the value of this variable."""
630 if self._cached_value is not None:
631 return self._cached_value
632 with ops.colocate_with(None, ignore_existing=True):
633 return self._read_variable_op()
635 def _as_graph_element(self):
636 """Conversion function for Graph.as_graph_element()."""
637 return self._graph_element
639 @property
640 def initializer(self):
641 """The op responsible for initializing this variable."""
642 return self._initializer_op
644 @property
645 def initial_value(self):
646 """Returns the Tensor used as the initial value for the variable."""
647 if context.executing_eagerly():
648 raise RuntimeError("This property is not supported "
649 "when eager execution is enabled.")
650 return self._initial_value
652 @property
653 def constraint(self):
654 """Returns the constraint function associated with this variable.
656 Returns:
657 The constraint function that was passed to the variable constructor.
658 Can be `None` if no constraint was passed.
659 """
660 return self._constraint
662 @property
663 def op(self):
664 """The op for this variable."""
665 return self.handle.op
667 @property
668 def trainable(self):
669 return self._trainable
671 @property
672 def synchronization(self):
673 return self._synchronization
675 @property
676 def aggregation(self):
677 return self._aggregation
679 def eval(self, session=None):
680 """Evaluates and returns the value of this variable."""
681 if context.executing_eagerly():
682 raise RuntimeError("This operation is not supported "
683 "when eager execution is enabled.")
684 return self._graph_element.eval(session=session)
686 def numpy(self):
687 if context.executing_eagerly():
688 return self.read_value().numpy()
689 raise NotImplementedError(
690 "numpy() is only available when eager execution is enabled.")
692 @deprecated(None, "Prefer Dataset.range instead.")
693 def count_up_to(self, limit):
694 """Increments this variable until it reaches `limit`.
696 When that Op is run it tries to increment the variable by `1`. If
697 incrementing the variable would bring it above `limit` then the Op raises
698 the exception `OutOfRangeError`.
700 If no error is raised, the Op outputs the value of the variable before
701 the increment.
703 This is essentially a shortcut for `count_up_to(self, limit)`.
705 Args:
706 limit: value at which incrementing the variable raises an error.
708 Returns:
709 A `Tensor` that will hold the variable value before the increment. If no
710 other Op modifies this variable, the values produced will all be
711 distinct.
712 """
713 return gen_state_ops.resource_count_up_to(
714 self.handle, limit=limit, T=self.dtype)
716 def _export_to_saved_model_graph(self, object_map=None, tensor_map=None,
717 options=None, **kwargs):
718 """For implementing `Trackable`."""
719 new_variable = None
720 if options.experimental_variable_policy._save_variable_devices(): # pylint:disable=protected-access
721 with ops.device(self.device):
722 new_variable = copy_to_graph_uninitialized(self)
723 else:
724 new_variable = copy_to_graph_uninitialized(self)
725 object_map[self] = new_variable
726 tensor_map[self.handle] = new_variable.handle
727 return [self.handle]
729 def _serialize_to_tensors(self):
730 """Implements Trackable._serialize_to_tensors."""
732 def _read_variable_closure():
733 v = self
734 with ops.device(v.device):
735 if context.executing_eagerly() and not v.is_initialized():
736 # A SaveSpec tensor value of `None` indicates that the variable is
737 # uninitialized.
738 return None
739 # Read the variable without making a copy to limit memory usage.
740 x = v.read_value_no_copy()
741 # To allow variables placed on non-CPU devices to be checkpointed,
742 # we copy them to CPU on the same machine first.
743 with ops.device("/device:CPU:0"):
744 return array_ops.identity(x)
746 return {
747 trackable.VARIABLE_VALUE_KEY:
748 tensor_callable.Callable(
749 _read_variable_closure, dtype=self.dtype, device=self.device)
750 }
752 def _restore_from_tensors(self, restored_tensors):
753 """Implements Trackable._restore_from_tensors."""
754 with ops.device(self.device):
755 restored_tensor = array_ops.identity(
756 restored_tensors[trackable.VARIABLE_VALUE_KEY])
757 try:
758 assigned_variable = shape_safe_assign_variable_handle(
759 self.handle, self.shape, restored_tensor)
760 except ValueError as e:
761 raise ValueError(
762 f"Received incompatible tensor with shape {restored_tensor.shape} "
763 f"when attempting to restore variable with shape {self.shape} "
764 f"and name {self.name}.") from e
765 return assigned_variable
767 def _read_variable_op(self, no_copy=False):
768 """Reads the value of the variable.
770 If the variable is in copy-on-read mode and `no_copy` is True, the variable
771 is converted to copy-on-write mode before it is read.
773 Args:
774 no_copy: Whether to prevent a copy of the variable.
776 Returns:
777 The value of the variable.
778 """
779 variable_accessed(self)
781 def read_and_set_handle(no_copy):
782 if no_copy and forward_compat.forward_compatible(2022, 5, 3):
783 gen_resource_variable_ops.disable_copy_on_read(self.handle)
784 result = gen_resource_variable_ops.read_variable_op(
785 self.handle, self._dtype)
786 _maybe_set_handle_data(self._dtype, self.handle, result)
787 return result
789 if getattr(self, "_caching_device", None) is not None:
790 with ops.colocate_with(None, ignore_existing=True):
791 with ops.device(self._caching_device):
792 result = read_and_set_handle(no_copy)
793 else:
794 result = read_and_set_handle(no_copy)
796 if not context.executing_eagerly():
797 # Note that if a control flow context is active the input of the read op
798 # might not actually be the handle. This line bypasses it.
799 record.record_operation(
800 "ReadVariableOp", [result], [self.handle],
801 backward_function=lambda x: [x],
802 forward_function=lambda x: [x])
803 return result
805 def read_value(self):
806 """Constructs an op which reads the value of this variable.
808 Should be used when there are multiple reads, or when it is desirable to
809 read the value only after some condition is true.
811 Returns:
812 The value of the variable.
813 """
814 with ops.name_scope("Read"):
815 value = self._read_variable_op()
816 # Return an identity so it can get placed on whatever device the context
817 # specifies instead of the device where the variable is.
818 return array_ops.identity(value)
820 def read_value_no_copy(self):
821 """Constructs an op which reads the value of this variable without copy.
823 The variable is read without making a copy even when it has been sparsely
824 accessed. Variables in copy-on-read mode will be converted to copy-on-write
825 mode.
827 Returns:
828 The value of the variable.
829 """
830 with ops.name_scope("Read"):
831 value = self._read_variable_op(no_copy=True)
832 # Return an identity so it can get placed on whatever device the context
833 # specifies instead of the device where the variable is.
834 return array_ops.identity(value)
836 def sparse_read(self, indices, name=None):
837 """Reads the value of this variable sparsely, using `gather`."""
838 with ops.name_scope("Gather" if name is None else name) as name:
839 variable_accessed(self)
840 value = gen_resource_variable_ops.resource_gather(
841 self.handle, indices, dtype=self._dtype, name=name)
843 if self._dtype == dtypes.variant:
844 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
845 # variant's handle data. Extract it.
846 handle_data = get_eager_safe_handle_data(self.handle)
847 if handle_data.is_set and len(handle_data.shape_and_type) > 1:
848 value._handle_data = ( # pylint: disable=protected-access
849 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
850 is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
851 return array_ops.identity(value)
853 return value
855 def gather_nd(self, indices, name=None):
856 """Reads the value of this variable sparsely, using `gather_nd`."""
857 with ops.name_scope("GatherNd" if name is None else name) as name:
858 if self.trainable:
859 variable_accessed(self)
860 value = gen_resource_variable_ops.resource_gather_nd(
861 self.handle, indices, dtype=self._dtype, name=name)
863 return array_ops.identity(value)
865 def to_proto(self, export_scope=None):
866 """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
868 Args:
869 export_scope: Optional `string`. Name scope to remove.
871 Raises:
872 RuntimeError: If run in EAGER mode.
874 Returns:
875 A `VariableDef` protocol buffer, or `None` if the `Variable` is not
876 in the specified name scope.
877 """
878 if context.executing_eagerly():
879 raise RuntimeError("This operation is not supported "
880 "when eager execution is enabled.")
881 if export_scope is None or self.handle.name.startswith(export_scope):
882 var_def = variable_pb2.VariableDef()
883 var_def.variable_name = ops.strip_name_scope(self.handle.name,
884 export_scope)
885 if self._initial_value is not None:
886 # This is inside an if-statement for backwards compatibility, since
887 # self._initial_value might be None for variables constructed from old
888 # protos.
889 var_def.initial_value_name = ops.strip_name_scope(
890 self._initial_value.name, export_scope)
891 var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
892 export_scope)
893 if self._cached_value is not None:
894 var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
895 export_scope)
896 else:
897 # Store the graph_element here
898 var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
899 export_scope)
900 var_def.is_resource = True
901 var_def.trainable = self.trainable
902 var_def.synchronization = self.synchronization.value
903 var_def.aggregation = self.aggregation.value
904 if self._save_slice_info:
905 var_def.save_slice_info_def.MergeFrom(
906 self._save_slice_info.to_proto(export_scope=export_scope))
907 return var_def
908 else:
909 return None
911 @staticmethod
912 def from_proto(variable_def, import_scope=None):
913 if context.executing_eagerly():
914 raise RuntimeError("This operation is not supported "
915 "when eager execution is enabled.")
916 return ResourceVariable(
917 variable_def=variable_def, import_scope=import_scope)
919 __array_priority__ = 100
921 def is_initialized(self, name=None):
922 """Checks whether a resource variable has been initialized.
924 Outputs boolean scalar indicating whether the tensor has been initialized.
926 Args:
927 name: A name for the operation (optional).
929 Returns:
930 A `Tensor` of type `bool`.
931 """
932 return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
934 def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
935 """Subtracts a value from this variable.
937 Args:
938 delta: A `Tensor`. The value to subtract from this variable.
939 use_locking: If `True`, use locking during the operation.
940 name: The name to use for the operation.
941 read_value: A `bool`. Whether to read and return the new value of the
942 variable or not.
944 Returns:
945 If `read_value` is `True`, this method will return the new value of the
946 variable after the assignment has completed. Otherwise, when in graph mode
947 it will return the `Operation` that does the assignment, and when in eager
948 mode it will return `None`.
949 """
950 # TODO(apassos): this here and below is not atomic. Consider making it
951 # atomic if there's a way to do so without a performance cost for those who
952 # don't need it.
953 with _handle_graph(self.handle), self._assign_dependencies():
954 assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
955 self.handle,
956 ops.convert_to_tensor(delta, dtype=self.dtype),
957 name=name)
958 if read_value:
959 return self._lazy_read(assign_sub_op)
960 return assign_sub_op
962 def assign_add(self, delta, use_locking=None, name=None, read_value=True):
963 """Adds a value to this variable.
965 Args:
966 delta: A `Tensor`. The value to add to this variable.
967 use_locking: If `True`, use locking during the operation.
968 name: The name to use for the operation.
969 read_value: A `bool`. Whether to read and return the new value of the
970 variable or not.
972 Returns:
973 If `read_value` is `True`, this method will return the new value of the
974 variable after the assignment has completed. Otherwise, when in graph mode
975 it will return the `Operation` that does the assignment, and when in eager
976 mode it will return `None`.
977 """
978 with _handle_graph(self.handle), self._assign_dependencies():
979 assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
980 self.handle,
981 ops.convert_to_tensor(delta, dtype=self.dtype),
982 name=name)
983 if read_value:
984 return self._lazy_read(assign_add_op)
985 return assign_add_op
987 def _lazy_read(self, op):
988 variable_accessed(self)
989 return _UnreadVariable(
990 handle=self.handle,
991 dtype=self.dtype,
992 shape=self._shape,
993 in_graph_mode=self._in_graph_mode,
994 parent_op=op,
995 unique_id=self._unique_id)
997 def assign(self, value, use_locking=None, name=None, read_value=True):
998 """Assigns a new value to this variable.
1000 Args:
1001 value: A `Tensor`. The new value for this variable.
1002 use_locking: If `True`, use locking during the assignment.
1003 name: The name to use for the assignment.
1004 read_value: A `bool`. Whether to read and return the new value of the
1005 variable or not.
1007 Returns:
1008 If `read_value` is `True`, this method will return the new value of the
1009 variable after the assignment has completed. Otherwise, when in graph mode
1010 it will return the `Operation` that does the assignment, and when in eager
1011 mode it will return `None`.
1012 """
1013 # Note: not depending on the cached value here since this can be used to
1014 # initialize the variable.
1015 with _handle_graph(self.handle):
1016 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
1017 if not self._shape.is_compatible_with(value_tensor.shape):
1018 if self.name is None:
1019 tensor_name = ""
1020 else:
1021 tensor_name = " " + str(self.name)
1022 raise ValueError(
1023 (f"Cannot assign value to variable '{tensor_name}': Shape mismatch."
1024 f"The variable shape {self._shape}, and the "
1025 f"assigned value shape {value_tensor.shape} are incompatible."))
1026 kwargs = {}
1027 if forward_compat.forward_compatible(2022, 3, 23):
1028 # If the shape is fully defined, we do a runtime check with the shape of
1029 # value.
1030 validate_shape = self._validate_shape and self._shape.is_fully_defined()
1031 kwargs["validate_shape"] = validate_shape
1032 assign_op = gen_resource_variable_ops.assign_variable_op(
1033 self.handle, value_tensor, name=name, **kwargs)
1034 if read_value:
1035 return self._lazy_read(assign_op)
1036 return assign_op
1038 def __reduce__(self):
1039 # The implementation mirrors that of __deepcopy__.
1040 return functools.partial(
1041 ResourceVariable,
1042 initial_value=self.numpy(),
1043 trainable=self.trainable,
1044 name=self._shared_name,
1045 dtype=self.dtype,
1046 constraint=self.constraint,
1047 distribute_strategy=self._distribute_strategy), ()
1049 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
1050 """Subtracts `tf.IndexedSlices` from this variable.
1052 Args:
1053 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
1054 use_locking: If `True`, use locking during the operation.
1055 name: the name of the operation.
1057 Returns:
1058 The updated variable.
1060 Raises:
1061 TypeError: if `sparse_delta` is not an `IndexedSlices`.
1062 """
1063 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1064 raise TypeError(f"Argument `sparse_delta` must be a "
1065 f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1066 return self._lazy_read(
1067 gen_resource_variable_ops.resource_scatter_sub(
1068 self.handle,
1069 sparse_delta.indices,
1070 ops.convert_to_tensor(sparse_delta.values, self.dtype),
1071 name=name))
1073 def scatter_add(self, sparse_delta, use_locking=False, name=None):
1074 """Adds `tf.IndexedSlices` to this variable.
1076 Args:
1077 sparse_delta: `tf.IndexedSlices` to be added to this variable.
1078 use_locking: If `True`, use locking during the operation.
1079 name: the name of the operation.
1081 Returns:
1082 The updated variable.
1084 Raises:
1085 TypeError: if `sparse_delta` is not an `IndexedSlices`.
1086 """
1087 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1088 raise TypeError(f"Argument `sparse_delta` must be a "
1089 f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1090 return self._lazy_read(
1091 gen_resource_variable_ops.resource_scatter_add(
1092 self.handle,
1093 sparse_delta.indices,
1094 ops.convert_to_tensor(sparse_delta.values, self.dtype),
1095 name=name))
1097 def scatter_max(self, sparse_delta, use_locking=False, name=None):
1098 """Updates this variable with the max of `tf.IndexedSlices` and itself.
1100 Args:
1101 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
1102 variable.
1103 use_locking: If `True`, use locking during the operation.
1104 name: the name of the operation.
1106 Returns:
1107 The updated variable.
1109 Raises:
1110 TypeError: if `sparse_delta` is not an `IndexedSlices`.
1111 """
1112 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1113 raise TypeError(f"Argument `sparse_delta` must be a "
1114 f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1115 return self._lazy_read(
1116 gen_resource_variable_ops.resource_scatter_max(
1117 self.handle,
1118 sparse_delta.indices,
1119 ops.convert_to_tensor(sparse_delta.values, self.dtype),
1120 name=name))
1122 def scatter_min(self, sparse_delta, use_locking=False, name=None):
1123 """Updates this variable with the min of `tf.IndexedSlices` and itself.
1125 Args:
1126 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
1127 variable.
1128 use_locking: If `True`, use locking during the operation.
1129 name: the name of the operation.
1131 Returns:
1132 The updated variable.
1134 Raises:
1135 TypeError: if `sparse_delta` is not an `IndexedSlices`.
1136 """
1137 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1138 raise TypeError(f"Argument `sparse_delta` must be a "
1139 f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1140 return self._lazy_read(
1141 gen_resource_variable_ops.resource_scatter_min(
1142 self.handle,
1143 sparse_delta.indices,
1144 ops.convert_to_tensor(sparse_delta.values, self.dtype),
1145 name=name))
1147 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
1148 """Multiply this variable by `tf.IndexedSlices`.
1150 Args:
1151 sparse_delta: `tf.IndexedSlices` to multiply this variable by.
1152 use_locking: If `True`, use locking during the operation.
1153 name: the name of the operation.
1155 Returns:
1156 The updated variable.
1158 Raises:
1159 TypeError: if `sparse_delta` is not an `IndexedSlices`.
1160 """
1161 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1162 raise TypeError(f"Argument `sparse_delta` must be a "
1163 f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1164 return self._lazy_read(
1165 gen_resource_variable_ops.resource_scatter_mul(
1166 self.handle,
1167 sparse_delta.indices,
1168 ops.convert_to_tensor(sparse_delta.values, self.dtype),
1169 name=name))
1171 def scatter_div(self, sparse_delta, use_locking=False, name=None):
1172 """Divide this variable by `tf.IndexedSlices`.
1174 Args:
1175 sparse_delta: `tf.IndexedSlices` to divide this variable by.
1176 use_locking: If `True`, use locking during the operation.
1177 name: the name of the operation.
1179 Returns:
1180 The updated variable.
1182 Raises:
1183 TypeError: if `sparse_delta` is not an `IndexedSlices`.
1184 """
1185 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1186 raise TypeError(f"Argument `sparse_delta` must be a "
1187 f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1188 return self._lazy_read(
1189 gen_resource_variable_ops.resource_scatter_div(
1190 self.handle,
1191 sparse_delta.indices,
1192 ops.convert_to_tensor(sparse_delta.values, self.dtype),
1193 name=name))
1195 def scatter_update(self, sparse_delta, use_locking=False, name=None):
1196 """Assigns `tf.IndexedSlices` to this variable.
1198 Args:
1199 sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
1200 use_locking: If `True`, use locking during the operation.
1201 name: the name of the operation.
1203 Returns:
1204 The updated variable.
1206 Raises:
1207 TypeError: if `sparse_delta` is not an `IndexedSlices`.
1208 """
1209 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1210 raise TypeError(f"Argument `sparse_delta` must be a "
1211 f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1212 return self._lazy_read(
1213 gen_resource_variable_ops.resource_scatter_update(
1214 self.handle,
1215 sparse_delta.indices,
1216 ops.convert_to_tensor(sparse_delta.values, self.dtype),
1217 name=name))
1219 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
1220 """Assigns `tf.IndexedSlices` to this variable batch-wise.
1222 Analogous to `batch_gather`. This assumes that this variable and the
1223 sparse_delta IndexedSlices have a series of leading dimensions that are the
1224 same for all of them, and the updates are performed on the last dimension of
1225 indices. In other words, the dimensions should be the following:
1227 `num_prefix_dims = sparse_delta.indices.ndims - 1`
1228 `batch_dim = num_prefix_dims + 1`
1229 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
1230 batch_dim:]`
1232 where
1234 `sparse_delta.updates.shape[:num_prefix_dims]`
1235 `== sparse_delta.indices.shape[:num_prefix_dims]`
1236 `== var.shape[:num_prefix_dims]`
1238 And the operation performed can be expressed as:
1240 `var[i_1, ..., i_n,
1241 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
1242 i_1, ..., i_n, j]`
1244 When sparse_delta.indices is a 1D tensor, this operation is equivalent to
1245 `scatter_update`.
1247 To avoid this operation one can looping over the first `ndims` of the
1248 variable and using `scatter_update` on the subtensors that result of slicing
1249 the first dimension. This is a valid option for `ndims = 1`, but less
1250 efficient than this implementation.
1252 Args:
1253 sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
1254 use_locking: If `True`, use locking during the operation.
1255 name: the name of the operation.
1257 Returns:
1258 The updated variable.
1260 Raises:
1261 TypeError: if `sparse_delta` is not an `IndexedSlices`.
1262 """
1263 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1264 raise TypeError(f"Argument `sparse_delta` must be a "
1265 f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1266 return self._lazy_read(
1267 state_ops.batch_scatter_update(
1268 self,
1269 sparse_delta.indices,
1270 sparse_delta.values,
1271 use_locking=use_locking,
1272 name=name))
1274 def scatter_nd_sub(self, indices, updates, name=None):
1275 """Applies sparse subtraction to individual values or slices in a Variable.
1277 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1279 `indices` must be integer tensor, containing indices into `ref`.
1280 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1282 The innermost dimension of `indices` (with length `K`) corresponds to
1283 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1284 dimension of `ref`.
1286 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1288 ```
1289 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1290 ```
1292 For example, say we want to add 4 scattered elements to a rank-1 tensor to
1293 8 elements. In Python, that update would look like this:
1295 ```python
1296 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1297 indices = tf.constant([[4], [3], [1] ,[7]])
1298 updates = tf.constant([9, 10, 11, 12])
1299 op = ref.scatter_nd_sub(indices, updates)
1300 with tf.compat.v1.Session() as sess:
1301 print sess.run(op)
1302 ```
1304 The resulting update to ref would look like this:
1306 [1, -9, 3, -6, -6, 6, 7, -4]
1308 See `tf.scatter_nd` for more details about how to make updates to
1309 slices.
1311 Args:
1312 indices: The indices to be used in the operation.
1313 updates: The values to be used in the operation.
1314 name: the name of the operation.
1316 Returns:
1317 The updated variable.
1318 """
1319 return self._lazy_read(
1320 gen_state_ops.resource_scatter_nd_sub(
1321 self.handle,
1322 indices,
1323 ops.convert_to_tensor(updates, self.dtype),
1324 name=name))
1326 def scatter_nd_add(self, indices, updates, name=None):
1327 """Applies sparse addition to individual values or slices in a Variable.
1329 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1331 `indices` must be integer tensor, containing indices into `ref`.
1332 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1334 The innermost dimension of `indices` (with length `K`) corresponds to
1335 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1336 dimension of `ref`.
1338 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1340 ```
1341 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1342 ```
1344 For example, say we want to add 4 scattered elements to a rank-1 tensor to
1345 8 elements. In Python, that update would look like this:
1347 ```python
1348 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1349 indices = tf.constant([[4], [3], [1] ,[7]])
1350 updates = tf.constant([9, 10, 11, 12])
1351 add = ref.scatter_nd_add(indices, updates)
1352 with tf.compat.v1.Session() as sess:
1353 print sess.run(add)
1354 ```
1356 The resulting update to ref would look like this:
1358 [1, 13, 3, 14, 14, 6, 7, 20]
1360 See `tf.scatter_nd` for more details about how to make updates to
1361 slices.
1363 Args:
1364 indices: The indices to be used in the operation.
1365 updates: The values to be used in the operation.
1366 name: the name of the operation.
1368 Returns:
1369 The updated variable.
1370 """
1371 return self._lazy_read(
1372 gen_state_ops.resource_scatter_nd_add(
1373 self.handle,
1374 indices,
1375 ops.convert_to_tensor(updates, self.dtype),
1376 name=name))
1378 def scatter_nd_update(self, indices, updates, name=None):
1379 """Applies sparse assignment to individual values or slices in a Variable.
1381 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1383 `indices` must be integer tensor, containing indices into `ref`.
1384 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1386 The innermost dimension of `indices` (with length `K`) corresponds to
1387 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1388 dimension of `ref`.
1390 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1392 ```
1393 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1394 ```
1396 For example, say we want to add 4 scattered elements to a rank-1 tensor to
1397 8 elements. In Python, that update would look like this:
1399 ```python
1400 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1401 indices = tf.constant([[4], [3], [1] ,[7]])
1402 updates = tf.constant([9, 10, 11, 12])
1403 op = ref.scatter_nd_update(indices, updates)
1404 with tf.compat.v1.Session() as sess:
1405 print sess.run(op)
1406 ```
1408 The resulting update to ref would look like this:
1410 [1, 11, 3, 10, 9, 6, 7, 12]
1412 See `tf.scatter_nd` for more details about how to make updates to
1413 slices.
1415 Args:
1416 indices: The indices to be used in the operation.
1417 updates: The values to be used in the operation.
1418 name: the name of the operation.
1420 Returns:
1421 The updated variable.
1422 """
1423 return self._lazy_read(
1424 gen_state_ops.resource_scatter_nd_update(
1425 self.handle,
1426 indices,
1427 ops.convert_to_tensor(updates, self.dtype),
1428 name=name))
1430 def scatter_nd_max(self, indices, updates, name=None):
1431 """Updates this variable with the max of `tf.IndexedSlices` and itself.
1433 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1435 `indices` must be integer tensor, containing indices into `ref`.
1436 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1438 The innermost dimension of `indices` (with length `K`) corresponds to
1439 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1440 dimension of `ref`.
1442 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1444 ```
1445 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1446 ```
1448 See `tf.scatter_nd` for more details about how to make updates to
1449 slices.
1451 Args:
1452 indices: The indices to be used in the operation.
1453 updates: The values to be used in the operation.
1454 name: the name of the operation.
1456 Returns:
1457 The updated variable.
1458 """
1459 return self._lazy_read(
1460 gen_state_ops.resource_scatter_nd_max(
1461 self.handle,
1462 indices,
1463 ops.convert_to_tensor(updates, self.dtype),
1464 name=name))
1466 def scatter_nd_min(self, indices, updates, name=None):
1467 """Updates this variable with the min of `tf.IndexedSlices` and itself.
1469 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1471 `indices` must be integer tensor, containing indices into `ref`.
1472 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1474 The innermost dimension of `indices` (with length `K`) corresponds to
1475 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1476 dimension of `ref`.
1478 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1480 ```
1481 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1482 ```
1484 See `tf.scatter_nd` for more details about how to make updates to
1485 slices.
1487 Args:
1488 indices: The indices to be used in the operation.
1489 updates: The values to be used in the operation.
1490 name: the name of the operation.
1492 Returns:
1493 The updated variable.
1494 """
1495 return self._lazy_read(
1496 gen_state_ops.resource_scatter_nd_min(
1497 self.handle,
1498 indices,
1499 ops.convert_to_tensor(updates, self.dtype),
1500 name=name))
1502 def _write_object_proto(self, proto, options):
1503 """Writes additional information of the variable into the SavedObject proto.
1505 Subclasses of ResourceVariables could choose to override this method to
1506 customize extra information to provide when saving a SavedModel.
1508 Ideally, this should contain the logic in
1509 write_object_proto_for_resource_variable but `DistributedValue` is an
1510 outlier at the momemnt. Once `DistributedValue` becomes a proper
1511 ResourceVariable, we should remove the helper method below.
1513 Args:
1514 proto: `SavedObject` proto to update.
1515 options: A `SaveOption` instance that configures save behavior.
1516 """
1517 write_object_proto_for_resource_variable(self, proto, options)
1519 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
1520 end_mask, ellipsis_mask, new_axis_mask,
1521 shrink_axis_mask):
1522 with _handle_graph(self.handle), self._assign_dependencies():
1523 return self._lazy_read(
1524 gen_array_ops.resource_strided_slice_assign(
1525 ref=self.handle,
1526 begin=begin,
1527 end=end,
1528 strides=strides,
1529 value=ops.convert_to_tensor(value, dtype=self.dtype),
1530 name=name,
1531 begin_mask=begin_mask,
1532 end_mask=end_mask,
1533 ellipsis_mask=ellipsis_mask,
1534 new_axis_mask=new_axis_mask,
1535 shrink_axis_mask=shrink_axis_mask))
1537 def __complex__(self):
1538 return complex(self.value().numpy())
1540 def __int__(self):
1541 return int(self.value().numpy())
1543 def __long__(self):
1544 return long(self.value().numpy())
1546 def __float__(self):
1547 return float(self.value().numpy())
1549 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1550 del name
1551 if dtype is not None and not dtype.is_compatible_with(self.dtype):
1552 raise ValueError(
1553 f"Incompatible type conversion requested to type {dtype.name} for "
1554 f"`tf.Variable of type {self.dtype.name}. (Variable: {self})")
1555 if as_ref:
1556 return self.read_value().op.inputs[0]
1557 else:
1558 return self.value()
1560 def __iadd__(self, unused_other):
1561 raise RuntimeError("`variable += value` with `tf.Variable`s is not "
1562 "supported. Use `variable.assign_add(value)` to modify "
1563 "the variable, or `out = variable + value` if you "
1564 "need to get a new output Tensor.")
1566 def __isub__(self, unused_other):
1567 raise RuntimeError("`variable -= value` with `tf.Variable`s is not "
1568 "supported. Use `variable.assign_sub(value)` to modify "
1569 "the variable, or `out = variable * value` if you "
1570 "need to get a new output Tensor.")
1572 def __imul__(self, unused_other):
1573 raise RuntimeError("`var *= value` with `tf.Variable`s is not "
1574 "supported. Use `var.assign(var * value)` to modify "
1575 "the variable, or `out = var * value` if you "
1576 "need to get a new output Tensor.")
1578 def __idiv__(self, unused_other):
1579 raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1580 "supported. Use `var.assign(var / value)` to modify "
1581 "the variable, or `out = var / value` if you "
1582 "need to get a new output Tensor.")
1584 def __itruediv__(self, unused_other):
1585 raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1586 "supported. Use `var.assign(var / value)` to modify "
1587 "the variable, or `out = var / value` if you "
1588 "need to get a new output Tensor.")
1590 def __irealdiv__(self, unused_other):
1591 raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1592 "supported. Use `var.assign(var / value)` to modify "
1593 "the variable, or `out = var / value` if you "
1594 "need to get a new output Tensor.")
1596 def __ipow__(self, unused_other):
1597 raise RuntimeError("`var **= value` with `tf.Variable`s is not "
1598 "supported. Use `var.assign(var ** value)` to modify "
1599 "the variable, or `out = var ** value` if you "
1600 "need to get a new output Tensor.")
1603class ResourceVariableGradient(
1604 composite_tensor_gradient.CompositeTensorGradient):
1605 """CompositeTensorGradient protocol for ResourceVariable."""
1607 # TODO(b/246997907): update this method to return value.handle.
1608 def get_gradient_components(self, value):
1609 """Returns the components of `value` that should be included in gradients.
1611 For a ResourceVariable, its gradient component is its handle tensor.
1612 For now, we return the ResourceVariable because the gradient infrastructure
1613 has special logics to handle ResourceVariables. We should remove those
1614 special logics and return the handle tensor.
1616 Args:
1617 value: A `ResourceVariable`.
1619 Returns:
1620 `value` itself.
1621 """
1622 return value
1624 def replace_gradient_components(self, value, component_grads):
1625 """Replaces the gradient components in `value` with `component_grads`.
1627 The gradient of a ResourceVariable is either None or a Tensor. So we don't
1628 need `value`'s TypeSpec or non-gradient components in this method.
1630 Args:
1631 value: A `ResourceVariable` with its gradient components compatible with
1632 `component_grads`.
1633 component_grads: A `Tensor` or None as the gradient result.
1635 Returns:
1636 The `component_grads`, which is either a `Tensor` or None.
1637 """
1638 return component_grads
1641class ResourceVariable(BaseResourceVariable, composite_tensor.CompositeTensor):
1642 """Variable based on resource handles.
1644 See the [Variables How To](https://tensorflow.org/guide/variables)
1645 for a high level overview.
1647 A `ResourceVariable` allows you to maintain state across subsequent calls to
1648 session.run.
1650 The `ResourceVariable` constructor requires an initial value for the variable,
1651 which can be a `Tensor` of any type and shape. The initial value defines the
1652 type and shape of the variable. After construction, the type and shape of
1653 the variable are fixed. The value can be changed using one of the assign
1654 methods.
1656 Just like any `Tensor`, variables created with
1657 `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the
1658 graph. Additionally, all the operators overloaded for the `Tensor` class are
1659 carried over to variables, so you can also add nodes to the graph by just
1660 doing arithmetic on variables.
1662 Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each
1663 usage of a ResourceVariable in a TensorFlow graph adds a read_value operation
1664 to the graph. The Tensors returned by a read_value operation are guaranteed to
1665 see all modifications to the value of the variable which happen in any
1666 operation on which the read_value depends on (either directly, indirectly, or
1667 via a control dependency) and guaranteed to not see any modification to the
1668 value of the variable from operations that depend on the read_value operation.
1669 Updates from operations that have no dependency relationship to the read_value
1670 operation might or might not be visible to read_value.
1672 For example, if there is more than one assignment to a ResourceVariable in
1673 a single session.run call there is a well-defined value for each operation
1674 which uses the variable's value if the assignments and the read are connected
1675 by edges in the graph. Consider the following example, in which two writes
1676 can cause tf.Variable and tf.ResourceVariable to behave differently:
1678 ```python
1679 a = tf.Variable(1.0, use_resource=True)
1680 a.initializer.run()
1682 assign = a.assign(2.0)
1683 with tf.control_dependencies([assign]):
1684 b = a.read_value()
1685 with tf.control_dependencies([b]):
1686 other_assign = a.assign(3.0)
1687 with tf.control_dependencies([other_assign]):
1688 # Will print 2.0 because the value was read before other_assign ran. If
1689 # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed.
1690 tf.compat.v1.Print(b, [b]).eval()
1691 ```
1692 """
1694 def __init__(
1695 self, # pylint: disable=super-init-not-called
1696 initial_value=None,
1697 trainable=None,
1698 collections=None,
1699 validate_shape=True, # pylint: disable=unused-argument
1700 caching_device=None,
1701 name=None,
1702 dtype=None,
1703 variable_def=None,
1704 import_scope=None,
1705 constraint=None,
1706 distribute_strategy=None,
1707 synchronization=None,
1708 aggregation=None,
1709 shape=None,
1710 handle=None,
1711 experimental_enable_variable_lifting=None,
1712 ):
1713 """Creates a variable.
1715 Args:
1716 initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1717 which is the initial value for the Variable. Can also be a callable with
1718 no argument that returns the initial value when called. (Note that
1719 initializer functions from init_ops.py must first be bound to a shape
1720 before being used here.)
1721 trainable: If `True`, the default, also adds the variable to the graph
1722 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1723 the default list of variables to use by the `Optimizer` classes.
1724 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
1725 which case it defaults to `False`.
1726 collections: List of graph collections keys. The new variable is added to
1727 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1728 validate_shape: If `False`, allows the variable to be initialized with a
1729 value of unknown shape. If `True`, the default, the shape of
1730 `initial_value` must be known.
1731 caching_device: Optional device string or function describing where the
1732 Variable should be cached for reading. Defaults to the Variable's
1733 device. If not `None`, caches on another device. Typical use is to
1734 cache on the device where the Ops using the Variable reside, to
1735 deduplicate copying through `Switch` and other conditional statements.
1736 name: Optional name for the variable. Defaults to `'Variable'` and gets
1737 uniquified automatically.
1738 dtype: If set, initial_value will be converted to the given type. If None,
1739 either the datatype will be kept (if initial_value is a Tensor) or
1740 float32 will be used (if it is a Python object convertible to a Tensor).
1741 variable_def: `VariableDef` protocol buffer. If not None, recreates the
1742 `ResourceVariable` object with its contents. `variable_def` and other
1743 arguments (except for import_scope) are mutually exclusive.
1744 import_scope: Optional `string`. Name scope to add to the
1745 ResourceVariable. Only used when `variable_def` is provided.
1746 constraint: An optional projection function to be applied to the variable
1747 after being updated by an `Optimizer` (e.g. used to implement norm
1748 constraints or value constraints for layer weights). The function must
1749 take as input the unprojected Tensor representing the value of the
1750 variable and return the Tensor for the projected value (which must have
1751 the same shape). Constraints are not safe to use when doing asynchronous
1752 distributed training.
1753 distribute_strategy: The tf.distribute.Strategy this variable is being
1754 created inside of.
1755 synchronization: Indicates when a distributed a variable will be
1756 aggregated. Accepted values are constants defined in the class
1757 `tf.VariableSynchronization`. By default the synchronization is set to
1758 `AUTO` and the current `DistributionStrategy` chooses when to
1759 synchronize.
1760 aggregation: Indicates how a distributed variable will be aggregated.
1761 Accepted values are constants defined in the class
1762 `tf.VariableAggregation`.
1763 shape: (optional) The shape of this variable. If None, the shape of
1764 `initial_value` will be used. When setting this argument to
1765 `tf.TensorShape(None)` (representing an unspecified shape), the variable
1766 can be assigned with values of different shapes.
1767 handle: (optional) The handle of a `tf.Variable`. If provided, only
1768 `trainable`, `shape`, `dtype`, and `handle` will be used to construct
1769 this `tf.Variable`.
1770 experimental_enable_variable_lifting: Whether to lift the variable out if
1771 it's in a `tf.function`. Default is `True`. When this argument
1772 is `True`, variable creation will follow the behavior and
1773 restrictions described
1774 [here](https://www.tensorflow.org/guide/function#creating_tfvariables).
1775 If this argument is `False`, that description doesn't apply,
1776 and you can freely create and use the variable in the
1777 `tf.function`, as if it's a "mutable `tf.Tensor`". You can't
1778 return the variable though.
1780 Raises:
1781 ValueError: If the initial value is not specified, or does not have a
1782 shape and `validate_shape` is `True`.
1784 @compatibility(eager)
1785 When Eager Execution is enabled, the default for the `collections` argument
1786 is `None`, which signifies that this `Variable` will not be added to any
1787 collections.
1788 @end_compatibility
1789 """
1790 if variable_def:
1791 if initial_value is not None:
1792 raise ValueError(f"The variable_def and initial_value args to "
1793 f"`tf.Variable` are mutually exclusive, but got both: "
1794 f"variable_def={variable_def},\n"
1795 f"initial_value={initial_value}")
1796 if context.executing_eagerly():
1797 raise ValueError(f"Creating a `tf.Variable` with a `variable_def` arg "
1798 f"is not supported when eager execution is enabled. "
1799 f"Got: variable_def={variable_def}")
1800 self._init_from_proto(
1801 variable_def,
1802 import_scope=import_scope,
1803 validate_shape=validate_shape)
1804 elif handle is not None:
1805 self._init_from_handle(trainable=trainable,
1806 shape=shape,
1807 dtype=dtype,
1808 handle=handle)
1809 else:
1810 self._init_from_args(
1811 initial_value=initial_value,
1812 trainable=trainable,
1813 collections=collections,
1814 caching_device=caching_device,
1815 name=name,
1816 dtype=dtype,
1817 constraint=constraint,
1818 synchronization=synchronization,
1819 aggregation=aggregation,
1820 shape=shape,
1821 distribute_strategy=distribute_strategy,
1822 validate_shape=validate_shape,
1823 experimental_enable_variable_lifting=experimental_enable_variable_lifting,
1824 )
1826 # CompositeTensor method
1827 @property
1828 def _type_spec(self):
1829 return VariableSpec.from_value(self)
1831 # CompositeTensor method
1832 def _shape_invariant_to_type_spec(self, shape):
1833 return VariableSpec(shape, self.dtype, self.trainable)
1835 # CompositeTensorGradient protocol
1836 __composite_gradient__ = ResourceVariableGradient()
1838 def _init_from_args(
1839 self,
1840 initial_value=None,
1841 trainable=None,
1842 collections=None,
1843 caching_device=None,
1844 name=None,
1845 dtype=None,
1846 constraint=None,
1847 synchronization=None,
1848 aggregation=None,
1849 distribute_strategy=None,
1850 shape=None,
1851 validate_shape=True,
1852 experimental_enable_variable_lifting=None,
1853 ):
1854 """Creates a variable.
1856 Args:
1857 initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1858 which is the initial value for the Variable. The initial value must have
1859 a shape specified unless `validate_shape` is set to False. Can also be a
1860 callable with no argument that returns the initial value when called.
1861 (Note that initializer functions from init_ops.py must first be bound to
1862 a shape before being used here.)
1863 trainable: If `True`, the default, also adds the variable to the graph
1864 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1865 the default list of variables to use by the `Optimizer` classes.
1866 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
1867 which case it defaults to `False`.
1868 collections: List of graph collections keys. The new variable is added to
1869 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1870 caching_device: Optional device string or function describing where the
1871 Variable should be cached for reading. Defaults to the Variable's
1872 device. If not `None`, caches on another device. Typical use is to
1873 cache on the device where the Ops using the Variable reside, to
1874 deduplicate copying through `Switch` and other conditional statements.
1875 name: Optional name for the variable. Defaults to `'Variable'` and gets
1876 uniquified automatically.
1877 dtype: If set, initial_value will be converted to the given type. If None,
1878 either the datatype will be kept (if initial_value is a Tensor) or
1879 float32 will be used (if it is a Python object convertible to a Tensor).
1880 constraint: An optional projection function to be applied to the variable
1881 after being updated by an `Optimizer` (e.g. used to implement norm
1882 constraints or value constraints for layer weights). The function must
1883 take as input the unprojected Tensor representing the value of the
1884 variable and return the Tensor for the projected value (which must have
1885 the same shape). Constraints are not safe to use when doing asynchronous
1886 distributed training.
1887 synchronization: Indicates when a distributed a variable will be
1888 aggregated. Accepted values are constants defined in the class
1889 `tf.VariableSynchronization`. By default the synchronization is set to
1890 `AUTO` and the current `DistributionStrategy` chooses when to
1891 synchronize.
1892 aggregation: Indicates how a distributed variable will be aggregated.
1893 Accepted values are constants defined in the class
1894 `tf.VariableAggregation`.
1895 distribute_strategy: DistributionStrategy under which this variable was
1896 created.
1897 shape: (optional) The shape of this variable. If None, the shape of
1898 `initial_value` will be used. When setting this argument to
1899 `tf.TensorShape(None)` (representing an unspecified shape), the variable
1900 can be assigned with values of different shapes.
1901 validate_shape: If `False`, allows the variable to be initialized with a
1902 value of unknown shape. If `True`, the default, the shape of
1903 `initial_value` must be known.
1904 experimental_enable_variable_lifting: Whether to lift the variable out if
1905 it's in a `tf.function`. Default is `True`. When this argument
1906 is `True`, variable creation will follow the behavior and
1907 restrictions described
1908 [here](https://www.tensorflow.org/guide/function#creating_tfvariables).
1909 If this argument is `False`, that description doesn't apply,
1910 and you can freely create and use the variable in the
1911 `tf.function`, as if it's a "mutable `tf.Tensor`". You can't
1912 return the variable though.
1914 Raises:
1915 ValueError: If the initial value is not specified, or does not have a
1916 shape and `validate_shape` is `True`.
1918 @compatibility(eager)
1919 When Eager Execution is enabled, variables are never added to collections.
1920 It is not implicitly added to the `GLOBAL_VARIABLES` or
1921 `TRAINABLE_VARIABLES` collections, and the `collections` argument is
1922 ignored.
1923 @end_compatibility
1924 """
1925 synchronization, aggregation, trainable = (
1926 variables.validate_synchronization_aggregation_trainable(
1927 synchronization, aggregation, trainable, name))
1928 if experimental_enable_variable_lifting is None:
1929 experimental_enable_variable_lifting = True
1930 if initial_value is None:
1931 raise ValueError("The `initial_value` arg to `tf.Variable` must "
1932 "be specified except when you are not providing a "
1933 "`variable_def`. You provided neither.")
1934 init_from_fn = callable(initial_value)
1936 if isinstance(initial_value, ops.Tensor) and hasattr(
1937 initial_value, "graph") and initial_value.graph.building_function:
1938 raise ValueError(f"Argument `initial_value` ({initial_value}) could not "
1939 "be lifted out of a `tf.function`. "
1940 f"(Tried to create variable with name='{name}'). "
1941 "To avoid this error, when constructing `tf.Variable`s "
1942 "inside of `tf.function` you can create the "
1943 "`initial_value` tensor in a "
1944 "`tf.init_scope` or pass a callable `initial_value` "
1945 "(e.g., `tf.Variable(lambda : "
1946 "tf.truncated_normal([10, 40]))`). "
1947 "Please file a feature request if this "
1948 "restriction inconveniences you.")
1950 if collections is None:
1951 collections = [ops.GraphKeys.GLOBAL_VARIABLES]
1952 if not isinstance(collections, (list, tuple, set)):
1953 raise ValueError(
1954 f"collections argument to Variable constructor must be a list, "
1955 f"tuple, or set. Got {collections} of type {type(collections)}")
1956 if constraint is not None and not callable(constraint):
1957 raise ValueError(f"Argument `constraint` must be None or a callable. "
1958 f"a callable. Got a {type(constraint)}: {constraint}")
1960 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
1961 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
1962 with ops.init_scope():
1963 self._in_graph_mode = not context.executing_eagerly()
1964 if experimental_enable_variable_lifting:
1965 maybe_init_scope = ops.init_scope
1966 else:
1967 maybe_init_scope = contextlib.nullcontext
1968 with maybe_init_scope():
1969 with ops.name_scope(
1970 name,
1971 "Variable", [] if init_from_fn else [initial_value],
1972 skip_on_eager=False) as name:
1973 # pylint: disable=protected-access
1974 handle_name = ops.name_from_scope_name(name)
1975 if self._in_graph_mode:
1976 shared_name = handle_name
1977 unique_id = shared_name
1978 else:
1979 # When in eager mode, use a uid for the shared_name, to prevent
1980 # accidental sharing.
1981 unique_id = "%s_%d" % (handle_name, ops.uid())
1982 shared_name = None # Never shared
1983 # Use attr_scope and device(None) to simulate the behavior of
1984 # colocate_with when the variable we want to colocate with doesn't
1985 # yet exist.
1986 device_context_manager = (
1987 ops.device if self._in_graph_mode else ops.NullContextmanager)
1988 attr = attr_value_pb2.AttrValue(
1989 list=attr_value_pb2.AttrValue.ListValue(
1990 s=[compat.as_bytes("loc:@%s" % handle_name)]))
1991 with ops.get_default_graph()._attr_scope({"_class": attr}):
1992 with ops.name_scope("Initializer"), device_context_manager(None):
1993 if init_from_fn:
1994 initial_value = initial_value()
1995 if isinstance(initial_value, trackable.CheckpointInitialValue):
1996 self._maybe_initialize_trackable()
1997 self._update_uid = initial_value.checkpoint_position.restore_uid
1998 initial_value = initial_value.wrapped_value
1999 initial_value = ops.convert_to_tensor(
2000 initial_value, name="initial_value", dtype=dtype)
2001 if shape is not None:
2002 if not initial_value.shape.is_compatible_with(shape):
2003 raise ValueError(
2004 f"In this `tf.Variable` creation, the initial value's shape "
2005 f"({initial_value.shape}) is not compatible with "
2006 f"the explicitly supplied `shape` argument ({shape}).")
2007 else:
2008 shape = initial_value.shape
2009 handle = eager_safe_variable_handle(
2010 initial_value=initial_value,
2011 shape=shape,
2012 shared_name=shared_name,
2013 name=name,
2014 graph_mode=self._in_graph_mode)
2015 handle._parent_trackable = weakref.ref(self)
2016 handle._name = handle_name + ":0"
2017 handle._unique_id = unique_id
2018 # pylint: disable=protected-access
2019 if (self._in_graph_mode and initial_value is not None and
2020 initial_value.op._get_control_flow_context() is not None):
2021 raise ValueError(
2022 f"The `initial_value` passed to `tf.Variable` {name} is from "
2023 f"inside a control-flow construct, such as a loop or "
2024 f"conditional. When creating a "
2025 f"`tf.Variable` inside a loop or conditional, use a lambda as "
2026 f"the `initial_value`. Got: initial_value=({initial_value})")
2027 # pylint: enable=protected-access
2028 dtype = initial_value.dtype.base_dtype
2030 if self._in_graph_mode:
2031 with ops.name_scope("IsInitialized"):
2032 is_initialized_op = (
2033 gen_resource_variable_ops.var_is_initialized_op(handle))
2034 if initial_value is not None:
2035 # pylint: disable=g-backslash-continuation
2036 with ops.name_scope("Assign") as n, \
2037 ops.colocate_with(None, ignore_existing=True), \
2038 ops.device(handle.device):
2039 # pylint: disable=protected-access
2040 initializer_op = (
2041 gen_resource_variable_ops.assign_variable_op(
2042 handle,
2043 variables._try_guard_against_uninitialized_dependencies(
2044 name, initial_value),
2045 name=n))
2046 # pylint: enable=protected-access
2047 # pylint: enable=g-backslash-continuation
2048 with ops.name_scope("Read"):
2049 # Manually assign reads to the handle's device to avoid log
2050 # messages.
2051 with ops.device(handle.device):
2052 value = gen_resource_variable_ops.read_variable_op(handle, dtype)
2053 _maybe_set_handle_data(dtype, handle, value)
2054 graph_element = value
2055 if caching_device is not None:
2056 # Variables may be created in a tf.device() or ops.colocate_with()
2057 # context. At the same time, users would expect caching device to
2058 # be independent of this context, and/or would not expect the
2059 # current device context to be merged with the caching device
2060 # spec. Therefore we reset the colocation stack before creating
2061 # the cached value. Note that resetting the colocation stack will
2062 # also reset the device stack.
2063 with ops.colocate_with(None, ignore_existing=True):
2064 with ops.device(caching_device):
2065 cached_value = array_ops.identity(value)
2066 else:
2067 cached_value = None
2068 else:
2069 gen_resource_variable_ops.assign_variable_op(handle, initial_value)
2070 is_initialized_op = None
2071 initializer_op = None
2072 graph_element = None
2073 if caching_device:
2074 with ops.device(caching_device):
2075 cached_value = gen_resource_variable_ops.read_variable_op(
2076 handle, dtype)
2077 _maybe_set_handle_data(dtype, handle, cached_value)
2078 else:
2079 cached_value = None
2081 if cached_value is not None:
2082 # Store the variable object so that the original variable can be
2083 # accessed to generate functions that are compatible with SavedModel.
2084 cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access
2086 if self._in_graph_mode:
2087 # Eager variables are only added to collections if they are part of an
2088 # eager variable store (otherwise in an interactive session they would
2089 # hog memory and cause OOM). This is done in ops/variable_scope.py.
2090 ops.add_to_collections(collections, self)
2091 elif ops.GraphKeys.GLOBAL_STEP in collections:
2092 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
2093 initial_value = initial_value if self._in_graph_mode else None
2094 super(ResourceVariable, self).__init__(
2095 trainable=trainable,
2096 shape=shape,
2097 dtype=dtype,
2098 handle=handle,
2099 synchronization=synchronization,
2100 constraint=constraint,
2101 aggregation=aggregation,
2102 distribute_strategy=distribute_strategy,
2103 name=name,
2104 unique_id=unique_id,
2105 handle_name=handle_name,
2106 graph_element=graph_element,
2107 initial_value=initial_value,
2108 initializer_op=initializer_op,
2109 is_initialized_op=is_initialized_op,
2110 cached_value=cached_value,
2111 caching_device=caching_device,
2112 validate_shape=validate_shape,
2113 )
2115 def _init_from_proto(self,
2116 variable_def,
2117 import_scope=None,
2118 validate_shape=True):
2119 """Initializes from `VariableDef` proto."""
2120 # Note that init_from_proto is currently not supported in Eager mode.
2121 assert not context.executing_eagerly()
2122 self._in_graph_mode = True
2123 assert isinstance(variable_def, variable_pb2.VariableDef)
2124 if not variable_def.is_resource:
2125 raise ValueError(f"The `variable_def` you passed to `tf.Variable` is "
2126 f"Trying to restore a TF 1.x Reference Variable "
2127 f"as a TF 2.x ResourceVariable. This is unsupported. "
2128 f"Got variable_def={variable_def}")
2130 # Create from variable_def.
2131 g = ops.get_default_graph()
2132 self._handle = g.as_graph_element(
2133 ops.prepend_name_scope(
2134 variable_def.variable_name, import_scope=import_scope),
2135 allow_operation=False)
2136 self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
2137 self._handle_name = self._handle.name
2138 self._unique_id = self._handle_name
2139 self._initializer_op = g.as_graph_element(
2140 ops.prepend_name_scope(
2141 variable_def.initializer_name, import_scope=import_scope))
2142 # Check whether initial_value_name exists for backwards compatibility.
2143 if (hasattr(variable_def, "initial_value_name") and
2144 variable_def.initial_value_name):
2145 self._initial_value = g.as_graph_element(
2146 ops.prepend_name_scope(
2147 variable_def.initial_value_name, import_scope=import_scope))
2148 else:
2149 self._initial_value = None
2150 synchronization, aggregation, trainable = (
2151 variables.validate_synchronization_aggregation_trainable(
2152 variable_def.synchronization, variable_def.aggregation,
2153 variable_def.trainable, variable_def.variable_name))
2154 self._synchronization = synchronization
2155 self._aggregation = aggregation
2156 self._trainable = trainable
2157 if variable_def.snapshot_name:
2158 snapshot = g.as_graph_element(
2159 ops.prepend_name_scope(
2160 variable_def.snapshot_name, import_scope=import_scope))
2161 if snapshot.op.type != "ReadVariableOp":
2162 self._cached_value = snapshot
2163 else:
2164 self._cached_value = None
2165 while snapshot.op.type != "ReadVariableOp":
2166 snapshot = snapshot.op.inputs[0]
2167 self._graph_element = snapshot
2168 else:
2169 self._cached_value = None
2170 # Legacy case for protos without the snapshot name; assume it's the
2171 # following.
2172 self._graph_element = g.get_tensor_by_name(self._handle.op.name +
2173 "/Read/ReadVariableOp:0")
2174 if variable_def.HasField("save_slice_info_def"):
2175 self._save_slice_info = variables.Variable.SaveSliceInfo(
2176 save_slice_info_def=variable_def.save_slice_info_def,
2177 import_scope=import_scope)
2178 else:
2179 self._save_slice_info = None
2180 self._caching_device = None
2181 self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
2182 self._constraint = None
2183 self._validate_shape = validate_shape
2185 def _init_from_handle(self,
2186 trainable=None,
2187 shape=None,
2188 dtype=None,
2189 handle=None):
2190 handle_data = get_eager_safe_handle_data(handle)
2191 if not handle_data.is_set:
2192 # The handle may not have the handle shape and dtype if it was created
2193 # using tf.placeholder.
2194 handle_data = handle_data_util.create_handle_data(shape, dtype)
2195 handle_data_util.set_handle_data(handle, handle_data)
2196 # pylint: disable=protected-access
2197 if hasattr(handle, "_name") and isinstance(handle._name, str):
2198 handle_name = handle._name.rstrip(":0")
2199 else:
2200 handle_name = None
2201 # pylint: enable=protected-access
2202 unique_id = getattr(handle, "_unique_id", None)
2203 super().__init__(
2204 trainable=trainable, shape=shape, dtype=dtype, handle=handle,
2205 unique_id=unique_id, handle_name=handle_name)
2208class UninitializedVariable(BaseResourceVariable):
2209 """A variable with no initializer."""
2211 def __init__( # pylint: disable=super-init-not-called
2212 self,
2213 trainable=None,
2214 caching_device=None,
2215 name=None,
2216 shape=None,
2217 dtype=None,
2218 constraint=None,
2219 synchronization=None,
2220 aggregation=None,
2221 extra_handle_data=None,
2222 distribute_strategy=None,
2223 **unused_kwargs):
2224 """Creates the variable handle.
2226 Args:
2227 trainable: If `True`, GradientTapes automatically watch uses of this
2228 Variable.
2229 caching_device: Optional device string or function describing where the
2230 Variable should be cached for reading. Defaults to the Variable's
2231 device. If not `None`, caches on another device. Typical use is to
2232 cache on the device where the Ops using the Variable reside, to
2233 deduplicate copying through `Switch` and other conditional statements.
2234 name: Optional name for the variable. Defaults to `'Variable'` and gets
2235 uniquified automatically.
2236 shape: The variable's shape.
2237 dtype: The variable's dtype.
2238 constraint: An optional projection function to be applied to the variable
2239 after being updated by an `Optimizer` (e.g. used to implement norm
2240 constraints or value constraints for layer weights). The function must
2241 take as input the unprojected Tensor representing the value of the
2242 variable and return the Tensor for the projected value (which must have
2243 the same shape). Constraints are not safe to use when doing asynchronous
2244 distributed training.
2245 synchronization: Indicates when a distributed a variable will be
2246 aggregated. Accepted values are constants defined in the class
2247 `tf.VariableSynchronization`. By default the synchronization is set to
2248 `AUTO` and the current `DistributionStrategy` chooses when to
2249 synchronize.
2250 aggregation: Indicates how a distributed variable will be aggregated.
2251 Accepted values are constants defined in the class
2252 `tf.VariableAggregation`.
2253 extra_handle_data: Optional, another resource handle or Tensor with handle
2254 data to merge with `shape` and `dtype`.
2255 distribute_strategy: The tf.distribute.Strategy this variable is being
2256 created inside of.
2257 """
2258 with ops.init_scope():
2259 # Here we are detecting eagerness within an init_scope, so this will only
2260 # be true when we are running in TF1 graph mode.
2261 self._in_graph_mode = not context.executing_eagerly()
2262 with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
2263 handle_name = ops.name_from_scope_name(name)
2264 if self._in_graph_mode:
2265 shared_name = handle_name
2266 unique_id = shared_name
2267 else:
2268 unique_id = "%s_%d" % (handle_name, ops.uid())
2269 shared_name = None # Never shared
2270 handle = _variable_handle_from_shape_and_dtype(
2271 shape=shape,
2272 dtype=dtype,
2273 shared_name=shared_name,
2274 name=name,
2275 graph_mode=self._in_graph_mode,
2276 initial_value=extra_handle_data)
2277 handle._parent_trackable = weakref.ref(self)
2278 handle._name = handle_name + ":0"
2279 handle._unique_id = unique_id
2281 if self._in_graph_mode:
2282 # We only need to add the read_variable_op in TF1.
2283 with ops.name_scope("Read"):
2284 # Manually assign reads to the handle's device to avoid log
2285 # messages.
2286 with ops.device(handle.device):
2287 value = gen_resource_variable_ops.read_variable_op(handle, dtype)
2288 _maybe_set_handle_data(dtype, handle, value)
2289 graph_element = value
2290 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
2291 # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable,
2292 # because retraining or frozen use of imported SavedModels is
2293 # controlled at higher levels of model building.
2294 else:
2295 graph_element = None
2296 super(UninitializedVariable, self).__init__(
2297 distribute_strategy=distribute_strategy,
2298 shape=shape,
2299 dtype=dtype,
2300 unique_id=unique_id,
2301 handle_name=handle_name,
2302 constraint=constraint,
2303 handle=handle,
2304 graph_element=graph_element,
2305 trainable=trainable,
2306 synchronization=synchronization,
2307 aggregation=aggregation,
2308 in_graph_mode=self._in_graph_mode)
2311_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable)
2312math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access
2315def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
2316 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
2319# Register a conversion function which reads the value of the variable,
2320# allowing instances of the class to be used as tensors.
2321tensor_conversion_registry.register_tensor_conversion_function(
2322 BaseResourceVariable, _dense_var_to_tensor)
2325class _UnreadVariable(BaseResourceVariable):
2326 """Represents a future for a read of a variable.
2328 Pretends to be the tensor if anyone looks.
2329 """
2331 def __init__(self, handle, dtype, shape, in_graph_mode, parent_op, unique_id):
2332 if isinstance(handle, ops.EagerTensor):
2333 handle_name = ""
2334 else:
2335 handle_name = handle.name
2336 # Only create a graph_element if we're in session.run-land as only
2337 # session.run requires a preexisting tensor to evaluate. Otherwise we can
2338 # avoid accidentally reading the variable.
2339 if context.executing_eagerly() or ops.inside_function():
2340 graph_element = None
2341 else:
2342 with ops.control_dependencies([parent_op]):
2343 graph_element = gen_resource_variable_ops.read_variable_op(
2344 handle, dtype)
2345 _maybe_set_handle_data(dtype, handle, graph_element)
2346 super(_UnreadVariable, self).__init__(
2347 handle=handle,
2348 shape=shape,
2349 handle_name=handle_name,
2350 unique_id=unique_id,
2351 dtype=dtype,
2352 graph_element=graph_element)
2353 self._parent_op = parent_op
2355 @property
2356 def name(self):
2357 if self._in_graph_mode:
2358 return self._parent_op.name
2359 else:
2360 return "UnreadVariable"
2362 def value(self):
2363 return self._read_variable_op()
2365 def read_value(self):
2366 return self._read_variable_op()
2368 def _read_variable_op(self):
2369 with ops.control_dependencies([self._parent_op]):
2370 result = gen_resource_variable_ops.read_variable_op(
2371 self._handle, self._dtype)
2372 _maybe_set_handle_data(self._dtype, self._handle, result)
2373 return result
2375 def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
2376 with ops.control_dependencies([self._parent_op]):
2377 return super(_UnreadVariable, self).assign_sub(delta, use_locking, name,
2378 read_value)
2380 def assign_add(self, delta, use_locking=None, name=None, read_value=True):
2381 with ops.control_dependencies([self._parent_op]):
2382 return super(_UnreadVariable, self).assign_add(delta, use_locking, name,
2383 read_value)
2385 def assign(self, value, use_locking=None, name=None, read_value=True):
2386 with ops.control_dependencies([self._parent_op]):
2387 return super(_UnreadVariable, self).assign(value, use_locking, name,
2388 read_value)
2390 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
2391 with ops.control_dependencies([self._parent_op]):
2392 return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking,
2393 name)
2395 def scatter_add(self, sparse_delta, use_locking=False, name=None):
2396 with ops.control_dependencies([self._parent_op]):
2397 return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking,
2398 name)
2400 def scatter_max(self, sparse_delta, use_locking=False, name=None):
2401 with ops.control_dependencies([self._parent_op]):
2402 return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking,
2403 name)
2405 def scatter_min(self, sparse_delta, use_locking=False, name=None):
2406 with ops.control_dependencies([self._parent_op]):
2407 return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking,
2408 name)
2410 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
2411 with ops.control_dependencies([self._parent_op]):
2412 return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking,
2413 name)
2415 def scatter_div(self, sparse_delta, use_locking=False, name=None):
2416 with ops.control_dependencies([self._parent_op]):
2417 return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking,
2418 name)
2420 def scatter_update(self, sparse_delta, use_locking=False, name=None):
2421 with ops.control_dependencies([self._parent_op]):
2422 return super(_UnreadVariable,
2423 self).scatter_update(sparse_delta, use_locking, name)
2425 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
2426 with ops.control_dependencies([self._parent_op]):
2427 return super(_UnreadVariable,
2428 self).batch_scatter_update(sparse_delta, use_locking, name)
2430 def scatter_nd_sub(self, indices, updates, name=None):
2431 with ops.control_dependencies([self._parent_op]):
2432 return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name)
2434 def scatter_nd_add(self, indices, updates, name=None):
2435 with ops.control_dependencies([self._parent_op]):
2436 return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name)
2438 def scatter_nd_update(self, indices, updates, name=None):
2439 with ops.control_dependencies([self._parent_op]):
2440 return super(_UnreadVariable,
2441 self).scatter_nd_update(indices, updates, name)
2443 def scatter_nd_max(self, indices, updates, name=None):
2444 with ops.control_dependencies([self._parent_op]):
2445 return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name)
2447 def scatter_nd_min(self, indices, updates, name=None):
2448 with ops.control_dependencies([self._parent_op]):
2449 return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name)
2451 @property
2452 def op(self):
2453 """The op for this variable."""
2454 return self._parent_op
2457@ops.RegisterGradient("ReadVariableOp")
2458def _ReadGrad(_, grad):
2459 """Gradient for read op."""
2460 return grad
2463def variable_shape(handle, out_type=dtypes.int32):
2464 handle_data = get_eager_safe_handle_data(handle)
2465 if handle_data is None or not handle_data.is_set:
2466 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
2467 shape_proto = handle_data.shape_and_type[0].shape
2468 if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim):
2469 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
2470 return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type)
2473@ops.RegisterGradient("ResourceGather")
2474def _GatherGrad(op, grad):
2475 """Gradient for gather op."""
2476 # Build appropriately shaped IndexedSlices
2477 handle = op.inputs[0]
2478 indices = op.inputs[1]
2479 params_shape = variable_shape(handle)
2480 size = array_ops.expand_dims(array_ops.size(indices), 0)
2481 values_shape = array_ops.concat([size, params_shape[1:]], 0)
2482 values = array_ops.reshape(grad, values_shape)
2483 indices = array_ops.reshape(indices, size)
2484 return (indexed_slices.IndexedSlices(values, indices, params_shape), None)
2487@tf_export("__internal__.ops.is_resource_variable", v1=[])
2488def is_resource_variable(var):
2489 """"Returns True if `var` is to be considered a ResourceVariable."""
2490 return isinstance(var, BaseResourceVariable) or hasattr(
2491 var, "_should_act_as_resource_variable")
2494def copy_to_graph_uninitialized(var):
2495 """Copies an existing variable to a new graph, with no initializer."""
2496 # Like ResourceVariable.__deepcopy__, but does not set an initializer on the
2497 # new variable.
2498 # pylint: disable=protected-access
2499 new_variable = UninitializedVariable(
2500 trainable=var.trainable,
2501 constraint=var._constraint,
2502 shape=var.shape,
2503 dtype=var.dtype,
2504 name=var._shared_name,
2505 synchronization=var.synchronization,
2506 aggregation=var.aggregation,
2507 extra_handle_data=var.handle)
2508 new_variable._maybe_initialize_trackable()
2509 # pylint: enable=protected-access
2510 return new_variable
2513ops.NotDifferentiable("Assert")
2514ops.NotDifferentiable("VarIsInitializedOp")
2515ops.NotDifferentiable("VariableShape")
2518# TODO(b/246356867): This is the draft implementation. Currently VariableSpec is
2519# the only class using them. Move them to a separate file when necessary.
2520class StructurePattern:
2521 pass
2524class PLeaf(StructurePattern):
2525 """Represents a singleton leaf StructurePattern."""
2527 def __new__(cls):
2528 if not hasattr(cls, "instance"):
2529 cls.instance = super().__new__(cls)
2530 return cls.instance
2533class PList(StructurePattern):
2534 """Represents a list of StructurePatterns."""
2536 def __init__(self, *components):
2537 self.components = list(components)
2539 def __eq__(self, other):
2540 return isinstance(other, PList) and self.components == other.components
2543class VariableSpec(tensor_spec.DenseSpec):
2544 """Describes a tf.Variable.
2546 A `VariableSpec` provides metadata describing the `tf.Variable` objects
2547 accepted or returned by TensorFlow 2.x APIs.
2548 """
2550 __slots__ = ["trainable", "alias_id"]
2552 value_type = property(lambda self: ResourceVariable)
2554 def __init__(self, shape, dtype=dtypes.float32, trainable=True,
2555 alias_id=None):
2556 super(VariableSpec, self).__init__(shape, dtype=dtype)
2557 self.trainable = trainable
2558 self.alias_id = alias_id
2560 def is_compatible_with(self, spec_or_value):
2561 """Returns True if `spec_or_value` is compatible with this `VariableSpec`.
2563 `spec_or_value` is considered to be compatible with this `VariableSpec` if
2565 * `spec_or_value` is a `Variable` or `VariableSpec`,
2566 * their shapes are compatible,
2567 * their dtypes are the same,
2568 * they are both trainable or not trainable.
2569 * they share the same alias_id if `spec_or_value` is a `VariableSpec`.
2571 Example:
2573 >>> v = tf.Variable([1., 2., 3.])
2574 >>> spec = VariableSpec([None])
2575 >>> spec.is_compatible_with(v)
2576 True
2577 >>> v = tf.Variable(1)
2578 >>> spec.is_compatible_with(v)
2579 False
2581 Args:
2582 spec_or_value: A VariableSpec or Variable to compare against.
2584 Returns:
2585 True if `spec_or_value` is compatible with this `VariableSpec`.
2586 """
2587 if not isinstance(spec_or_value, (type(self), self.value_type)):
2588 return False
2589 compatible = (self.shape.is_compatible_with(spec_or_value.shape) and
2590 self.dtype == spec_or_value.dtype and
2591 self.trainable == spec_or_value.trainable)
2592 if isinstance(spec_or_value, type(self)):
2593 # alias_id must be the same to be compatible.
2594 return compatible and self.alias_id == spec_or_value.alias_id
2595 return compatible
2597 @classmethod
2598 def from_value(cls, value):
2599 """Creates a `VariableSpec` from the given `Variable`.
2601 `value`'s shape, dtype, and trainable attributes will be used to create
2602 the new `VariableSpec`.
2604 Example:
2606 >>> v = tf.Variable([1., 2., 3.])
2607 >>> VariableSpec.from_value(v)
2608 VariableSpec(shape=(3,), dtype=tf.float32, trainable=True, alias_id=None)
2610 Args:
2611 value: A Variable.
2613 Returns:
2614 A `VariableSpec` created from `value`.
2615 """
2616 return cls(value.shape, dtype=value.dtype, trainable=value.trainable)
2618 def _to_components(self, value):
2619 return [value.handle]
2621 def _from_components(self, components):
2622 if not isinstance(components, (list, tuple)):
2623 raise TypeError(f"Components of a ResourceVariable must be a list or "
2624 f"tuple, got f{components} instead.")
2625 if len(components) != 1:
2626 raise ValueError(f"Components of a ResourceVariable must only contain "
2627 f"its resource handle, got f{components} instead.")
2628 handle = components[0]
2629 if not isinstance(handle, ops.Tensor) or handle.dtype != dtypes.resource:
2630 raise ValueError(f"The handle of a ResourceVariable must be a resource "
2631 f"tensor, got {handle} instead.")
2632 return ResourceVariable(trainable=self.trainable,
2633 shape=self.shape,
2634 dtype=self.dtype,
2635 handle=handle)
2637 @property
2638 def _component_specs(self):
2639 return [tensor_spec.TensorSpec([], dtypes.resource)]
2641 def _serialize(self):
2642 return self.shape, self.dtype, self.trainable, self.alias_id
2644 # TraceType method
2645 def is_subtype_of(self, other):
2646 if type(self) is not type(other):
2647 return False
2649 # Remove this once we add alias_id to all CompositeTensors with
2650 # ResourceVariable components.
2651 if self.alias_id is None and other.alias_id is None:
2652 return super().is_subtype_of(other)
2654 if self.alias_id is None or other.alias_id is None:
2655 raise NotImplementedError(f"VariableSpec.is_subtype_of doesn't support "
2656 f"alias_id=None, got self: {self} and other: "
2657 f"{other}.")
2659 return super().is_subtype_of(other)
2661 # TraceType method
2662 def most_specific_common_supertype(self, others):
2663 if any(type(self) is not type(other) for other in others):
2664 return None
2666 # It is a special case for tf.nest, which often takes CompositeTensors and
2667 # converts to TypeSpecs internally, such as tf.nest.assert_same_structure.
2668 if (self.alias_id is None and
2669 all(other.alias_id is None for other in others)):
2670 return super().most_specific_common_supertype(others)
2672 if self.alias_id is None or any(other.alias_id is None for other in others):
2673 raise NotImplementedError(f"VariableSpec.most_specific_common_supertype "
2674 f"doesn't support alias_id=None, got self: "
2675 f"{self} and others: {others}.")
2677 return super().most_specific_common_supertype(others)
2679 # TraceType method
2680 def placeholder_value(self, placeholder_context):
2681 if placeholder_context.unnest_only:
2682 return self
2684 name = self.name or placeholder_context.naming_scope
2685 context_graph = placeholder_context.context_graph
2686 if placeholder_context.has_placeholder(self.alias_id):
2687 # Get reference to the existing variable if alias_id already
2688 # exists in the PlaceholderContext
2689 variable = placeholder_context.get_placeholder(self.alias_id)
2690 else:
2691 spec = tensor_spec.TensorSpec([], dtypes.resource)
2692 spec_context = trace_type.InternalPlaceholderContext(
2693 context_graph.outer_graph)
2694 spec_context.update_naming_scope(name)
2695 placeholder = spec.placeholder_value(spec_context)
2696 variable = self._from_components([placeholder])
2697 # (b/262771247) ShardedVariable break without this and VariableSpecs
2698 # without alias_id are not TraceTypes.
2699 if self.alias_id is not None:
2700 placeholder_context.add_placeholder(self.alias_id, variable)
2701 # Capture the Variable's placeholder within the default graph of
2702 # the current thread.
2703 placeholder = context_graph.capture(variable.handle, name=name)
2704 placeholder.op._set_attr( # pylint: disable=protected-access
2705 "_user_specified_name",
2706 attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
2707 return variable
2709 def _to_tensors(self, value):
2710 assert isinstance(value, BaseResourceVariable)
2711 return [value.handle]
2713 def _get_structure(self):
2714 # shape, dtype, trainable, and alias_id are all leaves.
2715 return PList(PLeaf(), PLeaf(), PLeaf(), PLeaf())
2717 def __repr__(self):
2718 return (f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype!r}, "
2719 f"trainable={self.trainable!r}, alias_id={self.alias_id!r})")
2721 def __hash__(self):
2722 return hash((self.shape, self.dtype, self.trainable, self.alias_id))
2724 def __eq__(self, other):
2725 return (type(self) is type(other) and self.shape == other.shape and
2726 self.dtype == other.dtype and self.trainable == other.trainable and
2727 self.alias_id == other.alias_id)
2730nested_structure_coder.register_codec(
2731 nested_structure_coder.BuiltInTypeSpecCodec(
2732 VariableSpec, struct_pb2.TypeSpecProto.VARIABLE_SPEC
2733 )
2734)
2737_pywrap_utils.RegisterType("VariableSpec", VariableSpec)
2740def write_object_proto_for_resource_variable(resource_variable,
2741 proto,
2742 options,
2743 enforce_naming=True):
2744 """Writes additional information of the variable into the SavedObject proto.
2746 This allows users to define a `hook` to provide extra information of the
2747 variable to the SavedObject.
2749 For example, DistributedVariable class would fill in components in the
2750 distributed context.
2752 Args:
2753 resource_variable: A `ResourceVariable` or `DistributedValue` that has the
2754 information to be saved into the proto.
2755 proto: `SavedObject` proto to update.
2756 options: A `SaveOption` instance that configures save behavior.
2757 enforce_naming: A bool determining whether to check that names end in the
2758 expected string ':0'
2759 """
2760 proto.variable.SetInParent()
2761 if enforce_naming and not resource_variable.name.endswith(":0"):
2762 raise ValueError(f"Cowardly refusing to save variable "
2763 f"{resource_variable.name} because of "
2764 f"unexpected suffix in the name (expected ':0')"
2765 f"which won't be restored.")
2766 proto.variable.name = tensor_module.get_op_name(resource_variable.name)
2767 proto.variable.trainable = resource_variable.trainable
2768 proto.variable.dtype = resource_variable.dtype.as_datatype_enum
2769 proto.variable.synchronization = resource_variable.synchronization.value
2770 proto.variable.aggregation = resource_variable.aggregation.value
2771 proto.variable.shape.CopyFrom(resource_variable.shape.as_proto())
2772 if options.experimental_variable_policy._save_variable_devices( # pylint: disable=protected-access
2773 ):
2774 if hasattr(resource_variable, "device"):
2775 proto.variable.device = resource_variable.device