Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py: 41%
540 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Variable class."""
17import abc
18import enum
19import functools
20import itertools
21import os
23from tensorflow.core.framework import variable_pb2
24from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_conversion_registry
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import array_ops_stack
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import gen_math_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import state_ops
36from tensorflow.python.trackable import base as trackable
37from tensorflow.python.util import _pywrap_utils
38from tensorflow.python.util import object_identity
39from tensorflow.python.util import tf_should_use
40from tensorflow.python.util import traceback_utils
41from tensorflow.python.util.deprecation import deprecated
42from tensorflow.python.util.deprecation import deprecated_args
43from tensorflow.python.util.tf_export import tf_export
46def default_variable_creator_v2(_, **kwds):
47 del kwds
48 raise NotImplementedError("resource_variable_ops needs to be imported")
51def _make_getter(captured_getter, captured_previous):
52 """To avoid capturing loop variables."""
54 def getter(**kwargs):
55 return captured_getter(captured_previous, **kwargs)
57 return getter
60@tf_export("VariableSynchronization")
61class VariableSynchronization(enum.Enum):
62 """Indicates when a distributed variable will be synced.
64 * `AUTO`: Indicates that the synchronization will be determined by the current
65 `DistributionStrategy` (eg. With `MirroredStrategy` this would be
66 `ON_WRITE`).
67 * `NONE`: Indicates that there will only be one copy of the variable, so
68 there is no need to sync.
69 * `ON_WRITE`: Indicates that the variable will be updated across devices
70 every time it is written.
71 * `ON_READ`: Indicates that the variable will be aggregated across devices
72 when it is read (eg. when checkpointing or when evaluating an op that uses
73 the variable).
75 Example:
76 >>> temp_grad=[tf.Variable([0.], trainable=False,
77 ... synchronization=tf.VariableSynchronization.ON_READ,
78 ... aggregation=tf.VariableAggregation.MEAN
79 ... )]
80 """
81 AUTO = 0
82 NONE = 1
83 ON_WRITE = 2
84 ON_READ = 3
87# LINT.IfChange
88@tf_export("VariableAggregation", v1=[])
89class VariableAggregationV2(enum.Enum):
90 """Indicates how a distributed variable will be aggregated.
92 `tf.distribute.Strategy` distributes a model by making multiple copies
93 (called "replicas") acting on different elements of the input batch in a
94 data parallel model. When performing some variable-update operation,
95 for example `var.assign_add(x)`, in a model, we need to resolve how to combine
96 the different values for `x` computed in the different replicas.
98 * `NONE`: This is the default, giving an error if you use a
99 variable-update operation with multiple replicas.
100 * `SUM`: Add the updates across replicas.
101 * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas.
102 * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same
103 update, but we only want to perform the update once. Used, e.g., for the
104 global step counter.
106 For example:
108 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
109 >>> with strategy.scope():
110 ... v = tf.Variable(5.0, aggregation=tf.VariableAggregation.MEAN)
111 >>> @tf.function
112 ... def update_fn():
113 ... return v.assign_add(1.0)
114 >>> strategy.run(update_fn)
115 PerReplica:{
116 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
117 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
118 }
120 """
121 NONE = 0
122 SUM = 1
123 MEAN = 2
124 ONLY_FIRST_REPLICA = 3
126 def __hash__(self):
127 return hash(self.value)
129 def __eq__(self, other):
130 if self is other:
131 return True
132 elif isinstance(other, VariableAggregation):
133 return int(self.value) == int(other.value)
134 else:
135 return False
138@tf_export(v1=["VariableAggregation"])
139class VariableAggregation(enum.Enum):
140 NONE = 0
141 SUM = 1
142 MEAN = 2
143 ONLY_FIRST_REPLICA = 3
144 ONLY_FIRST_TOWER = 3 # DEPRECATED
146 def __hash__(self):
147 return hash(self.value)
150# LINT.ThenChange(//tensorflow/core/framework/variable.proto)
151#
152# Note that we are currently relying on the integer values of the Python enums
153# matching the integer values of the proto enums.
155VariableAggregation.__doc__ = (
156 VariableAggregationV2.__doc__ +
157 "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n ")
160def validate_synchronization_aggregation_trainable(synchronization, aggregation,
161 trainable, name):
162 """Given user-provided variable properties, sets defaults and validates."""
163 if aggregation is None:
164 aggregation = VariableAggregation.NONE
165 else:
166 if not isinstance(aggregation,
167 (VariableAggregation, VariableAggregationV2)):
168 try:
169 aggregation = VariableAggregationV2(aggregation)
170 except ValueError:
171 raise ValueError(
172 "Invalid variable aggregation mode: {} for variable: {}".format(
173 aggregation, name))
174 if synchronization is None:
175 synchronization = VariableSynchronization.AUTO
176 else:
177 try:
178 synchronization = VariableSynchronization(synchronization)
179 except ValueError:
180 raise ValueError(
181 "Invalid variable synchronization mode: {} for variable: {}".format(
182 synchronization, name))
183 if trainable is None:
184 trainable = synchronization != VariableSynchronization.ON_READ
185 return synchronization, aggregation, trainable
188class VariableMetaclass(abc.ABCMeta):
189 """Metaclass to allow construction of tf.Variable to be overridden."""
191 @traceback_utils.filter_traceback
192 def __call__(cls, *args, **kwargs):
193 if hasattr(cls, "_variable_call") and callable(cls._variable_call):
194 variable_call = cls._variable_call(*args, **kwargs)
195 if variable_call is not None:
196 return variable_call
197 return super(VariableMetaclass, cls).__call__(*args, **kwargs)
200@tf_export("Variable", v1=[])
201# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
202class Variable(trackable.Trackable, metaclass=VariableMetaclass):
203 """See the [variable guide](https://tensorflow.org/guide/variable).
205 A variable maintains shared, persistent state manipulated by a program.
207 The `Variable()` constructor requires an initial value for the variable, which
208 can be a `Tensor` of any type and shape. This initial value defines the type
209 and shape of the variable. After construction, the type and shape of the
210 variable are fixed. The value can be changed using one of the assign methods.
212 >>> v = tf.Variable(1.)
213 >>> v.assign(2.)
214 <tf.Variable ... shape=() dtype=float32, numpy=2.0>
215 >>> v.assign_add(0.5)
216 <tf.Variable ... shape=() dtype=float32, numpy=2.5>
218 The `shape` argument to `Variable`'s constructor allows you to construct a
219 variable with a less defined shape than its `initial_value`:
221 >>> v = tf.Variable(1., shape=tf.TensorShape(None))
222 >>> v.assign([[1.]])
223 <tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)>
225 Just like any `Tensor`, variables created with `Variable()` can be used as
226 inputs to operations. Additionally, all the operators overloaded for the
227 `Tensor` class are carried over to variables.
229 >>> w = tf.Variable([[1.], [2.]])
230 >>> x = tf.constant([[3., 4.]])
231 >>> tf.matmul(w, x)
232 <tf.Tensor:... shape=(2, 2), ... numpy=
233 array([[3., 4.],
234 [6., 8.]], dtype=float32)>
235 >>> tf.sigmoid(w + x)
236 <tf.Tensor:... shape=(2, 2), ...>
238 When building a machine learning model it is often convenient to distinguish
239 between variables holding trainable model parameters and other variables such
240 as a `step` variable used to count training steps. To make this easier, the
241 variable constructor supports a `trainable=<bool>`
242 parameter. `tf.GradientTape` watches trainable variables by default:
244 >>> with tf.GradientTape(persistent=True) as tape:
245 ... trainable = tf.Variable(1.)
246 ... non_trainable = tf.Variable(2., trainable=False)
247 ... x1 = trainable * 2.
248 ... x2 = non_trainable * 3.
249 >>> tape.gradient(x1, trainable)
250 <tf.Tensor:... shape=(), dtype=float32, numpy=2.0>
251 >>> assert tape.gradient(x2, non_trainable) is None # Unwatched
253 Variables are automatically tracked when assigned to attributes of types
254 inheriting from `tf.Module`.
256 >>> m = tf.Module()
257 >>> m.v = tf.Variable([1.])
258 >>> m.trainable_variables
259 (<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,)
261 This tracking then allows saving variable values to
262 [training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to
263 [SavedModels](https://www.tensorflow.org/guide/saved_model) which include
264 serialized TensorFlow graphs.
266 Variables are often captured and manipulated by `tf.function`s. This works the
267 same way the un-decorated function would have:
269 >>> v = tf.Variable(0.)
270 >>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1))
271 >>> read_and_decrement()
272 <tf.Tensor: shape=(), dtype=float32, numpy=-0.1>
273 >>> read_and_decrement()
274 <tf.Tensor: shape=(), dtype=float32, numpy=-0.2>
276 Variables created inside a `tf.function` must be owned outside the function
277 and be created only once:
279 >>> class M(tf.Module):
280 ... @tf.function
281 ... def __call__(self, x):
282 ... if not hasattr(self, "v"): # Or set self.v to None in __init__
283 ... self.v = tf.Variable(x)
284 ... return self.v * x
285 >>> m = M()
286 >>> m(2.)
287 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
288 >>> m(3.)
289 <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
290 >>> m.v
291 <tf.Variable ... shape=() dtype=float32, numpy=2.0>
293 See the `tf.function` documentation for details.
294 """
296 @deprecated_args(
297 None, "A variable's value can be manually cached by calling "
298 "tf.Variable.read_value() under a tf.device scope. The caching_device "
299 "argument does not work properly.", "caching_device")
300 def __init__(self,
301 initial_value=None,
302 trainable=None,
303 validate_shape=True,
304 caching_device=None,
305 name=None,
306 variable_def=None,
307 dtype=None,
308 import_scope=None,
309 constraint=None,
310 synchronization=VariableSynchronization.AUTO,
311 aggregation=VariableAggregation.NONE,
312 shape=None,
313 experimental_enable_variable_lifting=True,
314 ):
315 """Creates a new variable with value `initial_value`.
317 Args:
318 initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
319 which is the initial value for the Variable. The initial value must have
320 a shape specified unless `validate_shape` is set to False. Can also be a
321 callable with no argument that returns the initial value when called. In
322 that case, `dtype` must be specified. (Note that initializer functions
323 from init_ops.py must first be bound to a shape before being used here.)
324 trainable: If `True`, GradientTapes automatically watch uses of this
325 variable. Defaults to `True`, unless `synchronization` is set to
326 `ON_READ`, in which case it defaults to `False`.
327 validate_shape: If `False`, allows the variable to be initialized with a
328 value of unknown shape. If `True`, the default, the shape of
329 `initial_value` must be known.
330 caching_device: Note: This argument is only valid when using a v1-style
331 `Session`. Optional device string describing where the Variable should
332 be cached for reading. Defaults to the Variable's device. If not `None`,
333 caches on another device. Typical use is to cache on the device where
334 the Ops using the Variable reside, to deduplicate copying through
335 `Switch` and other conditional statements.
336 name: Optional name for the variable. Defaults to `'Variable'` and gets
337 uniquified automatically.
338 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
339 Variable object with its contents, referencing the variable's nodes in
340 the graph, which must already exist. The graph is not changed.
341 `variable_def` and the other arguments are mutually exclusive.
342 dtype: If set, initial_value will be converted to the given type. If
343 `None`, either the datatype will be kept (if `initial_value` is a
344 Tensor), or `convert_to_tensor` will decide.
345 import_scope: Optional `string`. Name scope to add to the `Variable.` Only
346 used when initializing from protocol buffer.
347 constraint: An optional projection function to be applied to the variable
348 after being updated by an `Optimizer` (e.g. used to implement norm
349 constraints or value constraints for layer weights). The function must
350 take as input the unprojected Tensor representing the value of the
351 variable and return the Tensor for the projected value (which must have
352 the same shape). Constraints are not safe to use when doing asynchronous
353 distributed training.
354 synchronization: Indicates when a distributed a variable will be
355 aggregated. Accepted values are constants defined in the class
356 `tf.VariableSynchronization`. By default the synchronization is set to
357 `AUTO` and the current `DistributionStrategy` chooses when to
358 synchronize.
359 aggregation: Indicates how a distributed variable will be aggregated.
360 Accepted values are constants defined in the class
361 `tf.VariableAggregation`.
362 shape: (optional) The shape of this variable. If None, the shape of
363 `initial_value` will be used. When setting this argument to
364 `tf.TensorShape(None)` (representing an unspecified shape), the variable
365 can be assigned with values of different shapes.
366 experimental_enable_variable_lifting: Whether to lift the variable out if
367 it's in a `tf.function`. Default is `True`. When this argument
368 is `True`, variable creation will follow the behavior and
369 restrictions described
370 [here](https://www.tensorflow.org/guide/function#creating_tfvariables).
371 If this argument is `False`, that description doesn't apply,
372 and you can freely create and use the variable in the
373 `tf.function`, as if it's a "mutable `tf.Tensor`". You can't
374 return the variable though.
376 Raises:
377 ValueError: If both `variable_def` and initial_value are specified.
378 ValueError: If the initial value is not specified, or does not have a
379 shape and `validate_shape` is `True`.
380 """
381 raise NotImplementedError
383 def __repr__(self):
384 raise NotImplementedError
386 def value(self):
387 """Returns the last snapshot of this variable.
389 You usually do not need to call this method as all ops that need the value
390 of the variable call it automatically through a `convert_to_tensor()` call.
392 Returns a `Tensor` which holds the value of the variable. You can not
393 assign a new value to this tensor as it is not a reference to the variable.
395 To avoid copies, if the consumer of the returned value is on the same device
396 as the variable, this actually returns the live value of the variable, not
397 a copy. Updates to the variable are seen by the consumer. If the consumer
398 is on a different device it will get a copy of the variable.
400 Returns:
401 A `Tensor` containing the value of the variable.
402 """
403 raise NotImplementedError
405 def read_value(self):
406 """Returns the value of this variable, read in the current context.
408 Can be different from value() if it's on another device, with control
409 dependencies, etc.
411 Returns:
412 A `Tensor` containing the value of the variable.
413 """
414 raise NotImplementedError
416 def set_shape(self, shape):
417 """Overrides the shape for this variable.
419 Args:
420 shape: the `TensorShape` representing the overridden shape.
421 """
422 raise NotImplementedError
424 @property
425 def trainable(self):
426 raise NotImplementedError
428 @property
429 def synchronization(self):
430 raise NotImplementedError
432 @property
433 def aggregation(self):
434 raise NotImplementedError
436 def eval(self, session=None):
437 """In a session, computes and returns the value of this variable.
439 This is not a graph construction method, it does not add ops to the graph.
441 This convenience method requires a session where the graph
442 containing this variable has been launched. If no session is
443 passed, the default session is used. See `tf.compat.v1.Session` for more
444 information on launching a graph and on sessions.
446 ```python
447 v = tf.Variable([1, 2])
448 init = tf.compat.v1.global_variables_initializer()
450 with tf.compat.v1.Session() as sess:
451 sess.run(init)
452 # Usage passing the session explicitly.
453 print(v.eval(sess))
454 # Usage with the default session. The 'with' block
455 # above makes 'sess' the default session.
456 print(v.eval())
457 ```
459 Args:
460 session: The session to use to evaluate this variable. If none, the
461 default session is used.
463 Returns:
464 A numpy `ndarray` with a copy of the value of this variable.
465 """
466 raise NotImplementedError
468 @deprecated(
469 None, "Use Variable.read_value. Variables in 2.X are initialized "
470 "automatically both in eager and graph (inside tf.defun) contexts.")
471 def initialized_value(self):
472 """Returns the value of the initialized variable.
474 You should use this instead of the variable itself to initialize another
475 variable with a value that depends on the value of this variable.
477 ```python
478 # Initialize 'v' with a random tensor.
479 v = tf.Variable(tf.random.truncated_normal([10, 40]))
480 # Use `initialized_value` to guarantee that `v` has been
481 # initialized before its value is used to initialize `w`.
482 # The random values are picked only once.
483 w = tf.Variable(v.initialized_value() * 2.0)
484 ```
486 Returns:
487 A `Tensor` holding the value of this variable after its initializer
488 has run.
489 """
490 raise NotImplementedError
492 @property
493 def initial_value(self):
494 """Returns the Tensor used as the initial value for the variable.
496 Note that this is different from `initialized_value()` which runs
497 the op that initializes the variable before returning its value.
498 This method returns the tensor that is used by the op that initializes
499 the variable.
501 Returns:
502 A `Tensor`.
503 """
504 raise NotImplementedError
506 @property
507 def constraint(self):
508 """Returns the constraint function associated with this variable.
510 Returns:
511 The constraint function that was passed to the variable constructor.
512 Can be `None` if no constraint was passed.
513 """
514 raise NotImplementedError
516 def assign(self, value, use_locking=False, name=None, read_value=True):
517 """Assigns a new value to the variable.
519 This is essentially a shortcut for `assign(self, value)`.
521 Args:
522 value: A `Tensor`. The new value for this variable.
523 use_locking: If `True`, use locking during the assignment.
524 name: The name of the operation to be created
525 read_value: if True, will return something which evaluates to the new
526 value of the variable; if False will return the assign op.
528 Returns:
529 The updated variable. If `read_value` is false, instead returns None in
530 Eager mode and the assign op in graph mode.
531 """
532 raise NotImplementedError
534 def assign_add(self, delta, use_locking=False, name=None, read_value=True):
535 """Adds a value to this variable.
537 This is essentially a shortcut for `assign_add(self, delta)`.
539 Args:
540 delta: A `Tensor`. The value to add to this variable.
541 use_locking: If `True`, use locking during the operation.
542 name: The name of the operation to be created
543 read_value: if True, will return something which evaluates to the new
544 value of the variable; if False will return the assign op.
546 Returns:
547 The updated variable. If `read_value` is false, instead returns None in
548 Eager mode and the assign op in graph mode.
549 """
550 raise NotImplementedError
552 def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
553 """Subtracts a value from this variable.
555 This is essentially a shortcut for `assign_sub(self, delta)`.
557 Args:
558 delta: A `Tensor`. The value to subtract from this variable.
559 use_locking: If `True`, use locking during the operation.
560 name: The name of the operation to be created
561 read_value: if True, will return something which evaluates to the new
562 value of the variable; if False will return the assign op.
564 Returns:
565 The updated variable. If `read_value` is false, instead returns None in
566 Eager mode and the assign op in graph mode.
567 """
568 raise NotImplementedError
570 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
571 """Subtracts `tf.IndexedSlices` from this variable.
573 Args:
574 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
575 use_locking: If `True`, use locking during the operation.
576 name: the name of the operation.
578 Returns:
579 The updated variable.
581 Raises:
582 TypeError: if `sparse_delta` is not an `IndexedSlices`.
583 """
584 raise NotImplementedError
586 def scatter_add(self, sparse_delta, use_locking=False, name=None):
587 """Adds `tf.IndexedSlices` to this variable.
589 Args:
590 sparse_delta: `tf.IndexedSlices` to be added to this variable.
591 use_locking: If `True`, use locking during the operation.
592 name: the name of the operation.
594 Returns:
595 The updated variable.
597 Raises:
598 TypeError: if `sparse_delta` is not an `IndexedSlices`.
599 """
600 raise NotImplementedError
602 def scatter_max(self, sparse_delta, use_locking=False, name=None):
603 """Updates this variable with the max of `tf.IndexedSlices` and itself.
605 Args:
606 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
607 variable.
608 use_locking: If `True`, use locking during the operation.
609 name: the name of the operation.
611 Returns:
612 The updated variable.
614 Raises:
615 TypeError: if `sparse_delta` is not an `IndexedSlices`.
616 """
617 raise NotImplementedError
619 def scatter_min(self, sparse_delta, use_locking=False, name=None):
620 """Updates this variable with the min of `tf.IndexedSlices` and itself.
622 Args:
623 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
624 variable.
625 use_locking: If `True`, use locking during the operation.
626 name: the name of the operation.
628 Returns:
629 The updated variable.
631 Raises:
632 TypeError: if `sparse_delta` is not an `IndexedSlices`.
633 """
634 raise NotImplementedError
636 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
637 """Multiply this variable by `tf.IndexedSlices`.
639 Args:
640 sparse_delta: `tf.IndexedSlices` to multiply this variable by.
641 use_locking: If `True`, use locking during the operation.
642 name: the name of the operation.
644 Returns:
645 The updated variable.
647 Raises:
648 TypeError: if `sparse_delta` is not an `IndexedSlices`.
649 """
650 raise NotImplementedError
652 def scatter_div(self, sparse_delta, use_locking=False, name=None):
653 """Divide this variable by `tf.IndexedSlices`.
655 Args:
656 sparse_delta: `tf.IndexedSlices` to divide this variable by.
657 use_locking: If `True`, use locking during the operation.
658 name: the name of the operation.
660 Returns:
661 The updated variable.
663 Raises:
664 TypeError: if `sparse_delta` is not an `IndexedSlices`.
665 """
666 raise NotImplementedError
668 def scatter_update(self, sparse_delta, use_locking=False, name=None):
669 """Assigns `tf.IndexedSlices` to this variable.
671 Args:
672 sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
673 use_locking: If `True`, use locking during the operation.
674 name: the name of the operation.
676 Returns:
677 The updated variable.
679 Raises:
680 TypeError: if `sparse_delta` is not an `IndexedSlices`.
681 """
682 raise NotImplementedError
684 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
685 """Assigns `tf.IndexedSlices` to this variable batch-wise.
687 Analogous to `batch_gather`. This assumes that this variable and the
688 sparse_delta IndexedSlices have a series of leading dimensions that are the
689 same for all of them, and the updates are performed on the last dimension of
690 indices. In other words, the dimensions should be the following:
692 `num_prefix_dims = sparse_delta.indices.ndims - 1`
693 `batch_dim = num_prefix_dims + 1`
694 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
695 batch_dim:]`
697 where
699 `sparse_delta.updates.shape[:num_prefix_dims]`
700 `== sparse_delta.indices.shape[:num_prefix_dims]`
701 `== var.shape[:num_prefix_dims]`
703 And the operation performed can be expressed as:
705 `var[i_1, ..., i_n,
706 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
707 i_1, ..., i_n, j]`
709 When sparse_delta.indices is a 1D tensor, this operation is equivalent to
710 `scatter_update`.
712 To avoid this operation one can looping over the first `ndims` of the
713 variable and using `scatter_update` on the subtensors that result of slicing
714 the first dimension. This is a valid option for `ndims = 1`, but less
715 efficient than this implementation.
717 Args:
718 sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
719 use_locking: If `True`, use locking during the operation.
720 name: the name of the operation.
722 Returns:
723 The updated variable.
725 Raises:
726 TypeError: if `sparse_delta` is not an `IndexedSlices`.
727 """
728 raise NotImplementedError
730 def scatter_nd_sub(self, indices, updates, name=None):
731 """Applies sparse subtraction to individual values or slices in a Variable.
733 Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
735 `indices` must be integer tensor, containing indices into self.
736 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
738 The innermost dimension of `indices` (with length `K`) corresponds to
739 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
740 dimension of self.
742 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
744 ```
745 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
746 ```
748 For example, say we want to add 4 scattered elements to a rank-1 tensor to
749 8 elements. In Python, that update would look like this:
751 ```python
752 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
753 indices = tf.constant([[4], [3], [1] ,[7]])
754 updates = tf.constant([9, 10, 11, 12])
755 v.scatter_nd_sub(indices, updates)
756 print(v)
757 ```
759 After the update `v` would look like this:
761 [1, -9, 3, -6, -4, 6, 7, -4]
763 See `tf.scatter_nd` for more details about how to make updates to
764 slices.
766 Args:
767 indices: The indices to be used in the operation.
768 updates: The values to be used in the operation.
769 name: the name of the operation.
771 Returns:
772 The updated variable.
773 """
774 raise NotImplementedError
776 def scatter_nd_add(self, indices, updates, name=None):
777 """Applies sparse addition to individual values or slices in a Variable.
779 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
781 `indices` must be integer tensor, containing indices into self.
782 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
784 The innermost dimension of `indices` (with length `K`) corresponds to
785 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
786 dimension of self.
788 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
790 ```
791 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
792 ```
794 For example, say we want to add 4 scattered elements to a rank-1 tensor to
795 8 elements. In Python, that update would look like this:
797 ```python
798 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
799 indices = tf.constant([[4], [3], [1] ,[7]])
800 updates = tf.constant([9, 10, 11, 12])
801 v.scatter_nd_add(indices, updates)
802 print(v)
803 ```
805 The resulting update to v would look like this:
807 [1, 13, 3, 14, 14, 6, 7, 20]
809 See `tf.scatter_nd` for more details about how to make updates to
810 slices.
812 Args:
813 indices: The indices to be used in the operation.
814 updates: The values to be used in the operation.
815 name: the name of the operation.
817 Returns:
818 The updated variable.
819 """
820 raise NotImplementedError
822 def scatter_nd_update(self, indices, updates, name=None):
823 """Applies sparse assignment to individual values or slices in a Variable.
825 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
827 `indices` must be integer tensor, containing indices into self.
828 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
830 The innermost dimension of `indices` (with length `K`) corresponds to
831 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
832 dimension of self.
834 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
836 ```
837 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
838 ```
840 For example, say we want to add 4 scattered elements to a rank-1 tensor to
841 8 elements. In Python, that update would look like this:
843 ```python
844 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
845 indices = tf.constant([[4], [3], [1] ,[7]])
846 updates = tf.constant([9, 10, 11, 12])
847 v.scatter_nd_update(indices, updates)
848 print(v)
849 ```
851 The resulting update to v would look like this:
853 [1, 11, 3, 10, 9, 6, 7, 12]
855 See `tf.scatter_nd` for more details about how to make updates to
856 slices.
858 Args:
859 indices: The indices to be used in the operation.
860 updates: The values to be used in the operation.
861 name: the name of the operation.
863 Returns:
864 The updated variable.
865 """
866 raise NotImplementedError
868 def sparse_read(self, indices, name=None):
869 r"""Gather slices from params axis axis according to indices.
871 This function supports a subset of tf.gather, see tf.gather for details on
872 usage.
874 Args:
875 indices: The index `Tensor`. Must be one of the following types: `int32`,
876 `int64`. Must be in range `[0, params.shape[axis])`.
877 name: A name for the operation (optional).
879 Returns:
880 A `Tensor`. Has the same type as `params`.
881 """
882 raise AttributeError
884 def gather_nd(self, indices, name=None):
885 r"""Gather slices from `params` into a Tensor with shape specified by `indices`.
887 See tf.gather_nd for details.
889 Args:
890 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
891 Index tensor.
892 name: A name for the operation (optional).
894 Returns:
895 A `Tensor`. Has the same type as `params`.
896 """
897 raise AttributeError
899 @deprecated(None, "Prefer Dataset.range instead.")
900 def count_up_to(self, limit):
901 """Increments this variable until it reaches `limit`.
903 When that Op is run it tries to increment the variable by `1`. If
904 incrementing the variable would bring it above `limit` then the Op raises
905 the exception `OutOfRangeError`.
907 If no error is raised, the Op outputs the value of the variable before
908 the increment.
910 This is essentially a shortcut for `count_up_to(self, limit)`.
912 Args:
913 limit: value at which incrementing the variable raises an error.
915 Returns:
916 A `Tensor` that will hold the variable value before the increment. If no
917 other Op modifies this variable, the values produced will all be
918 distinct.
919 """
920 raise NotImplementedError
922 @deprecated(None,
923 "Prefer Variable.assign which has equivalent behavior in 2.X.")
924 def load(self, value, session=None):
925 """Load new value into this variable.
927 Writes new value to variable's memory. Doesn't add ops to the graph.
929 This convenience method requires a session where the graph
930 containing this variable has been launched. If no session is
931 passed, the default session is used. See `tf.compat.v1.Session` for more
932 information on launching a graph and on sessions.
934 ```python
935 v = tf.Variable([1, 2])
936 init = tf.compat.v1.global_variables_initializer()
938 with tf.compat.v1.Session() as sess:
939 sess.run(init)
940 # Usage passing the session explicitly.
941 v.load([2, 3], sess)
942 print(v.eval(sess)) # prints [2 3]
943 # Usage with the default session. The 'with' block
944 # above makes 'sess' the default session.
945 v.load([3, 4], sess)
946 print(v.eval()) # prints [3 4]
947 ```
949 Args:
950 value: New variable value
951 session: The session to use to evaluate this variable. If none, the
952 default session is used.
954 Raises:
955 ValueError: Session is not passed and no default session
956 """
957 if context.executing_eagerly():
958 self.assign(value)
959 else:
960 session = session or ops.get_default_session()
961 if session is None:
962 raise ValueError(
963 "Either session argument should be provided or default session "
964 "should be established")
965 session.run(self.initializer, {self.initializer.inputs[1]: value})
967 # Conversion to tensor.
968 @staticmethod
969 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name
970 """Utility function for converting a Variable to a Tensor."""
971 _ = name
972 if dtype and not dtype.is_compatible_with(v.dtype):
973 raise ValueError(
974 f"Incompatible type conversion requested to type '{dtype.name}' for "
975 f"variable of type '{v.dtype.name}' (Variable: {v}).")
976 if as_ref:
977 return v._ref() # pylint: disable=protected-access
978 else:
979 return v.value()
981 @classmethod
982 def _OverloadAllOperators(cls): # pylint: disable=invalid-name
983 """Register overloads for all operators."""
984 for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
985 cls._OverloadOperator(operator)
986 # For slicing, bind getitem differently than a tensor (use SliceHelperVar
987 # instead)
988 # pylint: disable=protected-access
989 setattr(cls, "__getitem__", array_ops._SliceHelperVar)
991 @classmethod
992 def _OverloadOperator(cls, operator): # pylint: disable=invalid-name
993 """Defer an operator overload to `ops.Tensor`.
995 We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
997 Args:
998 operator: string. The operator name.
999 """
1000 # We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is
1001 # called when adding a variable to sets. As a result we call a.value() which
1002 # causes infinite recursion when operating within a GradientTape
1003 # TODO(gjn): Consider removing this
1004 if operator == "__eq__" or operator == "__ne__":
1005 return
1007 tensor_oper = getattr(ops.Tensor, operator)
1009 def _run_op(a, *args, **kwargs):
1010 # pylint: disable=protected-access
1011 return tensor_oper(a.value(), *args, **kwargs)
1013 functools.update_wrapper(_run_op, tensor_oper)
1014 setattr(cls, operator, _run_op)
1016 def __hash__(self):
1017 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
1018 raise TypeError(
1019 "Variable is unhashable. "
1020 f"Instead, use variable.ref() as the key. (Variable: {self})")
1021 else:
1022 return id(self)
1024 # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing
1025 def __eq__(self, other):
1026 """Compares two variables element-wise for equality."""
1027 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
1028 return gen_math_ops.equal(self, other, incompatible_shape_error=False)
1029 else:
1030 # In legacy graph mode, tensor equality is object equality
1031 return self is other
1033 # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing
1034 def __ne__(self, other):
1035 """Compares two variables element-wise for equality."""
1036 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
1037 return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
1038 else:
1039 # In legacy graph mode, tensor equality is object equality
1040 return self is not other
1042 def __iter__(self):
1043 """When executing eagerly, iterates over the value of the variable."""
1044 return iter(self.read_value())
1046 # NOTE(mrry): This enables the Variable's overloaded "right" binary
1047 # operators to run when the left operand is an ndarray, because it
1048 # accords the Variable class higher priority than an ndarray, or a
1049 # numpy matrix.
1050 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
1051 # mechanism, which allows more control over how Variables interact
1052 # with ndarrays.
1053 __array_priority__ = 100
1055 @property
1056 def name(self):
1057 """The name of this variable."""
1058 raise NotImplementedError
1060 @property
1061 def _shared_name(self):
1062 """The shared name of the variable.
1064 Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified
1065 name with name scope prefix.
1067 Returns:
1068 variable name.
1069 """
1070 return self.name[:self.name.index(":")]
1072 @property
1073 def initializer(self):
1074 """The initializer operation for this variable."""
1075 raise NotImplementedError
1077 @property
1078 def device(self):
1079 """The device of this variable."""
1080 raise NotImplementedError
1082 @property
1083 def dtype(self):
1084 """The `DType` of this variable."""
1085 raise NotImplementedError
1087 @property
1088 def op(self):
1089 """The `Operation` of this variable."""
1090 raise NotImplementedError
1092 @property
1093 def graph(self):
1094 """The `Graph` of this variable."""
1095 raise NotImplementedError
1097 @property
1098 def shape(self):
1099 """The `TensorShape` of this variable.
1101 Returns:
1102 A `TensorShape`.
1103 """
1104 raise NotImplementedError
1106 def get_shape(self):
1107 """Alias of `Variable.shape`."""
1108 return self.shape
1110 def _gather_saveables_for_checkpoint(self):
1111 """For implementing `Trackable`. This object is saveable on its own."""
1112 return {trackable.VARIABLE_VALUE_KEY: self}
1114 def to_proto(self, export_scope=None):
1115 """Converts a `Variable` to a `VariableDef` protocol buffer.
1117 Args:
1118 export_scope: Optional `string`. Name scope to remove.
1120 Returns:
1121 A `VariableDef` protocol buffer, or `None` if the `Variable` is not
1122 in the specified name scope.
1123 """
1124 raise NotImplementedError
1126 @staticmethod
1127 def from_proto(variable_def, import_scope=None):
1128 """Returns a `Variable` object created from `variable_def`."""
1129 raise NotImplementedError
1131 def _set_save_slice_info(self, save_slice_info):
1132 """Sets the slice info for this `Variable`.
1134 Args:
1135 save_slice_info: A `Variable.SaveSliceInfo` object.
1136 """
1137 self._save_slice_info = save_slice_info
1139 def _get_save_slice_info(self):
1140 return self._save_slice_info
1142 @deprecated(None, "Use ref() instead.")
1143 def experimental_ref(self):
1144 return self.ref()
1146 def ref(self):
1147 # tf.Tensor also has the same ref() API. If you update the
1148 # documentation here, please update tf.Tensor.ref() as well.
1149 """Returns a hashable reference object to this Variable.
1151 The primary use case for this API is to put variables in a set/dictionary.
1152 We can't put variables in a set/dictionary as `variable.__hash__()` is no
1153 longer available starting Tensorflow 2.0.
1155 The following will raise an exception starting 2.0
1157 >>> x = tf.Variable(5)
1158 >>> y = tf.Variable(10)
1159 >>> z = tf.Variable(10)
1160 >>> variable_set = {x, y, z}
1161 Traceback (most recent call last):
1162 ...
1163 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
1164 >>> variable_dict = {x: 'five', y: 'ten'}
1165 Traceback (most recent call last):
1166 ...
1167 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
1169 Instead, we can use `variable.ref()`.
1171 >>> variable_set = {x.ref(), y.ref(), z.ref()}
1172 >>> x.ref() in variable_set
1173 True
1174 >>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
1175 >>> variable_dict[y.ref()]
1176 'ten'
1178 Also, the reference object provides `.deref()` function that returns the
1179 original Variable.
1181 >>> x = tf.Variable(5)
1182 >>> x.ref().deref()
1183 <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
1184 """
1185 return object_identity.Reference(self)
1187 @classmethod
1188 def _variable_call(
1189 cls,
1190 initial_value=None,
1191 trainable=None,
1192 validate_shape=True,
1193 caching_device=None,
1194 name=None,
1195 variable_def=None,
1196 dtype=None,
1197 import_scope=None,
1198 constraint=None,
1199 synchronization=VariableSynchronization.AUTO,
1200 aggregation=VariableAggregation.NONE,
1201 shape=None,
1202 experimental_enable_variable_lifting=None,
1203 **kwargs,
1204 ):
1205 """Variable class getter. Useful to force the signature."""
1206 if cls is not Variable:
1207 return None
1208 previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
1209 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
1210 previous_getter = _make_getter(getter, previous_getter)
1212 # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
1213 if aggregation is None:
1214 aggregation = VariableAggregation.NONE
1215 return previous_getter(
1216 initial_value=initial_value,
1217 trainable=trainable,
1218 validate_shape=validate_shape,
1219 caching_device=caching_device,
1220 name=name,
1221 variable_def=variable_def,
1222 dtype=dtype,
1223 import_scope=import_scope,
1224 constraint=constraint,
1225 synchronization=synchronization,
1226 aggregation=aggregation,
1227 shape=shape,
1228 experimental_enable_variable_lifting=experimental_enable_variable_lifting,
1229 )
1231 class SaveSliceInfo:
1232 """Information on how to save this Variable as a slice.
1234 Provides internal support for saving variables as slices of a larger
1235 variable. This API is not public and is subject to change.
1237 Available properties:
1239 * full_name
1240 * full_shape
1241 * var_offset
1242 * var_shape
1243 """
1245 def __init__(self,
1246 full_name=None,
1247 full_shape=None,
1248 var_offset=None,
1249 var_shape=None,
1250 save_slice_info_def=None,
1251 import_scope=None):
1252 """Create a `SaveSliceInfo`.
1254 Args:
1255 full_name: Name of the full variable of which this `Variable` is a
1256 slice.
1257 full_shape: Shape of the full variable, as a list of int.
1258 var_offset: Offset of this `Variable` into the full variable, as a list
1259 of int.
1260 var_shape: Shape of this `Variable`, as a list of int.
1261 save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
1262 recreates the SaveSliceInfo object its contents. `save_slice_info_def`
1263 and other arguments are mutually exclusive.
1264 import_scope: Optional `string`. Name scope to add. Only used when
1265 initializing from protocol buffer.
1266 """
1267 if save_slice_info_def:
1268 assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
1269 self.full_name = ops.prepend_name_scope(
1270 save_slice_info_def.full_name, import_scope=import_scope)
1271 self.full_shape = list(save_slice_info_def.full_shape)
1272 self.var_offset = list(save_slice_info_def.var_offset)
1273 self.var_shape = list(save_slice_info_def.var_shape)
1274 else:
1275 self.full_name = full_name
1276 self.full_shape = full_shape
1277 self.var_offset = var_offset
1278 self.var_shape = var_shape
1280 @property
1281 def spec(self):
1282 """Computes the spec string used for saving."""
1283 full_shape_str = " ".join("%d" % d for d in self.full_shape) + " "
1284 sl_spec = ":".join(
1285 "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape))
1286 return full_shape_str + sl_spec
1288 def to_proto(self, export_scope=None):
1289 """Returns a SaveSliceInfoDef() proto.
1291 Args:
1292 export_scope: Optional `string`. Name scope to remove.
1294 Returns:
1295 A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
1296 in the specified name scope.
1297 """
1298 if (export_scope is None or self.full_name.startswith(export_scope)):
1299 save_slice_info_def = variable_pb2.SaveSliceInfoDef()
1300 save_slice_info_def.full_name = ops.strip_name_scope(
1301 self.full_name, export_scope)
1302 for i in self.full_shape:
1303 save_slice_info_def.full_shape.append(i)
1304 for i in self.var_offset:
1305 save_slice_info_def.var_offset.append(i)
1306 for i in self.var_shape:
1307 save_slice_info_def.var_shape.append(i)
1308 return save_slice_info_def
1309 else:
1310 return None
1313Variable._OverloadAllOperators() # pylint: disable=protected-access
1314_pywrap_utils.RegisterType("Variable", Variable)
1317def _try_guard_against_uninitialized_dependencies(name, initial_value):
1318 """Attempt to guard against dependencies on uninitialized variables.
1320 Replace references to variables in `initial_value` with references to the
1321 variable's initialized values. The initialized values are essentially
1322 conditional TensorFlow graphs that return a variable's value if it is
1323 initialized or its `initial_value` if it hasn't been initialized. This
1324 replacement is done on a best effort basis:
1326 - If the `initial_value` graph contains cycles, we don't do any
1327 replacements for that graph.
1328 - If the variables that `initial_value` depends on are not present in the
1329 `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
1331 In these cases, it is up to the caller to ensure that the `initial_value`
1332 graph uses initialized variables or that they guard access to variables
1333 using their `initialized_value` method.
1335 Args:
1336 name: Variable name.
1337 initial_value: `Tensor`. The initial value.
1339 Returns:
1340 A `Tensor` suitable to initialize a variable.
1341 Raises:
1342 TypeError: If `initial_value` is not a `Tensor`.
1343 """
1344 if not isinstance(initial_value, ops.Tensor):
1345 raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
1347 # Don't modify initial_value if it contains any cyclic dependencies.
1348 if _has_cycle(initial_value.op, state={}):
1349 return initial_value
1350 return _safe_initial_value_from_tensor(name, initial_value, op_cache={})
1353_UNKNOWN, _STARTED, _FINISHED = range(3)
1356def _has_cycle(op, state):
1357 """Detect cycles in the dependencies of `initial_value`."""
1358 op_state = state.get(op.name, _UNKNOWN)
1359 if op_state == _STARTED:
1360 return True
1361 elif op_state == _FINISHED:
1362 return False
1364 state[op.name] = _STARTED
1365 for i in itertools.chain((i.op for i in op.inputs), op.control_inputs):
1366 if _has_cycle(i, state):
1367 return True
1368 state[op.name] = _FINISHED
1369 return False
1372def _safe_initial_value_from_tensor(name, tensor, op_cache):
1373 """Replace dependencies on variables with their initialized values.
1375 Args:
1376 name: Variable name.
1377 tensor: A `Tensor`. The tensor to replace.
1378 op_cache: A dict mapping operation names to `Operation`s. Used to memoize
1379 the results so as to avoid creating redundant operations.
1381 Returns:
1382 A `Tensor` compatible with `tensor`. Any inputs that lead to variable
1383 values will be replaced with a corresponding graph that uses the
1384 variable's initialized values. This is done on a best-effort basis. If no
1385 modifications need to be made then `tensor` will be returned unchanged.
1386 """
1387 op = tensor.op
1388 new_op = op_cache.get(op.name)
1389 if new_op is None:
1390 new_op = _safe_initial_value_from_op(name, op, op_cache)
1391 op_cache[op.name] = new_op
1392 return new_op.outputs[tensor.value_index]
1395def _safe_initial_value_from_op(name, op, op_cache):
1396 """Replace dependencies on variables with their initialized values.
1398 Args:
1399 name: Variable name.
1400 op: An `Operation`. The operation to replace.
1401 op_cache: A dict mapping operation names to `Operation`s. Used to memoize
1402 the results so as to avoid creating redundant operations.
1404 Returns:
1405 An `Operation` compatible with `op`. Any inputs that lead to variable
1406 values will be replaced with a corresponding graph that uses the
1407 variable's initialized values. This is done on a best-effort basis. If no
1408 modifications need to be made then `op` will be returned unchanged.
1409 """
1410 op_type = op.node_def.op
1411 if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
1412 "ReadVariableOp", "If"):
1413 return op
1415 # Attempt to find the initialized_value of any variable reference / handles.
1416 # TODO(b/70206927): Fix handling of ResourceVariables.
1417 if op_type in ("Variable", "VariableV2", "VarHandleOp"):
1418 initialized_value = _find_initialized_value_for_variable(op)
1419 return op if initialized_value is None else initialized_value.op
1421 # Recursively build initializer expressions for inputs.
1422 modified = False
1423 new_op_inputs = []
1424 for op_input in op.inputs:
1425 new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache)
1426 new_op_inputs.append(new_op_input)
1427 modified = modified or (new_op_input != op_input)
1429 # If at least one input was modified, replace the op.
1430 if modified:
1431 new_op_type = op_type
1432 if new_op_type == "RefSwitch":
1433 new_op_type = "Switch"
1434 new_op_name = op.node_def.name + "_" + name
1435 new_op_name = new_op_name.replace(":", "_")
1436 return op.graph.create_op(
1437 new_op_type,
1438 new_op_inputs,
1439 op._output_types, # pylint: disable=protected-access
1440 name=new_op_name,
1441 attrs=op.node_def.attr)
1443 return op
1446def _find_initialized_value_for_variable(variable_op):
1447 """Find the initialized value for a variable op.
1449 To do so, lookup the variable op in the variables collection.
1451 Args:
1452 variable_op: A variable `Operation`.
1454 Returns:
1455 A `Tensor` representing the initialized value for the variable or `None`
1456 if the initialized value could not be found.
1457 """
1458 try:
1459 var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
1460 for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
1461 ops.GraphKeys.LOCAL_VARIABLES):
1462 for var in variable_op.graph.get_collection(collection_name):
1463 if var.name in var_names:
1464 return var.initialized_value()
1465 except AttributeError:
1466 # Return None when an incomplete user-defined variable type was put in
1467 # the collection.
1468 return None
1469 return None
1472class PartitionedVariable:
1473 """A container for partitioned `Variable` objects.
1475 @compatibility(eager) `tf.PartitionedVariable` is not compatible with
1476 eager execution. Use `tf.Variable` instead which is compatible
1477 with both eager execution and graph construction. See [the
1478 TensorFlow Eager Execution
1479 guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers)
1480 for details on how variables work in eager execution.
1481 @end_compatibility
1482 """
1484 def __init__(self, name, shape, dtype, variable_list, partitions):
1485 """Creates a new partitioned variable wrapper.
1487 Variables passed via the variable_list must contain a save_slice_info
1488 field. Concatenation and iteration is in lexicographic order according
1489 to the var_offset property of the save_slice_info.
1491 Args:
1492 name: String. Overall name of the variables.
1493 shape: List of integers. Overall shape of the variables.
1494 dtype: Type of the variables.
1495 variable_list: List of `Variable` that comprise this partitioned variable.
1496 partitions: List of integers. Number of partitions for each dimension.
1498 Raises:
1499 TypeError: If `variable_list` is not a list of `Variable` objects, or
1500 `partitions` is not a list.
1501 ValueError: If `variable_list` is empty, or the `Variable` shape
1502 information does not match `shape`, or `partitions` has invalid values.
1503 """
1504 if not isinstance(variable_list, (list, tuple)):
1505 raise TypeError("variable_list is not a list or tuple: %s" %
1506 variable_list)
1507 if not isinstance(partitions, (list, tuple)):
1508 raise TypeError("partitions is not a list or tuple: %s" % partitions)
1509 if not all(p >= 1 for p in partitions):
1510 raise ValueError("partition values must be positive: %s" % partitions)
1511 if not variable_list:
1512 raise ValueError("variable_list may not be empty")
1513 # pylint: disable=protected-access
1514 for v in variable_list:
1515 # Sort the variable_list lexicographically according to var offset value.
1516 if not all(v._get_save_slice_info() is not None for v in variable_list):
1517 raise ValueError(
1518 "All variables must have a save_slice_info available: %s" %
1519 [v.name for v in variable_list])
1520 if len(shape) != len(partitions):
1521 raise ValueError("len(shape) != len(partitions): %s vs. %s" %
1522 (shape, partitions))
1523 if v._get_save_slice_info().full_shape != shape:
1524 raise ValueError("All variables' full shapes must match shape: %s; "
1525 "but full shapes were: %s" %
1526 (shape, str([v._get_save_slice_info().full_shape])))
1527 self._variable_list = sorted(
1528 variable_list, key=lambda v: v._get_save_slice_info().var_offset)
1529 # pylint: enable=protected-access
1531 self._name = name
1532 self._shape = shape
1533 self._dtype = dtype
1534 self._partitions = partitions
1535 self._as_tensor = None
1537 def __iter__(self):
1538 """Return an iterable for accessing the underlying partition Variables."""
1539 return iter(self._variable_list)
1541 def __len__(self):
1542 num_partition_axes = len(self._partition_axes())
1543 if num_partition_axes > 1:
1544 raise ValueError("Cannot get a length for %d > 1 partition axes" %
1545 num_partition_axes)
1546 return len(self._variable_list)
1548 def _partition_axes(self):
1549 if all(p == 1 for p in self._partitions):
1550 return [0]
1551 else:
1552 return [i for i, p in enumerate(self._partitions) if p > 1]
1554 def _concat(self):
1555 """Returns the overall concatenated value as a `Tensor`.
1557 This is different from using the partitioned variable directly as a tensor
1558 (through tensor conversion and `as_tensor`) in that it creates a new set of
1559 operations that keeps the control dependencies from its scope.
1561 Returns:
1562 `Tensor` containing the concatenated value.
1563 """
1564 if len(self._variable_list) == 1:
1565 with ops.name_scope(None):
1566 return array_ops.identity(self._variable_list[0], name=self._name)
1568 partition_axes = self._partition_axes()
1570 if len(partition_axes) > 1:
1571 raise NotImplementedError(
1572 "Cannot concatenate along more than one dimension: %s. "
1573 "Multi-axis partition concat is not supported" % str(partition_axes))
1574 partition_ix = partition_axes[0]
1576 with ops.name_scope(self._name + "/ConcatPartitions/"):
1577 concatenated = array_ops.concat(self._variable_list, partition_ix)
1579 with ops.name_scope(None):
1580 return array_ops.identity(concatenated, name=self._name)
1582 def as_tensor(self):
1583 """Returns the overall concatenated value as a `Tensor`.
1585 The returned tensor will not inherit the control dependencies from the scope
1586 where the value is used, which is similar to getting the value of
1587 `Variable`.
1589 Returns:
1590 `Tensor` containing the concatenated value.
1591 """
1592 with ops.control_dependencies(None):
1593 return self._concat()
1595 @staticmethod
1596 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
1597 # pylint: disable=invalid-name
1598 _ = name
1599 if dtype is not None and not dtype.is_compatible_with(v.dtype):
1600 raise ValueError(
1601 "Incompatible type conversion requested to type '%s' for variable "
1602 "of type '%s'" % (dtype.name, v.dtype.name))
1603 if as_ref:
1604 raise NotImplementedError(
1605 "PartitionedVariable doesn't support being used as a reference.")
1606 else:
1607 return v.as_tensor()
1609 @property
1610 def name(self):
1611 return self._name
1613 @property
1614 def dtype(self):
1615 return self._dtype
1617 @property
1618 def shape(self):
1619 return self.get_shape()
1621 @property
1622 def _distribute_strategy(self):
1623 """The `tf.distribute.Strategy` that this variable was created under."""
1624 # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy.
1625 return None
1627 def get_shape(self):
1628 return self._shape
1630 def _get_variable_list(self):
1631 return self._variable_list
1633 def _get_partitions(self):
1634 return self._partitions
1636 def _apply_assign_fn(self, assign_fn, value):
1637 partition_axes = self._partition_axes()
1638 if len(partition_axes) > 1:
1639 raise NotImplementedError(
1640 "Cannot do assign action along more than one dimension: %s. "
1641 "Multi-axis partition assign action is not supported " %
1642 str(partition_axes))
1643 if isinstance(value, list):
1644 assert len(value) == len(self._variable_list)
1645 value_list = value
1646 elif isinstance(value, PartitionedVariable):
1647 value_list = list(value)
1648 else:
1649 partition_ix = partition_axes[0]
1650 size_splits_list = [
1651 tensor_shape.dimension_value(var.shape[partition_ix])
1652 for var in self._variable_list
1653 ]
1654 value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
1656 op_list = [
1657 assign_fn(var, value_list[idx])
1658 for idx, var in enumerate(self._variable_list)
1659 ]
1660 return op_list
1662 def assign(self, value, use_locking=False, name=None, read_value=True):
1663 assign_fn = lambda var, r_value: var.assign(
1664 r_value, use_locking=use_locking, name=name, read_value=read_value)
1665 assign_list = self._apply_assign_fn(assign_fn, value)
1666 if read_value:
1667 return assign_list
1668 return [assign.op for assign in assign_list]
1670 def assign_add(self, value, use_locking=False, name=None, read_value=True):
1671 assign_fn = lambda var, r_value: var.assign_add(
1672 r_value, use_locking=use_locking, name=name, read_value=read_value)
1673 assign_list = self._apply_assign_fn(assign_fn, value)
1674 if read_value:
1675 return assign_list
1676 return [assign.op for assign in assign_list]
1678 def assign_sub(self, value, use_locking=False, name=None, read_value=True):
1679 assign_fn = lambda var, r_value: var.assign_sub(
1680 r_value, use_locking=use_locking, name=name, read_value=read_value)
1681 assign_list = self._apply_assign_fn(assign_fn, value)
1682 if read_value:
1683 return assign_list
1684 return [assign.op for assign in assign_list]
1687@tf_export(v1=["global_variables"])
1688def global_variables(scope=None):
1689 """Returns global variables.
1691 Global variables are variables that are shared across machines in a
1692 distributed environment. The `Variable()` constructor or `get_variable()`
1693 automatically adds new variables to the graph collection
1694 `GraphKeys.GLOBAL_VARIABLES`.
1695 This convenience function returns the contents of that collection.
1697 An alternative to global variables are local variables. See
1698 `tf.compat.v1.local_variables`
1700 @compatibility(TF2)
1701 Not compatible with eager execution and `tf.function`. In particular, Graph
1702 collections are deprecated in TF2. Instead please create a
1703 [tf.Module](https://www.tensorflow.org/guide/intro_to_modules)
1704 container for all your model state, including variables.
1705 You can then list all the variables in your `tf.Module` through the
1706 `variables` attribute.
1707 @end_compatibility
1709 Args:
1710 scope: (Optional.) A string. If supplied, the resulting list is filtered to
1711 include only items whose `name` attribute matches `scope` using
1712 `re.match`. Items without a `name` attribute are never returned if a scope
1713 is supplied. The choice of `re.match` means that a `scope` without special
1714 tokens filters by prefix.
1716 Returns:
1717 A list of `Variable` objects.
1718 """
1719 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
1722@tf_export(v1=["all_variables"])
1723@deprecated("2017-03-02", "Please use tf.global_variables instead.")
1724def all_variables():
1725 """Use `tf.compat.v1.global_variables` instead."""
1726 return global_variables()
1729def _all_saveable_objects(scope=None):
1730 """Returns all variables and `SaveableObject`s that must be checkpointed.
1732 Args:
1733 scope: (Optional.) A string. If supplied, the resulting list is filtered to
1734 include only items whose `name` attribute matches `scope` using
1735 `re.match`. Items without a `name` attribute are never returned if a scope
1736 is supplied. The choice of `re.match` means that a `scope` without special
1737 tokens filters by prefix.
1739 Returns:
1740 A list of `Variable` and `SaveableObject` to be checkpointed
1741 """
1742 # TODO(andreasst): make this function public once things are settled.
1743 return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
1744 ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
1747@tf_export(v1=["local_variables"])
1748def local_variables(scope=None):
1749 """Returns local variables.
1751 Local variables - per process variables, usually not saved/restored to
1752 checkpoint and used for temporary or intermediate values.
1753 For example, they can be used as counters for metrics computation or
1754 number of epochs this machine has read data.
1755 The `tf.contrib.framework.local_variable()` function automatically adds the
1756 new variable to `GraphKeys.LOCAL_VARIABLES`.
1757 This convenience function returns the contents of that collection.
1759 An alternative to local variables are global variables. See
1760 `tf.compat.v1.global_variables`
1762 Args:
1763 scope: (Optional.) A string. If supplied, the resulting list is filtered to
1764 include only items whose `name` attribute matches `scope` using
1765 `re.match`. Items without a `name` attribute are never returned if a scope
1766 is supplied. The choice of `re.match` means that a `scope` without special
1767 tokens filters by prefix.
1769 Returns:
1770 A list of local `Variable` objects.
1771 """
1772 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
1775@tf_export(v1=["model_variables"])
1776def model_variables(scope=None):
1777 """Returns all variables in the MODEL_VARIABLES collection.
1779 Args:
1780 scope: (Optional.) A string. If supplied, the resulting list is filtered to
1781 include only items whose `name` attribute matches `scope` using
1782 `re.match`. Items without a `name` attribute are never returned if a scope
1783 is supplied. The choice of `re.match` means that a `scope` without special
1784 tokens filters by prefix.
1786 Returns:
1787 A list of local Variable objects.
1788 """
1789 return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
1792@tf_export(v1=["trainable_variables"])
1793def trainable_variables(scope=None):
1794 """Returns all variables created with `trainable=True`.
1796 When passed `trainable=True`, the `Variable()` constructor automatically
1797 adds new variables to the graph collection
1798 `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the
1799 contents of that collection.
1801 @compatibility(TF2)
1802 Not compatible with eager execution and `tf.function`. In particular, Graph
1803 collections are deprecated in TF2. Instead please create a `tf.Module`
1804 container for all your model state, including variables.
1805 You can then list all the trainable variables in your `tf.Module` through the
1806 `trainable_variables` attribute.
1807 @end_compatibility
1809 Args:
1810 scope: (Optional.) A string. If supplied, the resulting list is filtered to
1811 include only items whose `name` attribute matches `scope` using
1812 `re.match`. Items without a `name` attribute are never returned if a scope
1813 is supplied. The choice of `re.match` means that a `scope` without special
1814 tokens filters by prefix.
1816 Returns:
1817 A list of Variable objects.
1818 """
1819 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
1822@tf_export(v1=["moving_average_variables"])
1823def moving_average_variables(scope=None):
1824 """Returns all variables that maintain their moving averages.
1826 If an `ExponentialMovingAverage` object is created and the `apply()`
1827 method is called on a list of variables, these variables will
1828 be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
1829 This convenience function returns the contents of that collection.
1831 Args:
1832 scope: (Optional.) A string. If supplied, the resulting list is filtered to
1833 include only items whose `name` attribute matches `scope` using
1834 `re.match`. Items without a `name` attribute are never returned if a scope
1835 is supplied. The choice of `re.match` means that a `scope` without special
1836 tokens filters by prefix.
1838 Returns:
1839 A list of Variable objects.
1840 """
1841 return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
1844@tf_export(v1=["initializers.variables", "variables_initializer"])
1845def variables_initializer(var_list, name="init"):
1846 """Returns an Op that initializes a list of variables.
1848 After you launch the graph in a session, you can run the returned Op to
1849 initialize all the variables in `var_list`. This Op runs all the
1850 initializers of the variables in `var_list` in parallel.
1852 Calling `initialize_variables()` is equivalent to passing the list of
1853 initializers to `Group()`.
1855 If `var_list` is empty, however, the function still returns an Op that can
1856 be run. That Op just has no effect.
1858 @compatibility(TF2)
1859 In TF2, variables are initialized immediately when they are created. There is
1860 no longer a need to run variable initializers before using them.
1861 @end_compatibility
1863 Args:
1864 var_list: List of `Variable` objects to initialize.
1865 name: Optional name for the returned operation.
1867 Returns:
1868 An Op that run the initializers of all the specified variables.
1869 """
1870 if var_list and not context.executing_eagerly():
1871 return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
1872 return control_flow_ops.no_op(name=name)
1875@tf_export(v1=["initialize_variables"])
1876@tf_should_use.should_use_result
1877@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
1878def initialize_variables(var_list, name="init"):
1879 """See `tf.compat.v1.variables_initializer`."""
1880 return variables_initializer(var_list, name=name)
1883@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
1884def global_variables_initializer():
1885 """Returns an Op that initializes global variables.
1887 This is just a shortcut for `variables_initializer(global_variables())`
1889 @compatibility(TF2)
1890 In TF2, variables are initialized immediately when they are created. There is
1891 no longer a need to run variable initializers before using them.
1892 @end_compatibility
1894 Returns:
1895 An Op that initializes global variables in the graph.
1896 """
1897 if context.executing_eagerly():
1898 return control_flow_ops.no_op(name="global_variables_initializer")
1899 return variables_initializer(global_variables())
1902@tf_export(v1=["initialize_all_variables"])
1903@tf_should_use.should_use_result
1904@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
1905def initialize_all_variables():
1906 """See `tf.compat.v1.global_variables_initializer`."""
1907 return global_variables_initializer()
1910@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
1911def local_variables_initializer():
1912 """Returns an Op that initializes all local variables.
1914 This is just a shortcut for `variables_initializer(local_variables())`
1916 @compatibility(TF2)
1917 In TF2, variables are initialized immediately when they are created. There is
1918 no longer a need to run variable initializers before using them.
1919 @end_compatibility
1921 Returns:
1922 An Op that initializes all local variables in the graph.
1923 """
1924 if context.executing_eagerly():
1925 return control_flow_ops.no_op(name="local_variables_initializer")
1926 return variables_initializer(local_variables())
1929@tf_export(v1=["initialize_local_variables"])
1930@tf_should_use.should_use_result
1931@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
1932def initialize_local_variables():
1933 """See `tf.compat.v1.local_variables_initializer`."""
1934 return local_variables_initializer()
1937@tf_export(v1=["assert_variables_initialized"])
1938@tf_should_use.should_use_result
1939def assert_variables_initialized(var_list=None):
1940 """Returns an Op to check if variables are initialized.
1942 NOTE: This function is obsolete and will be removed in 6 months. Please
1943 change your implementation to use `report_uninitialized_variables()`.
1945 When run, the returned Op will raise the exception `FailedPreconditionError`
1946 if any of the variables has not yet been initialized.
1948 Note: This function is implemented by trying to fetch the values of the
1949 variables. If one of the variables is not initialized a message may be
1950 logged by the C++ runtime. This is expected.
1952 Args:
1953 var_list: List of `Variable` objects to check. Defaults to the value of
1954 `global_variables().`
1956 Returns:
1957 An Op, or None if there are no variables.
1958 """
1959 if var_list is None:
1960 var_list = global_variables() + local_variables()
1961 # Backwards compatibility for old-style variables. TODO(touts): remove.
1962 if not var_list:
1963 var_list = []
1964 for op in ops.get_default_graph().get_operations():
1965 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
1966 var_list.append(op.outputs[0])
1967 if not var_list:
1968 return None
1969 else:
1970 ranks = []
1971 for var in var_list:
1972 with ops.colocate_with(var.op):
1973 ranks.append(array_ops.rank_internal(var, optimize=False))
1974 if len(ranks) == 1:
1975 return ranks[0]
1976 else:
1977 return array_ops_stack.stack(ranks)
1980@tf_export(v1=["report_uninitialized_variables"])
1981@tf_should_use.should_use_result
1982def report_uninitialized_variables(var_list=None,
1983 name="report_uninitialized_variables"):
1984 """Adds ops to list the names of uninitialized variables.
1986 When run, it returns a 1-D tensor containing the names of uninitialized
1987 variables if there are any, or an empty array if there are none.
1989 Args:
1990 var_list: List of `Variable` objects to check. Defaults to the value of
1991 `global_variables() + local_variables()`
1992 name: Optional name of the `Operation`.
1994 Returns:
1995 A 1-D tensor containing names of the uninitialized variables, or an empty
1996 1-D tensor if there are no variables or no uninitialized variables.
1997 """
1998 if var_list is None:
1999 var_list = global_variables() + local_variables()
2000 # Backwards compatibility for old-style variables. TODO(touts): remove.
2001 if not var_list:
2002 var_list = []
2003 for op in ops.get_default_graph().get_operations():
2004 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
2005 var_list.append(op.outputs[0])
2006 with ops.name_scope(name):
2007 # Run all operations on CPU
2008 if var_list:
2009 init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
2010 local_device = os.environ.get(
2011 "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0")
2012 with ops.device(local_device):
2013 if not var_list:
2014 # Return an empty tensor so we only need to check for returned tensor
2015 # size being 0 as an indication of model ready.
2016 return array_ops.constant([], dtype=dtypes.string)
2017 else:
2018 # Get a 1-D boolean tensor listing whether each variable is initialized.
2019 variables_mask = math_ops.logical_not(array_ops_stack.stack(init_vars))
2020 # Get a 1-D string tensor containing all the variable names.
2021 variable_names_tensor = array_ops.constant(
2022 [s.op.name for s in var_list])
2023 # Return a 1-D tensor containing all the names of
2024 # uninitialized variables.
2025 return array_ops.boolean_mask(variable_names_tensor, variables_mask)
2028tensor_conversion_registry.register_tensor_conversion_function(
2029 PartitionedVariable, PartitionedVariable._TensorConversionFunction) # pylint: disable=protected-access