Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ref_variable.py: 33%
318 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RefVariable class."""
17from tensorflow.core.framework import attr_value_pb2
18from tensorflow.core.framework import variable_pb2
19from tensorflow.python.eager import context
20from tensorflow.python.framework import indexed_slices
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_conversion_registry
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_array_ops
26from tensorflow.python.ops import gen_state_ops
27from tensorflow.python.ops import resource_variable_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.ops import variable_v1
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.trackable import base as trackable
33from tensorflow.python.types import core
34from tensorflow.python.util import compat
35from tensorflow.python.util import lazy_loader
36from tensorflow.python.util.deprecation import deprecated
39variable_scope = lazy_loader.LazyLoader(
40 "variable_scope", globals(),
41 "tensorflow.python.ops.variable_scope")
44def default_variable_creator(next_creator=None, **kwargs):
45 """Default variable creator."""
46 assert next_creator is None
47 initial_value = kwargs.get("initial_value", None)
48 trainable = kwargs.get("trainable", None)
49 collections = kwargs.get("collections", None)
50 validate_shape = kwargs.get("validate_shape", True)
51 caching_device = kwargs.get("caching_device", None)
52 name = kwargs.get("name", None)
53 variable_def = kwargs.get("variable_def", None)
54 dtype = kwargs.get("dtype", None)
55 expected_shape = kwargs.get("expected_shape", None)
56 import_scope = kwargs.get("import_scope", None)
57 constraint = kwargs.get("constraint", None)
58 use_resource = kwargs.get("use_resource", None)
59 synchronization = kwargs.get("synchronization", None)
60 aggregation = kwargs.get("aggregation", None)
61 shape = kwargs.get("shape", None)
63 if use_resource is None:
64 use_resource = variable_scope.get_variable_scope().use_resource
65 if use_resource is None:
66 use_resource = variable_scope._DEFAULT_USE_RESOURCE # pylint: disable=protected-access
67 use_resource = use_resource or context.executing_eagerly()
68 if use_resource:
69 distribute_strategy = kwargs.get("distribute_strategy", None)
70 return resource_variable_ops.ResourceVariable(
71 initial_value=initial_value,
72 trainable=trainable,
73 collections=collections,
74 validate_shape=validate_shape,
75 caching_device=caching_device,
76 name=name,
77 dtype=dtype,
78 constraint=constraint,
79 variable_def=variable_def,
80 import_scope=import_scope,
81 distribute_strategy=distribute_strategy,
82 synchronization=synchronization,
83 aggregation=aggregation,
84 shape=shape)
85 else:
86 return RefVariable(
87 initial_value=initial_value,
88 trainable=trainable,
89 collections=collections,
90 validate_shape=validate_shape,
91 caching_device=caching_device,
92 name=name,
93 dtype=dtype,
94 constraint=constraint,
95 variable_def=variable_def,
96 expected_shape=expected_shape,
97 import_scope=import_scope,
98 synchronization=synchronization,
99 aggregation=aggregation,
100 shape=shape)
103variable_v1.default_variable_creator = default_variable_creator
106def _to_proto_fn(v, export_scope=None):
107 """Converts Variable and ResourceVariable to VariableDef for collections."""
108 return v.to_proto(export_scope=export_scope)
111def _from_proto_fn(v, import_scope=None):
112 """Creates Variable or ResourceVariable from VariableDef as needed."""
113 if v.is_resource:
114 return resource_variable_ops.ResourceVariable.from_proto(
115 v, import_scope=import_scope)
116 return variable_v1.VariableV1.from_proto(v, import_scope=import_scope)
119ops.register_proto_function(
120 ops.GraphKeys.GLOBAL_VARIABLES,
121 proto_type=variable_pb2.VariableDef,
122 to_proto=_to_proto_fn,
123 from_proto=_from_proto_fn)
124ops.register_proto_function(
125 ops.GraphKeys.TRAINABLE_VARIABLES,
126 proto_type=variable_pb2.VariableDef,
127 to_proto=_to_proto_fn,
128 from_proto=_from_proto_fn)
129ops.register_proto_function(
130 ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
131 proto_type=variable_pb2.VariableDef,
132 to_proto=_to_proto_fn,
133 from_proto=_from_proto_fn)
134ops.register_proto_function(
135 ops.GraphKeys.LOCAL_VARIABLES,
136 proto_type=variable_pb2.VariableDef,
137 to_proto=_to_proto_fn,
138 from_proto=_from_proto_fn)
139ops.register_proto_function(
140 ops.GraphKeys.MODEL_VARIABLES,
141 proto_type=variable_pb2.VariableDef,
142 to_proto=_to_proto_fn,
143 from_proto=_from_proto_fn)
144ops.register_proto_function(
145 ops.GraphKeys.GLOBAL_STEP,
146 proto_type=variable_pb2.VariableDef,
147 to_proto=_to_proto_fn,
148 from_proto=_from_proto_fn)
149ops.register_proto_function(
150 ops.GraphKeys.METRIC_VARIABLES,
151 proto_type=variable_pb2.VariableDef,
152 to_proto=_to_proto_fn,
153 from_proto=_from_proto_fn)
156# TODO(apassos): do not repeat all comments here
157class RefVariable(variable_v1.VariableV1, core.Tensor):
158 """Ref-based implementation of variables."""
160 def __init__(
161 self, # pylint: disable=super-init-not-called
162 initial_value=None,
163 trainable=None,
164 collections=None,
165 validate_shape=True,
166 caching_device=None,
167 name=None,
168 variable_def=None,
169 dtype=None,
170 expected_shape=None,
171 import_scope=None,
172 constraint=None,
173 synchronization=None,
174 aggregation=None,
175 shape=None):
176 """Creates a new variable with value `initial_value`.
178 The new variable is added to the graph collections listed in `collections`,
179 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
181 If `trainable` is `True` the variable is also added to the graph collection
182 `GraphKeys.TRAINABLE_VARIABLES`.
184 This constructor creates both a `variable` Op and an `assign` Op to set the
185 variable to its initial value.
187 Args:
188 initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
189 which is the initial value for the Variable. The initial value must have
190 a shape specified unless `validate_shape` is set to False. Can also be a
191 callable with no argument that returns the initial value when called. In
192 that case, `dtype` must be specified. (Note that initializer functions
193 from init_ops.py must first be bound to a shape before being used here.)
194 trainable: If `True`, also adds the variable to the graph collection
195 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
196 list of variables to use by the `Optimizer` classes. Defaults to `True`,
197 unless `synchronization` is set to `ON_READ`, in which case it defaults
198 to `False`.
199 collections: List of graph collections keys. The new variable is added to
200 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
201 validate_shape: If `False`, allows the variable to be initialized with a
202 value of unknown shape. If `True`, the default, the shape of
203 `initial_value` must be known.
204 caching_device: Optional device string describing where the Variable
205 should be cached for reading. Defaults to the Variable's device. If not
206 `None`, caches on another device. Typical use is to cache on the device
207 where the Ops using the Variable reside, to deduplicate copying through
208 `Switch` and other conditional statements.
209 name: Optional name for the variable. Defaults to `'Variable'` and gets
210 uniquified automatically.
211 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
212 Variable object with its contents, referencing the variable's nodes in
213 the graph, which must already exist. The graph is not changed.
214 `variable_def` and the other arguments are mutually exclusive.
215 dtype: If set, initial_value will be converted to the given type. If
216 `None`, either the datatype will be kept (if `initial_value` is a
217 Tensor), or `convert_to_tensor` will decide.
218 expected_shape: A TensorShape. If set, initial_value is expected to have
219 this shape.
220 import_scope: Optional `string`. Name scope to add to the `Variable.` Only
221 used when initializing from protocol buffer.
222 constraint: An optional projection function to be applied to the variable
223 after being updated by an `Optimizer` (e.g. used to implement norm
224 constraints or value constraints for layer weights). The function must
225 take as input the unprojected Tensor representing the value of the
226 variable and return the Tensor for the projected value (which must have
227 the same shape). Constraints are not safe to use when doing asynchronous
228 distributed training.
229 synchronization: Indicates when a distributed a variable will be
230 aggregated. Accepted values are constants defined in the class
231 `tf.VariableSynchronization`. By default the synchronization is set to
232 `AUTO` and the current `DistributionStrategy` chooses when to
233 synchronize.
234 aggregation: Indicates how a distributed variable will be aggregated.
235 Accepted values are constants defined in the class
236 `tf.VariableAggregation`.
237 shape: (optional) The shape of this variable. If None, the shape of
238 `initial_value` will be used. When setting this argument to
239 `tf.TensorShape(None)` (representing an unspecified shape), the variable
240 can be assigned with values of different shapes.
242 Raises:
243 ValueError: If both `variable_def` and initial_value are specified.
244 ValueError: If the initial value is not specified, or does not have a
245 shape and `validate_shape` is `True`.
246 RuntimeError: If eager execution is enabled.
247 """
248 self._in_graph_mode = True
249 if variable_def:
250 # If variable_def is provided, recreates the variable from its fields.
251 if initial_value:
252 raise ValueError("variable_def and initial_value are mutually "
253 "exclusive.")
254 self._init_from_proto(variable_def, import_scope=import_scope)
255 else:
256 # Create from initial_value.
257 self._init_from_args(
258 initial_value=initial_value,
259 trainable=trainable,
260 collections=collections,
261 validate_shape=validate_shape,
262 caching_device=caching_device,
263 name=name,
264 dtype=dtype,
265 expected_shape=expected_shape,
266 constraint=constraint,
267 synchronization=synchronization,
268 aggregation=aggregation,
269 shape=shape)
271 def __repr__(self):
272 if context.executing_eagerly() and not self._in_graph_mode:
273 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
274 self.name, self.get_shape(), self.dtype.name,
275 ops.numpy_text(self.read_value(), is_repr=True))
276 else:
277 return "<tf.Variable '%s' shape=%s dtype=%s>" % (
278 self.name, self.get_shape(), self.dtype.name)
280 def _init_from_args(self,
281 initial_value=None,
282 trainable=None,
283 collections=None,
284 validate_shape=True,
285 caching_device=None,
286 name=None,
287 dtype=None,
288 expected_shape=None,
289 constraint=None,
290 synchronization=None,
291 aggregation=None,
292 shape=None):
293 """Creates a new variable from arguments.
295 Args:
296 initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
297 which is the initial value for the Variable. The initial value must have
298 a shape specified unless `validate_shape` is set to False. Can also be a
299 callable with no argument that returns the initial value when called.
300 (Note that initializer functions from init_ops.py must first be bound to
301 a shape before being used here.)
302 trainable: If `True`, also adds the variable to the graph collection
303 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
304 list of variables to use by the `Optimizer` classes. Defaults to `True`,
305 unless `synchronization` is set to `ON_READ`, in which case it defaults
306 to `False`.
307 collections: List of graph collections keys. The new variable is added to
308 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
309 validate_shape: If `False`, allows the variable to be initialized with a
310 value of unknown shape. If `True`, the default, the shape of
311 `initial_value` must be known.
312 caching_device: Optional device string or function describing where the
313 Variable should be cached for reading. Defaults to the Variable's
314 device. If not `None`, caches on another device. Typical use is to
315 cache on the device where the Ops using the Variable reside, to
316 deduplicate copying through `Switch` and other conditional statements.
317 name: Optional name for the variable. Defaults to `'Variable'` and gets
318 uniquified automatically.
319 dtype: If set, initial_value will be converted to the given type. If None,
320 either the datatype will be kept (if initial_value is a Tensor) or
321 float32 will be used (if it is a Python object convertible to a Tensor).
322 expected_shape: Deprecated. Ignored.
323 constraint: An optional projection function to be applied to the variable
324 after being updated by an `Optimizer` (e.g. used to implement norm
325 constraints or value constraints for layer weights). The function must
326 take as input the unprojected Tensor representing the value of the
327 variable and return the Tensor for the projected value (which must have
328 the same shape). Constraints are not safe to use when doing asynchronous
329 distributed training.
330 synchronization: Indicates when a distributed a variable will be
331 aggregated. Accepted values are constants defined in the class
332 `tf.VariableSynchronization`. By default the synchronization is set to
333 `AUTO` and the current `DistributionStrategy` chooses when to
334 synchronize.
335 aggregation: Indicates how a distributed variable will be aggregated.
336 Accepted values are constants defined in the class
337 `tf.VariableAggregation`.
338 shape: (optional) The shape of this variable. If None, the shape of
339 `initial_value` will be used. When setting this argument to
340 `tf.TensorShape(None)` (representing an unspecified shape), the variable
341 can be assigned with values of different shapes.
343 Raises:
344 ValueError: If the initial value is not specified, or does not have a
345 shape and `validate_shape` is `True`.
346 RuntimeError: If lifted into the eager context.
347 """
348 _ = expected_shape
349 if initial_value is None:
350 raise ValueError("initial_value must be specified.")
351 init_from_fn = callable(initial_value)
353 if collections is None:
354 collections = [ops.GraphKeys.GLOBAL_VARIABLES]
355 if not isinstance(collections, (list, tuple, set)):
356 raise ValueError(
357 "collections argument to Variable constructor must be a list, tuple, "
358 "or set. Got %s of type %s" % (collections, type(collections)))
359 if constraint is not None and not callable(constraint):
360 raise ValueError("The `constraint` argument must be a callable.")
362 # Store the graph key so optimizers know how to only retrieve variables from
363 # this graph.
364 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
365 if isinstance(initial_value, trackable.CheckpointInitialValue):
366 self._maybe_initialize_trackable()
367 self._update_uid = initial_value.checkpoint_position.restore_uid
368 initial_value = initial_value.wrapped_value
370 synchronization, aggregation, trainable = (
371 variables.validate_synchronization_aggregation_trainable(
372 synchronization, aggregation, trainable, name))
373 self._synchronization = synchronization
374 self._aggregation = aggregation
375 self._trainable = trainable
376 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
377 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
378 with ops.init_scope():
379 # Ensure that we weren't lifted into the eager context.
380 if context.executing_eagerly():
381 raise RuntimeError(
382 "Reference variables are not supported when eager execution is "
383 "enabled. Please run `tf.compat.v1.enable_resource_variables()` to "
384 "switch to resource variables.")
385 with ops.name_scope(name, "Variable",
386 [] if init_from_fn else [initial_value]) as name:
388 if init_from_fn:
389 # Use attr_scope and device(None) to simulate the behavior of
390 # colocate_with when the variable we want to colocate with doesn't
391 # yet exist.
392 true_name = ops.name_from_scope_name(name) # pylint: disable=protected-access
393 attr = attr_value_pb2.AttrValue(
394 list=attr_value_pb2.AttrValue.ListValue(
395 s=[compat.as_bytes("loc:@%s" % true_name)]))
396 # pylint: disable=protected-access
397 with ops.get_default_graph()._attr_scope({"_class": attr}):
398 with ops.name_scope("Initializer"), ops.device(None):
399 initial_value = initial_value()
400 if isinstance(initial_value, trackable.CheckpointInitialValue):
401 self._maybe_initialize_trackable()
402 self._update_uid = initial_value.checkpoint_position.restore_uid
403 initial_value = initial_value.wrapped_value
404 self._initial_value = ops.convert_to_tensor(
405 initial_value, name="initial_value", dtype=dtype)
406 if shape is None:
407 shape = (
408 self._initial_value.get_shape()
409 if validate_shape else tensor_shape.unknown_shape())
410 self._variable = state_ops.variable_op_v2(
411 shape, self._initial_value.dtype.base_dtype, name=name)
412 # pylint: enable=protected-access
414 # Or get the initial value from a Tensor or Python object.
415 else:
416 self._initial_value = ops.convert_to_tensor(
417 initial_value, name="initial_value", dtype=dtype)
418 # pylint: disable=protected-access
419 if self._initial_value.op._get_control_flow_context() is not None:
420 raise ValueError(
421 "Initializer for variable %s is from inside a control-flow "
422 "construct, such as a loop or conditional. When creating a "
423 "variable inside a loop or conditional, use a lambda as the "
424 "initializer." % name)
425 if shape is None:
426 # pylint: enable=protected-access
427 shape = (
428 self._initial_value.get_shape()
429 if validate_shape else tensor_shape.unknown_shape())
430 # In this case, the variable op can't be created until after the
431 # initial_value has been converted to a Tensor with a known type.
432 self._variable = state_ops.variable_op_v2(
433 shape, self._initial_value.dtype.base_dtype, name=name)
435 # Cache the name in `self`, because some APIs call `Variable.name` in a
436 # tight loop, and this halves the cost.
437 self._name = self._variable.name
439 # Manually overrides the variable's shape with the initial value's.
440 if validate_shape:
441 initial_value_shape = self._initial_value.get_shape()
442 if not initial_value_shape.is_fully_defined():
443 raise ValueError("initial_value must have a shape specified: %s" %
444 self._initial_value)
446 # If 'initial_value' makes use of other variables, make sure we don't
447 # have an issue if these other variables aren't initialized first by
448 # using their initialized_value() method.
449 self._initializer_op = state_ops.assign(
450 self._variable,
451 variables._try_guard_against_uninitialized_dependencies( # pylint: disable=protected-access
452 name, self._initial_value),
453 validate_shape=validate_shape).op
455 # TODO(vrv): Change this class to not take caching_device, but
456 # to take the op to colocate the snapshot with, so we can use
457 # colocation rather than devices.
458 if caching_device is not None:
459 with ops.device(caching_device):
460 self._snapshot = array_ops.identity(self._variable, name="read")
461 else:
462 with ops.colocate_with(self._variable.op):
463 self._snapshot = array_ops.identity(self._variable, name="read")
464 ops.add_to_collections(collections, self)
466 self._caching_device = caching_device
467 self._save_slice_info = None
468 self._constraint = constraint
470 def _init_from_proto(self, variable_def, import_scope=None):
471 """Recreates the Variable object from a `VariableDef` protocol buffer.
473 Args:
474 variable_def: `VariableDef` protocol buffer, describing a variable whose
475 nodes already exists in the graph.
476 import_scope: Optional `string`. Name scope to add.
477 """
478 assert isinstance(variable_def, variable_pb2.VariableDef)
479 # Create from variable_def.
480 g = ops.get_default_graph()
481 self._variable = g.as_graph_element(
482 ops.prepend_name_scope(
483 variable_def.variable_name, import_scope=import_scope))
484 self._name = self._variable.name
485 self._initializer_op = g.as_graph_element(
486 ops.prepend_name_scope(
487 variable_def.initializer_name, import_scope=import_scope))
488 # Tests whether initial_value_name exists first for backwards compatibility.
489 if (hasattr(variable_def, "initial_value_name") and
490 variable_def.initial_value_name):
491 self._initial_value = g.as_graph_element(
492 ops.prepend_name_scope(
493 variable_def.initial_value_name, import_scope=import_scope))
494 else:
495 self._initial_value = None
496 synchronization, aggregation, trainable = (
497 variables.validate_synchronization_aggregation_trainable(
498 variable_def.synchronization, variable_def.aggregation,
499 variable_def.trainable, variable_def.variable_name))
500 self._synchronization = synchronization
501 self._aggregation = aggregation
502 self._trainable = trainable
503 self._snapshot = g.as_graph_element(
504 ops.prepend_name_scope(
505 variable_def.snapshot_name, import_scope=import_scope))
506 if variable_def.HasField("save_slice_info_def"):
507 self._save_slice_info = variables.Variable.SaveSliceInfo(
508 save_slice_info_def=variable_def.save_slice_info_def,
509 import_scope=import_scope)
510 else:
511 self._save_slice_info = None
512 self._caching_device = None
513 self._constraint = None
515 def _as_graph_element(self):
516 """Conversion function for Graph.as_graph_element()."""
517 return self._variable
519 def value(self):
520 """Returns the last snapshot of this variable.
522 You usually do not need to call this method as all ops that need the value
523 of the variable call it automatically through a `convert_to_tensor()` call.
525 Returns a `Tensor` which holds the value of the variable. You can not
526 assign a new value to this tensor as it is not a reference to the variable.
528 To avoid copies, if the consumer of the returned value is on the same device
529 as the variable, this actually returns the live value of the variable, not
530 a copy. Updates to the variable are seen by the consumer. If the consumer
531 is on a different device it will get a copy of the variable.
533 Returns:
534 A `Tensor` containing the value of the variable.
535 """
536 return self._snapshot
538 def read_value(self):
539 """Returns the value of this variable, read in the current context.
541 Can be different from value() if it's on another device, with control
542 dependencies, etc.
544 Returns:
545 A `Tensor` containing the value of the variable.
546 """
547 return array_ops.identity(self._variable, name="read")
549 def _ref(self):
550 """Returns a reference to this variable.
552 You usually do not need to call this method as all ops that need a reference
553 to the variable call it automatically.
555 Returns is a `Tensor` which holds a reference to the variable. You can
556 assign a new value to the variable by passing the tensor to an assign op.
557 See `tf.Variable.value` if you want to get the value of the
558 variable.
560 Returns:
561 A `Tensor` that is a reference to the variable.
562 """
563 return self._variable
565 def set_shape(self, shape):
566 """Overrides the shape for this variable.
568 Args:
569 shape: the `TensorShape` representing the overridden shape.
570 """
571 self._ref().set_shape(shape)
572 self.value().set_shape(shape)
574 @property
575 def trainable(self):
576 return self._trainable
578 @property
579 def synchronization(self):
580 return self._synchronization
582 @property
583 def aggregation(self):
584 return self._aggregation
586 def eval(self, session=None):
587 """In a session, computes and returns the value of this variable.
589 This is not a graph construction method, it does not add ops to the graph.
591 This convenience method requires a session where the graph
592 containing this variable has been launched. If no session is
593 passed, the default session is used. See `tf.compat.v1.Session` for more
594 information on launching a graph and on sessions.
596 ```python
597 v = tf.Variable([1, 2])
598 init = tf.compat.v1.global_variables_initializer()
600 with tf.compat.v1.Session() as sess:
601 sess.run(init)
602 # Usage passing the session explicitly.
603 print(v.eval(sess))
604 # Usage with the default session. The 'with' block
605 # above makes 'sess' the default session.
606 print(v.eval())
607 ```
609 Args:
610 session: The session to use to evaluate this variable. If none, the
611 default session is used.
613 Returns:
614 A numpy `ndarray` with a copy of the value of this variable.
615 """
616 return self._variable.eval(session=session)
618 @property
619 def initial_value(self):
620 """Returns the Tensor used as the initial value for the variable.
622 Note that this is different from `initialized_value()` which runs
623 the op that initializes the variable before returning its value.
624 This method returns the tensor that is used by the op that initializes
625 the variable.
627 Returns:
628 A `Tensor`.
629 """
630 return self._initial_value
632 @property
633 def constraint(self):
634 """Returns the constraint function associated with this variable.
636 Returns:
637 The constraint function that was passed to the variable constructor.
638 Can be `None` if no constraint was passed.
639 """
640 return self._constraint
642 def assign(self, value, use_locking=False, name=None, read_value=True):
643 """Assigns a new value to the variable.
645 This is essentially a shortcut for `assign(self, value)`.
647 Args:
648 value: A `Tensor`. The new value for this variable.
649 use_locking: If `True`, use locking during the assignment.
650 name: The name of the operation to be created
651 read_value: if True, will return something which evaluates to the new
652 value of the variable; if False will return the assign op.
654 Returns:
655 A `Tensor` that will hold the new value of this variable after
656 the assignment has completed.
657 """
658 assign = state_ops.assign(
659 self._variable, value, use_locking=use_locking, name=name)
660 if read_value:
661 return assign
662 return assign.op
664 def assign_add(self, delta, use_locking=False, name=None, read_value=True):
665 """Adds a value to this variable.
667 This is essentially a shortcut for `assign_add(self, delta)`.
669 Args:
670 delta: A `Tensor`. The value to add to this variable.
671 use_locking: If `True`, use locking during the operation.
672 name: The name of the operation to be created
673 read_value: if True, will return something which evaluates to the new
674 value of the variable; if False will return the assign op.
676 Returns:
677 A `Tensor` that will hold the new value of this variable after
678 the addition has completed.
679 """
680 assign = state_ops.assign_add(
681 self._variable, delta, use_locking=use_locking, name=name)
682 if read_value:
683 return assign
684 return assign.op
686 def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
687 """Subtracts a value from this variable.
689 This is essentially a shortcut for `assign_sub(self, delta)`.
691 Args:
692 delta: A `Tensor`. The value to subtract from this variable.
693 use_locking: If `True`, use locking during the operation.
694 name: The name of the operation to be created
695 read_value: if True, will return something which evaluates to the new
696 value of the variable; if False will return the assign op.
698 Returns:
699 A `Tensor` that will hold the new value of this variable after
700 the subtraction has completed.
701 """
702 assign = state_ops.assign_sub(
703 self._variable, delta, use_locking=use_locking, name=name)
704 if read_value:
705 return assign
706 return assign.op
708 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
709 """Subtracts `tf.IndexedSlices` from this variable.
711 Args:
712 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
713 use_locking: If `True`, use locking during the operation.
714 name: the name of the operation.
716 Returns:
717 A `Tensor` that will hold the new value of this variable after
718 the scattered subtraction has completed.
720 Raises:
721 TypeError: if `sparse_delta` is not an `IndexedSlices`.
722 """
723 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
724 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
725 return gen_state_ops.scatter_sub(
726 self._variable,
727 sparse_delta.indices,
728 sparse_delta.values,
729 use_locking=use_locking,
730 name=name)
732 def scatter_add(self, sparse_delta, use_locking=False, name=None):
733 """Adds `tf.IndexedSlices` to this variable.
735 Args:
736 sparse_delta: `tf.IndexedSlices` to be added to this variable.
737 use_locking: If `True`, use locking during the operation.
738 name: the name of the operation.
740 Returns:
741 A `Tensor` that will hold the new value of this variable after
742 the scattered addition has completed.
744 Raises:
745 TypeError: if `sparse_delta` is not an `IndexedSlices`.
746 """
747 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
748 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
749 return gen_state_ops.scatter_add(
750 self._variable,
751 sparse_delta.indices,
752 sparse_delta.values,
753 use_locking=use_locking,
754 name=name)
756 def scatter_max(self, sparse_delta, use_locking=False, name=None):
757 """Updates this variable with the max of `tf.IndexedSlices` and itself.
759 Args:
760 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
761 variable.
762 use_locking: If `True`, use locking during the operation.
763 name: the name of the operation.
765 Returns:
766 A `Tensor` that will hold the new value of this variable after
767 the scattered maximization has completed.
769 Raises:
770 TypeError: if `sparse_delta` is not an `IndexedSlices`.
771 """
772 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
773 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
774 return gen_state_ops.scatter_max(
775 self._variable,
776 sparse_delta.indices,
777 sparse_delta.values,
778 use_locking=use_locking,
779 name=name)
781 def scatter_min(self, sparse_delta, use_locking=False, name=None):
782 """Updates this variable with the min of `tf.IndexedSlices` and itself.
784 Args:
785 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
786 variable.
787 use_locking: If `True`, use locking during the operation.
788 name: the name of the operation.
790 Returns:
791 A `Tensor` that will hold the new value of this variable after
792 the scattered minimization has completed.
794 Raises:
795 TypeError: if `sparse_delta` is not an `IndexedSlices`.
796 """
797 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
798 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
799 return gen_state_ops.scatter_min(
800 self._variable,
801 sparse_delta.indices,
802 sparse_delta.values,
803 use_locking=use_locking,
804 name=name)
806 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
807 """Multiply this variable by `tf.IndexedSlices`.
809 Args:
810 sparse_delta: `tf.IndexedSlices` to multiply this variable by.
811 use_locking: If `True`, use locking during the operation.
812 name: the name of the operation.
814 Returns:
815 A `Tensor` that will hold the new value of this variable after
816 the scattered multiplication has completed.
818 Raises:
819 TypeError: if `sparse_delta` is not an `IndexedSlices`.
820 """
821 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
822 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
823 return gen_state_ops.scatter_mul(
824 self._variable,
825 sparse_delta.indices,
826 sparse_delta.values,
827 use_locking=use_locking,
828 name=name)
830 def scatter_div(self, sparse_delta, use_locking=False, name=None):
831 """Divide this variable by `tf.IndexedSlices`.
833 Args:
834 sparse_delta: `tf.IndexedSlices` to divide this variable by.
835 use_locking: If `True`, use locking during the operation.
836 name: the name of the operation.
838 Returns:
839 A `Tensor` that will hold the new value of this variable after
840 the scattered division has completed.
842 Raises:
843 TypeError: if `sparse_delta` is not an `IndexedSlices`.
844 """
845 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
846 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
847 return gen_state_ops.scatter_div(
848 self._variable,
849 sparse_delta.indices,
850 sparse_delta.values,
851 use_locking=use_locking,
852 name=name)
854 def scatter_update(self, sparse_delta, use_locking=False, name=None):
855 """Assigns `tf.IndexedSlices` to this variable.
857 Args:
858 sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
859 use_locking: If `True`, use locking during the operation.
860 name: the name of the operation.
862 Returns:
863 A `Tensor` that will hold the new value of this variable after
864 the scattered assignment has completed.
866 Raises:
867 TypeError: if `sparse_delta` is not an `IndexedSlices`.
868 """
869 if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
870 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
871 return gen_state_ops.scatter_update(
872 self._variable,
873 sparse_delta.indices,
874 sparse_delta.values,
875 use_locking=use_locking,
876 name=name)
878 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
879 """Assigns `tf.IndexedSlices` to this variable batch-wise.
881 Analogous to `batch_gather`. This assumes that this variable and the
882 sparse_delta IndexedSlices have a series of leading dimensions that are the
883 same for all of them, and the updates are performed on the last dimension of
884 indices. In other words, the dimensions should be the following:
886 `num_prefix_dims = sparse_delta.indices.ndims - 1`
887 `batch_dim = num_prefix_dims + 1`
888 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
889 batch_dim:]`
891 where
893 `sparse_delta.updates.shape[:num_prefix_dims]`
894 `== sparse_delta.indices.shape[:num_prefix_dims]`
895 `== var.shape[:num_prefix_dims]`
897 And the operation performed can be expressed as:
899 `var[i_1, ..., i_n,
900 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
901 i_1, ..., i_n, j]`
903 When sparse_delta.indices is a 1D tensor, this operation is equivalent to
904 `scatter_update`.
906 To avoid this operation one can looping over the first `ndims` of the
907 variable and using `scatter_update` on the subtensors that result of slicing
908 the first dimension. This is a valid option for `ndims = 1`, but less
909 efficient than this implementation.
911 Args:
912 sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
913 use_locking: If `True`, use locking during the operation.
914 name: the name of the operation.
916 Returns:
917 A `Tensor` that will hold the new value of this variable after
918 the scattered assignment has completed.
920 Raises:
921 TypeError: if `sparse_delta` is not an `IndexedSlices`.
922 """
923 return state_ops.batch_scatter_update(
924 self,
925 sparse_delta.indices,
926 sparse_delta.values,
927 use_locking=use_locking,
928 name=name)
930 def scatter_nd_sub(self, indices, updates, name=None):
931 """Applies sparse subtraction to individual values or slices in a Variable.
933 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
935 `indices` must be integer tensor, containing indices into `ref`.
936 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
938 The innermost dimension of `indices` (with length `K`) corresponds to
939 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
940 dimension of `ref`.
942 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
944 ```
945 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
946 ```
948 For example, say we want to add 4 scattered elements to a rank-1 tensor to
949 8 elements. In Python, that update would look like this:
951 ```python
952 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
953 indices = tf.constant([[4], [3], [1] ,[7]])
954 updates = tf.constant([9, 10, 11, 12])
955 op = ref.scatter_nd_sub(indices, updates)
956 with tf.compat.v1.Session() as sess:
957 print sess.run(op)
958 ```
960 The resulting update to ref would look like this:
962 [1, -9, 3, -6, -6, 6, 7, -4]
964 See `tf.scatter_nd` for more details about how to make updates to
965 slices.
967 Args:
968 indices: The indices to be used in the operation.
969 updates: The values to be used in the operation.
970 name: the name of the operation.
972 Returns:
973 A `Tensor` that will hold the new value of this variable after
974 the scattered subtraction has completed.
975 """
976 return gen_state_ops.scatter_nd_sub(
977 self._variable, indices, updates, use_locking=True, name=name)
979 def scatter_nd_add(self, indices, updates, name=None):
980 """Applies sparse addition to individual values or slices in a Variable.
982 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
984 `indices` must be integer tensor, containing indices into `ref`.
985 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
987 The innermost dimension of `indices` (with length `K`) corresponds to
988 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
989 dimension of `ref`.
991 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
993 ```
994 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
995 ```
997 For example, say we want to add 4 scattered elements to a rank-1 tensor to
998 8 elements. In Python, that update would look like this:
1000 ```python
1001 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1002 indices = tf.constant([[4], [3], [1] ,[7]])
1003 updates = tf.constant([9, 10, 11, 12])
1004 add = ref.scatter_nd_add(indices, updates)
1005 with tf.compat.v1.Session() as sess:
1006 print sess.run(add)
1007 ```
1009 The resulting update to ref would look like this:
1011 [1, 13, 3, 14, 14, 6, 7, 20]
1013 See `tf.scatter_nd` for more details about how to make updates to
1014 slices.
1016 Args:
1017 indices: The indices to be used in the operation.
1018 updates: The values to be used in the operation.
1019 name: the name of the operation.
1021 Returns:
1022 A `Tensor` that will hold the new value of this variable after
1023 the scattered addition has completed.
1024 """
1025 return gen_state_ops.scatter_nd_add(
1026 self._variable, indices, updates, use_locking=True, name=name)
1028 def scatter_nd_update(self, indices, updates, name=None):
1029 """Applies sparse assignment to individual values or slices in a Variable.
1031 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1033 `indices` must be integer tensor, containing indices into `ref`.
1034 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1036 The innermost dimension of `indices` (with length `K`) corresponds to
1037 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1038 dimension of `ref`.
1040 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1042 ```
1043 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1044 ```
1046 For example, say we want to add 4 scattered elements to a rank-1 tensor to
1047 8 elements. In Python, that update would look like this:
1049 ```python
1050 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1051 indices = tf.constant([[4], [3], [1] ,[7]])
1052 updates = tf.constant([9, 10, 11, 12])
1053 op = ref.scatter_nd_update(indices, updates)
1054 with tf.compat.v1.Session() as sess:
1055 print sess.run(op)
1056 ```
1058 The resulting update to ref would look like this:
1060 [1, 11, 3, 10, 9, 6, 7, 12]
1062 See `tf.scatter_nd` for more details about how to make updates to
1063 slices.
1065 Args:
1066 indices: The indices to be used in the operation.
1067 updates: The values to be used in the operation.
1068 name: the name of the operation.
1070 Returns:
1071 A `Tensor` that will hold the new value of this variable after
1072 the scattered assignment has completed.
1073 """
1074 return gen_state_ops.scatter_nd_update(
1075 self._variable, indices, updates, use_locking=True, name=name)
1077 def scatter_nd_max(self, indices, updates, name=None):
1078 """Updates this variable with the max of `tf.IndexedSlices` and itself.
1080 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1082 `indices` must be integer tensor, containing indices into `ref`.
1083 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1085 The innermost dimension of `indices` (with length `K`) corresponds to
1086 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1087 dimension of `ref`.
1089 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1091 ```
1092 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1093 ```
1095 See `tf.scatter_nd` for more details about how to make updates to
1096 slices.
1098 Args:
1099 indices: The indices to be used in the operation.
1100 updates: The values to be used in the operation.
1101 name: the name of the operation.
1103 Returns:
1104 A `Tensor` that will hold the new value of this variable after
1105 the scattered addition has completed.
1106 """
1107 return gen_state_ops.scatter_nd_max(
1108 self._variable, indices, updates, use_locking=True, name=name)
1110 def scatter_nd_min(self, indices, updates, name=None):
1111 """Updates this variable with the min of `tf.IndexedSlices` and itself.
1113 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1115 `indices` must be integer tensor, containing indices into `ref`.
1116 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1118 The innermost dimension of `indices` (with length `K`) corresponds to
1119 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1120 dimension of `ref`.
1122 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1124 ```
1125 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1126 ```
1128 See `tf.scatter_nd` for more details about how to make updates to
1129 slices.
1131 Args:
1132 indices: The indices to be used in the operation.
1133 updates: The values to be used in the operation.
1134 name: the name of the operation.
1136 Returns:
1137 A `Tensor` that will hold the new value of this variable after
1138 the scattered addition has completed.
1139 """
1140 return gen_state_ops.scatter_nd_min(
1141 self._variable, indices, updates, use_locking=True, name=name)
1143 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
1144 end_mask, ellipsis_mask, new_axis_mask,
1145 shrink_axis_mask):
1146 return gen_array_ops.strided_slice_assign(
1147 ref=self._ref(),
1148 begin=begin,
1149 end=end,
1150 strides=strides,
1151 value=value,
1152 name=name,
1153 begin_mask=begin_mask,
1154 end_mask=end_mask,
1155 ellipsis_mask=ellipsis_mask,
1156 new_axis_mask=new_axis_mask,
1157 shrink_axis_mask=shrink_axis_mask)
1159 @deprecated(None, "Prefer Dataset.range instead.")
1160 def count_up_to(self, limit):
1161 """Increments this variable until it reaches `limit`.
1163 When that Op is run it tries to increment the variable by `1`. If
1164 incrementing the variable would bring it above `limit` then the Op raises
1165 the exception `OutOfRangeError`.
1167 If no error is raised, the Op outputs the value of the variable before
1168 the increment.
1170 This is essentially a shortcut for `count_up_to(self, limit)`.
1172 Args:
1173 limit: value at which incrementing the variable raises an error.
1175 Returns:
1176 A `Tensor` that will hold the variable value before the increment. If no
1177 other Op modifies this variable, the values produced will all be
1178 distinct.
1179 """
1180 return state_ops.count_up_to(self._variable, limit=limit)
1182 # Conversion to tensor.
1183 @staticmethod
1184 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name
1185 """Utility function for converting a Variable to a Tensor."""
1186 _ = name
1187 if dtype and not dtype.is_compatible_with(v.dtype):
1188 raise ValueError(
1189 "Incompatible type conversion requested to type '%s' for variable "
1190 "of type '%s'" % (dtype.name, v.dtype.name))
1191 if as_ref:
1192 return v._ref() # pylint: disable=protected-access
1193 else:
1194 return v.value()
1196 # NOTE(mrry): This enables the Variable's overloaded "right" binary
1197 # operators to run when the left operand is an ndarray, because it
1198 # accords the Variable class higher priority than an ndarray, or a
1199 # numpy matrix.
1200 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
1201 # mechanism, which allows more control over how Variables interact
1202 # with ndarrays.
1203 __array_priority__ = 100
1205 @property
1206 def name(self):
1207 """The name of this variable."""
1208 return self._name
1210 @property
1211 def initializer(self):
1212 """The initializer operation for this variable."""
1213 return self._initializer_op
1215 @property
1216 def device(self):
1217 """The device of this variable."""
1218 return self._variable.device
1220 @property
1221 def dtype(self):
1222 """The `DType` of this variable."""
1223 return self._variable.dtype
1225 @property
1226 def op(self):
1227 """The `Operation` of this variable."""
1228 return self._variable.op
1230 @property
1231 def graph(self):
1232 """The `Graph` of this variable."""
1233 return self._variable.graph
1235 @property
1236 def _distribute_strategy(self):
1237 """The `tf.distribute.Strategy` that this variable was created under."""
1238 return None # Ref variables are never created inside a strategy.
1240 @property
1241 def shape(self):
1242 """The `TensorShape` of this variable.
1244 Returns:
1245 A `TensorShape`.
1246 """
1247 return self._variable.get_shape()
1249 def to_proto(self, export_scope=None):
1250 """Converts a `Variable` to a `VariableDef` protocol buffer.
1252 Args:
1253 export_scope: Optional `string`. Name scope to remove.
1255 Returns:
1256 A `VariableDef` protocol buffer, or `None` if the `Variable` is not
1257 in the specified name scope.
1258 """
1259 if (export_scope is None or self._variable.name.startswith(export_scope)):
1260 var_def = variable_pb2.VariableDef()
1261 var_def.variable_name = ops.strip_name_scope(self._variable.name,
1262 export_scope)
1263 if self._initial_value is not None:
1264 # For backwards compatibility.
1265 var_def.initial_value_name = ops.strip_name_scope(
1266 self._initial_value.name, export_scope)
1267 var_def.trainable = self.trainable
1268 var_def.synchronization = self.synchronization.value
1269 var_def.aggregation = self.aggregation.value
1270 var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
1271 export_scope)
1272 var_def.snapshot_name = ops.strip_name_scope(self._snapshot.name,
1273 export_scope)
1274 if self._save_slice_info:
1275 var_def.save_slice_info_def.MergeFrom(
1276 self._save_slice_info.to_proto(export_scope=export_scope))
1277 return var_def
1278 else:
1279 return None
1281 def __iadd__(self, other):
1282 logging.log_first_n(
1283 logging.WARN, "Variable += will be deprecated. Use variable.assign_add"
1284 " if you want assignment to the variable value or 'x = x + y'"
1285 " if you want a new python Tensor object.", 1)
1286 return self + other
1288 def __isub__(self, other):
1289 logging.log_first_n(
1290 logging.WARN, "Variable -= will be deprecated. Use variable.assign_sub"
1291 " if you want assignment to the variable value or 'x = x - y'"
1292 " if you want a new python Tensor object.", 1)
1293 return self - other
1295 def __imul__(self, other):
1296 logging.log_first_n(
1297 logging.WARN,
1298 "Variable *= will be deprecated. Use `var.assign(var * other)`"
1299 " if you want assignment to the variable value or `x = x * y`"
1300 " if you want a new python Tensor object.", 1)
1301 return self * other
1303 def __idiv__(self, other):
1304 logging.log_first_n(
1305 logging.WARN,
1306 "Variable /= will be deprecated. Use `var.assign(var / other)`"
1307 " if you want assignment to the variable value or `x = x / y`"
1308 " if you want a new python Tensor object.", 1)
1309 return self / other
1311 def __itruediv__(self, other):
1312 logging.log_first_n(
1313 logging.WARN,
1314 "Variable /= will be deprecated. Use `var.assign(var / other)`"
1315 " if you want assignment to the variable value or `x = x / y`"
1316 " if you want a new python Tensor object.", 1)
1317 return self / other
1319 def __irealdiv__(self, other):
1320 logging.log_first_n(
1321 logging.WARN,
1322 "Variable /= will be deprecated. Use `var.assign(var / other)`"
1323 " if you want assignment to the variable value or `x = x / y`"
1324 " if you want a new python Tensor object.", 1)
1325 return self / other
1327 def __ipow__(self, other):
1328 logging.log_first_n(
1329 logging.WARN,
1330 "Variable **= will be deprecated. Use `var.assign(var ** other)`"
1331 " if you want assignment to the variable value or `x = x ** y`"
1332 " if you want a new python Tensor object.", 1)
1333 return self**other
1335 def _serialize_to_tensors(self):
1336 """Implements Trackable._serialize_to_tensors."""
1337 return {trackable.VARIABLE_VALUE_KEY: self}
1339 def _restore_from_tensors(self, restored_tensors):
1340 """Implements Trackable._restore_from_tensors."""
1341 restored_tensor = restored_tensors[trackable.VARIABLE_VALUE_KEY]
1342 return state_ops.assign(
1343 self,
1344 restored_tensor,
1345 validate_shape=self.get_shape().is_fully_defined())
1348# Register a conversion function which reads the value of the variable,
1349# allowing instances of the class to be used as tensors.
1350tensor_conversion_registry.register_tensor_conversion_function(
1351 RefVariable, RefVariable._TensorConversionFunction) # pylint: disable=protected-access
1354variable_v1.set_variable_from_proto_fn(RefVariable)