Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/variable_scope.py: 24%
792 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A class to store named variables and a scope operator to manage sharing."""
17import copy
18import enum
19import functools
20import sys
21import threading
22import traceback
24from tensorflow.python import tf2
25from tensorflow.python.client import session
26from tensorflow.python.eager import context
27from tensorflow.python.eager import monitoring
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_conversion_registry
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import init_ops
34from tensorflow.python.ops import ref_variable
35from tensorflow.python.ops import resource_variable_ops
36from tensorflow.python.ops import variable_v1
37from tensorflow.python.ops import variables
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.types import core
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import function_utils
42from tensorflow.python.util import tf_contextlib
43from tensorflow.python.util import tf_inspect
44from tensorflow.python.util.compat import collections_abc
45from tensorflow.python.util.tf_export import tf_export
48__all__ = [
49 "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
50 "get_local_variable", "variable_scope", "variable_op_scope",
51 "no_regularizer", "VariableSynchronization", "VariableAggregation"
52]
54_api_usage_gauge = monitoring.BoolGauge(
55 "/tensorflow/api/resource_variables",
56 "Whether variable_scope.enable_resource_variables() is called.")
59class _PartitionInfo:
60 """Holds partition info used by initializer functions."""
62 __slots__ = ["_full_shape", "_var_offset"]
64 def __init__(self, full_shape, var_offset):
65 """Constructor.
67 Args:
68 full_shape: Tuple or list of `int` indicating the full combined shape of
69 the partitioned variables.
70 var_offset: Tuple or list of `int` specifying offset of this partition
71 with respect to the full variable for each dimension.
73 Raises:
74 TypeError: If `full_shape` or `var_offset` is not a sequence.
75 ValueError: If `full_shape` or `var_offset` differ in length. If
76 `var_offset` exceeds `full_shape` in any dimension.
77 """
78 if not isinstance(full_shape, (list, tuple)):
79 raise TypeError(
80 "`full_shape` must be a sequence (like tuple or list) instead of " +
81 type(full_shape).__name__)
83 if not isinstance(var_offset, (list, tuple)):
84 raise TypeError(
85 "`var_offset` must be a sequence (like tuple or list) instead of " +
86 type(var_offset).__name__)
88 if len(var_offset) != len(full_shape):
89 raise ValueError(
90 "Expected equal length, but `var_offset` is of length {} while "
91 "full_shape is of length {}.".format(
92 len(var_offset), len(full_shape)))
94 for offset, shape in zip(var_offset, full_shape):
95 if offset < 0 or offset >= shape:
96 raise ValueError(
97 "Expected 0 <= offset < shape but found offset={}, shape={} for "
98 "var_offset={}, full_shape={}".format(offset, shape, var_offset,
99 full_shape))
101 self._full_shape = full_shape
102 self._var_offset = var_offset
104 @property
105 def full_shape(self):
106 return self._full_shape
108 @property
109 def var_offset(self):
110 return self._var_offset
112 def single_offset(self, shape):
113 """Returns the offset when the variable is partitioned in at most one dim.
115 Args:
116 shape: Tuple or list of `int` indicating the shape of one specific
117 variable partition.
119 Returns:
120 `int` representing the offset in the dimension along which the variable is
121 partitioned. Returns 0 if the variable is not being partitioned.
123 Raises:
124 ValueError: Depending on self.single_slice_dim().
125 """
127 single_slice_dim = self.single_slice_dim(shape)
128 # If this variable is not being partitioned at all, single_slice_dim() could
129 # return None.
130 if single_slice_dim is None:
131 return 0
132 return self.var_offset[single_slice_dim]
134 def single_slice_dim(self, shape):
135 """Returns the slice dim when the variable is partitioned only in one dim.
137 Args:
138 shape: Tuple or list of `int` indicating the shape of one specific
139 variable partition.
141 Returns:
142 `int` representing the dimension that the variable is partitioned in, or
143 `None` if the variable doesn't seem to be partitioned at all.
145 Raises:
146 TypeError: If `shape` is not a sequence.
147 ValueError: If `shape` is not the same length as `self.full_shape`. If
148 the variable is partitioned in more than one dimension.
149 """
150 if not isinstance(shape, (tuple, list)):
151 raise TypeError(
152 "`shape` must be a sequence (like tuple or list) instead of " +
153 type(shape).__name__)
155 if len(shape) != len(self.full_shape):
156 raise ValueError(
157 "Expected equal length, but received shape={} of length {} while "
158 "self.full_shape={} is of length {}.".format(shape, len(shape),
159 self.full_shape,
160 len(self.full_shape)))
162 for i in range(len(shape)):
163 if self.var_offset[i] + shape[i] > self.full_shape[i]:
164 raise ValueError(
165 "With self.var_offset={}, a partition of shape={} would exceed "
166 "self.full_shape={} in dimension {}.".format(
167 self.var_offset, shape, self.full_shape, i))
169 slice_dim = None
170 for i in range(len(shape)):
171 if shape[i] == self.full_shape[i]:
172 continue
173 if slice_dim is not None:
174 raise ValueError(
175 "Cannot use single_slice_dim() with shape={} and "
176 "self.full_shape={} since slice dim could be either dimension {} "
177 "or {}.".format(shape, self.full_shape, i, slice_dim))
178 slice_dim = i
180 return slice_dim
183class _ReuseMode(enum.Enum):
184 """Mode for variable access within a variable scope."""
186 # Indicates that variables are to be fetched if they already exist or
187 # otherwise created.
188 AUTO_REUSE = 1
190 # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of
191 # enum values.
192 # REUSE_FALSE = 2
193 # REUSE_TRUE = 3
196# TODO(apassos) remove these forwarding symbols.
197VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name
198VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
200AUTO_REUSE = _ReuseMode.AUTO_REUSE
201tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
202AUTO_REUSE.__doc__ = """
203@compatibility(TF2)
204`tf.compat.v1.AUTO_REUSE` is a legacy API that is a no-op when TF2 behaviors
205are enabled.
207If you rely on `get_variable` and auto-reuse, see the
208[model mapping guide](https://www.tensorflow.org/guide/migrate/model_mapping)
209for more info on how to migrate your code.
211Note: when you use the `tf.compat.v1.keras.utils.track_tf1_style_variables`
212API as described in the above guide, `get_variable` will always behave as if
213`v1.AUTO_REUSE` is set. Without the decorator, reuse will be ignored and new
214variables will always be created, regardless of if they have already been
215created.
216@end_compatibility
218When passed in as the value for the `reuse` flag, `AUTO_REUSE` indicates that
219get_variable() should create the requested variable if it doesn't exist or, if
220it does exist, simply return it.
221"""
223_DEFAULT_USE_RESOURCE = tf2.enabled()
226@tf_export(v1=["enable_resource_variables"])
227def enable_resource_variables():
228 """Creates resource variables by default.
230 Resource variables are improved versions of TensorFlow variables with a
231 well-defined memory model. Accessing a resource variable reads its value, and
232 all ops which access a specific read value of the variable are guaranteed to
233 see the same value for that tensor. Writes which happen after a read (by
234 having a control or data dependency on the read) are guaranteed not to affect
235 the value of the read tensor, and similarly writes which happen before a read
236 are guaranteed to affect the value. No guarantees are made about unordered
237 read/write pairs.
239 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
240 feature.
241 """
242 global _DEFAULT_USE_RESOURCE
243 _DEFAULT_USE_RESOURCE = True
244 logging.vlog(1, "Enabling resource variables")
245 _api_usage_gauge.get_cell().set(True)
248@tf_export(v1=["resource_variables_enabled"])
249def resource_variables_enabled():
250 """Returns `True` if resource variables are enabled.
252 Resource variables are improved versions of TensorFlow variables with a
253 well-defined memory model. Accessing a resource variable reads its value, and
254 all ops which access a specific read value of the variable are guaranteed to
255 see the same value for that tensor. Writes which happen after a read (by
256 having a control or data dependency on the read) are guaranteed not to affect
257 the value of the read tensor, and similarly writes which happen before a read
258 are guaranteed to affect the value. No guarantees are made about unordered
259 read/write pairs.
261 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
262 feature.
263 """
264 global _DEFAULT_USE_RESOURCE
265 return _DEFAULT_USE_RESOURCE
268@deprecation.deprecated(
269 None, "non-resource variables are not supported in the long term")
270@tf_export(v1=["disable_resource_variables"])
271def disable_resource_variables():
272 """Opts out of resource variables.
274 If your code needs tf.disable_resource_variables() to be called to work
275 properly please file a bug.
276 """
277 global _DEFAULT_USE_RESOURCE
278 _DEFAULT_USE_RESOURCE = False
279 logging.vlog(1, "Disabling resource variables")
280 _api_usage_gauge.get_cell().set(False)
283def _needs_no_arguments(python_callable):
284 """Returns true if the callable needs no arguments to call."""
285 # TODO(bfontain): Switch to inspect.signature when we are python 3 only.
286 # signature = inspect.signature(python_callable)
287 # return not [1 for param in signature.parameters.values()
288 # if param.default == param.empty]
289 num_arguments = len(tf_inspect.getargspec(python_callable).args)
290 if not tf_inspect.isfunction(python_callable) and not isinstance(
291 python_callable, functools.partial):
292 # getargspec includes self for function objects (which aren't
293 # functools.partial). This has no default so we need to remove it.
294 # It is not even an argument so its odd that getargspec returns this.
295 # Note that this is fixed with inspect.signature in Python 3.
296 num_arguments -= 1
297 return num_arguments == len(
298 tf_inspect.getargspec(python_callable).defaults or [])
301class _VariableStore:
302 """Variable store that carries a number of named Variables.
304 New variable names and new variables can be created; all stored
305 variables are initialized with the initializer passed to __init__.
307 Attributes:
308 vars: a dictionary with string names (same as passed in GetVar) as keys and
309 the corresponding TensorFlow Variables as values.
310 """
312 __slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"]
314 def __init__(self):
315 """Create a variable store."""
316 self._vars = {} # A dictionary of the stored TensorFlow variables.
317 self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
318 self._store_eager_variables = False
320 def get_variable(self,
321 name,
322 shape=None,
323 dtype=dtypes.float32,
324 initializer=None,
325 regularizer=None,
326 reuse=None,
327 trainable=None,
328 collections=None,
329 caching_device=None,
330 partitioner=None,
331 validate_shape=True,
332 use_resource=None,
333 custom_getter=None,
334 constraint=None,
335 synchronization=VariableSynchronization.AUTO,
336 aggregation=VariableAggregation.NONE):
337 """Gets an existing variable with these parameters or create a new one.
339 If a variable with the given name is already stored, we return the stored
340 variable. Otherwise, we create a new one.
342 Set `reuse` to `True` when you only want to reuse existing Variables.
343 Set `reuse` to `False` when you only want to create new Variables.
344 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
345 variables to be created if they don't exist or returned if they do.
347 If initializer is `None` (the default), the default initializer passed in
348 the constructor is used. If that one is `None` too, we use a new
349 `glorot_uniform_initializer`. If initializer is a Tensor, we use
350 it as a value and derive the shape from the initializer.
352 If a partitioner is provided, a `PartitionedVariable` is returned.
353 Accessing this object as a `Tensor` returns the shards concatenated along
354 the partition axis.
356 Some useful partitioners are available. See, e.g.,
357 `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
359 Args:
360 name: The name of the new or existing variable.
361 shape: Shape of the new or existing variable.
362 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
363 initializer: Initializer for the variable.
364 regularizer: A (Tensor -> Tensor or None) function; the result of applying
365 it on a newly created variable will be added to the collection
366 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
367 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
368 variables. When eager execution is enabled this argument is always
369 forced to be False.
370 trainable: If `True` also add the variable to the graph collection
371 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
372 defaults to `True`, unless `synchronization` is set to `ON_READ`, in
373 which case it defaults to `False`.
374 collections: List of graph collections keys to add the `Variable` to.
375 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
376 caching_device: Optional device string or function describing where the
377 Variable should be cached for reading. Defaults to the Variable's
378 device. If not `None`, caches on another device. Typical use is to
379 cache on the device where the Ops using the `Variable` reside, to
380 deduplicate copying through `Switch` and other conditional statements.
381 partitioner: Optional callable that accepts a fully defined `TensorShape`
382 and dtype of the `Variable` to be created, and returns a list of
383 partitions for each axis (currently only one axis can be partitioned).
384 validate_shape: If False, allows the variable to be initialized with a
385 value of unknown shape. If True, the default, the shape of initial_value
386 must be known.
387 use_resource: If False, creates a regular Variable. If True, creates
388 instead an experimental ResourceVariable which has well-defined
389 semantics. Defaults to False (will later change to True). When eager
390 execution is enabled this argument is always forced to be true.
391 custom_getter: Callable that takes as a first argument the true getter,
392 and allows overwriting the internal get_variable method. The signature
393 of `custom_getter` should match that of this method,
394 but the most future-proof version will allow for changes: `def
395 custom_getter(getter, *args, **kwargs)`. Direct access to
396 all `get_variable` parameters is also allowed: `def
397 custom_getter(getter, name, *args, **kwargs)`. A simple identity
398 custom getter that simply creates variables with modified names is:
399 ```python
400 def custom_getter(getter, name, *args, **kwargs): return getter(name +
401 '_suffix', *args, **kwargs) ```
402 constraint: An optional projection function to be applied to the variable
403 after being updated by an `Optimizer` (e.g. used to implement norm
404 constraints or value constraints for layer weights). The function must
405 take as input the unprojected Tensor representing the value of the
406 variable and return the Tensor for the projected value (which must have
407 the same shape). Constraints are not safe to use when doing asynchronous
408 distributed training.
409 synchronization: Indicates when a distributed a variable will be
410 aggregated. Accepted values are constants defined in the class
411 `tf.VariableSynchronization`. By default the synchronization is set to
412 `AUTO` and the current `DistributionStrategy` chooses when to
413 synchronize.
414 aggregation: Indicates how a distributed variable will be aggregated.
415 Accepted values are constants defined in the class
416 `tf.VariableAggregation`.
418 Returns:
419 The created or existing `Variable` (or `PartitionedVariable`, if a
420 partitioner was used).
422 Raises:
423 ValueError: when creating a new variable and shape is not declared,
424 when reusing a variable and specifying a conflicting shape,
425 or when violating reuse during variable creation.
426 RuntimeError: when eager execution is enabled and not called from an
427 EagerVariableStore.
428 """
429 if custom_getter is not None and not callable(custom_getter):
430 raise ValueError("Passed a custom_getter which is not callable: %s" %
431 custom_getter)
433 with ops.init_scope():
434 if context.executing_eagerly():
435 # Variable creation and initialization takes place in `init_scope`s;
436 # as such, if an `init_scope` lifts us into the eager context, then we
437 # need to use `ResourceVariable`s.
438 use_resource = True
440 # Note that it's fine to reuse eager variables whose initialization was
441 # lifted from a function-building graph into the eager context (that's why
442 # the following clause is not wrapped in an `init_scope`); lifted variables
443 # are tracked by the graph's `VariableStore`.
444 if context.executing_eagerly():
445 if not self._store_eager_variables and reuse:
446 raise RuntimeError(
447 "When eager execution is enabled variable reuse is only supported"
448 " when an EagerVariableStore is active. See the documentation on"
449 " EagerVariableStore for example usage.")
450 if self._store_eager_variables:
451 reuse = AUTO_REUSE
453 # If a *_ref type is passed in an error would be triggered further down the
454 # stack. We prevent this using base_dtype to get a non-ref version of the
455 # type, before doing anything else. When _ref types are removed in favor of
456 # resources, this line can be removed.
457 try:
458 dtype = dtype.base_dtype
459 except AttributeError:
460 # .base_dtype not existing means that we will try and use the raw dtype
461 # which was passed in - this might be a NumPy type which is valid.
462 pass
464 # This is the main logic of get_variable. However, custom_getter
465 # may override this logic. So we save it as a callable and pass
466 # it to custom_getter.
467 # Note: the parameters of _true_getter, and their documentation, match
468 # *exactly* item-for-item with the docstring of this method.
469 def _true_getter( # pylint: disable=missing-docstring
470 name,
471 shape=None,
472 dtype=dtypes.float32,
473 initializer=None,
474 regularizer=None,
475 reuse=None,
476 trainable=None,
477 collections=None,
478 caching_device=None,
479 partitioner=None,
480 validate_shape=True,
481 use_resource=None,
482 constraint=None,
483 synchronization=VariableSynchronization.AUTO,
484 aggregation=VariableAggregation.NONE):
485 is_scalar = (
486 shape is not None and isinstance(shape, collections_abc.Sequence) and
487 not shape)
488 # Partitioned variable case
489 if partitioner is not None and not is_scalar:
490 if not callable(partitioner):
491 raise ValueError("Partitioner must be callable, but received: %s" %
492 partitioner)
493 with ops.name_scope(None):
494 return self._get_partitioned_variable(
495 name=name,
496 shape=shape,
497 dtype=dtype,
498 initializer=initializer,
499 regularizer=regularizer,
500 reuse=reuse,
501 trainable=trainable,
502 collections=collections,
503 caching_device=caching_device,
504 partitioner=partitioner,
505 validate_shape=validate_shape,
506 use_resource=use_resource,
507 constraint=constraint,
508 synchronization=synchronization,
509 aggregation=aggregation)
511 # Special case for partitioned variable to allow reuse without having to
512 # specify partitioner.
513 if (reuse is True and partitioner is None
514 and name in self._partitioned_vars):
515 return self._get_partitioned_variable(
516 name=name,
517 shape=shape,
518 dtype=dtype,
519 initializer=initializer,
520 regularizer=regularizer,
521 reuse=reuse,
522 trainable=trainable,
523 collections=collections,
524 caching_device=caching_device,
525 partitioner=None,
526 validate_shape=validate_shape,
527 use_resource=use_resource,
528 constraint=constraint,
529 synchronization=synchronization,
530 aggregation=aggregation)
532 # Single variable case
533 if "%s/part_0" % name in self._vars:
534 raise ValueError(
535 "No partitioner was provided, but a partitioned version of the "
536 "variable was found: %s/part_0. Perhaps a variable of the same "
537 "name was already created with partitioning?" % name)
539 return self._get_single_variable(
540 name=name,
541 shape=shape,
542 dtype=dtype,
543 initializer=initializer,
544 regularizer=regularizer,
545 reuse=reuse,
546 trainable=trainable,
547 collections=collections,
548 caching_device=caching_device,
549 validate_shape=validate_shape,
550 use_resource=use_resource,
551 constraint=constraint,
552 synchronization=synchronization,
553 aggregation=aggregation)
555 synchronization, aggregation, trainable = (
556 variables.validate_synchronization_aggregation_trainable(
557 synchronization, aggregation, trainable, name))
559 if custom_getter is not None:
560 # Handle backwards compatibility with getter arguments that were added
561 # to the API after users started writing custom getters.
562 custom_getter_kwargs = {
563 "getter": _true_getter,
564 "name": name,
565 "shape": shape,
566 "dtype": dtype,
567 "initializer": initializer,
568 "regularizer": regularizer,
569 "reuse": reuse,
570 "trainable": trainable,
571 "collections": collections,
572 "caching_device": caching_device,
573 "partitioner": partitioner,
574 "validate_shape": validate_shape,
575 "use_resource": use_resource,
576 "synchronization": synchronization,
577 "aggregation": aggregation,
578 }
579 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
580 # `lambda`.
581 if ("constraint" in function_utils.fn_args(custom_getter) or
582 function_utils.has_kwargs(custom_getter)):
583 custom_getter_kwargs["constraint"] = constraint
584 return custom_getter(**custom_getter_kwargs)
585 else:
586 return _true_getter(
587 name,
588 shape=shape,
589 dtype=dtype,
590 initializer=initializer,
591 regularizer=regularizer,
592 reuse=reuse,
593 trainable=trainable,
594 collections=collections,
595 caching_device=caching_device,
596 partitioner=partitioner,
597 validate_shape=validate_shape,
598 use_resource=use_resource,
599 constraint=constraint,
600 synchronization=synchronization,
601 aggregation=aggregation)
603 def _get_partitioned_variable(self,
604 name,
605 partitioner,
606 shape=None,
607 dtype=dtypes.float32,
608 initializer=None,
609 regularizer=None,
610 reuse=None,
611 trainable=None,
612 collections=None,
613 caching_device=None,
614 validate_shape=True,
615 use_resource=None,
616 constraint=None,
617 synchronization=VariableSynchronization.AUTO,
618 aggregation=VariableAggregation.NONE):
619 """Gets or creates a sharded variable list with these parameters.
621 The `partitioner` must be a callable that accepts a fully defined
622 `TensorShape` and returns a sequence of integers (the `partitions`).
623 These integers describe how to partition the given sharded `Variable`
624 along the given dimension. That is, `partitions[1] = 3` means split
625 the `Variable` into 3 shards along dimension 1. Currently, sharding along
626 only one axis is supported.
628 If the list of variables with the given name (prefix) is already stored,
629 we return the stored variables. Otherwise, we create a new one.
631 Set `reuse` to `True` when you only want to reuse existing Variables.
632 Set `reuse` to `False` when you only want to create new Variables.
633 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
634 variables to be created if they don't exist or returned if they do.
636 If initializer is `None` (the default), the default initializer passed in
637 the constructor is used. If that one is `None` too, we use a new
638 `glorot_uniform_initializer`. If initializer is a Tensor, we use
639 it as a value and derive the shape from the initializer.
641 If the initializer is a callable, then it will be called for each
642 shard. Otherwise the initializer should match the shape of the entire
643 sharded Variable, and it will be sliced accordingly for each shard.
645 Some useful partitioners are available. See, e.g.,
646 `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
648 Args:
649 name: the name of the new or existing sharded variable.
650 partitioner: Optional callable that accepts a fully defined `TensorShape`
651 and `dtype` of the Variable to be created, and returns a list of
652 partitions for each axis (currently only one axis can be partitioned).
653 shape: shape of the new or existing sharded variable.
654 dtype: type of the new or existing sharded variable (defaults to
655 `DT_FLOAT`).
656 initializer: initializer for the sharded variable.
657 regularizer: a (Tensor -> Tensor or None) function; the result of applying
658 it on a newly created variable will be added to the collection
659 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
660 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
661 variables.
662 trainable: If `True` also add the variable to the graph collection
663 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
664 collections: List of graph collections keys to add the Variable to.
665 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
666 caching_device: Optional device string or function describing where the
667 Variable should be cached for reading. Defaults to the Variable's
668 device. If not `None`, caches on another device. Typical use is to
669 cache on the device where the Ops using the Variable reside, to
670 deduplicate copying through `Switch` and other conditional statements.
671 validate_shape: If False, allows the variable to be initialized with a
672 value of unknown shape. If True, the default, the shape of initial_value
673 must be known.
674 use_resource: If False, creates a regular Variable. If True, creates an
675 experimental ResourceVariable which has well-defined semantics. Defaults
676 to False (will later change to True).
677 constraint: An optional projection function to be applied to the variable
678 after being updated by an `Optimizer` (e.g. used to implement norm
679 constraints or value constraints for layer weights). The function must
680 take as input the unprojected Tensor representing the value of the
681 variable and return the Tensor for the projected value (which must have
682 the same shape). Constraints are not safe to use when doing asynchronous
683 distributed training.
684 synchronization: Indicates when a distributed a variable will be
685 aggregated. Accepted values are constants defined in the class
686 `tf.VariableSynchronization`. By default the synchronization is set to
687 `AUTO` and the current `DistributionStrategy` chooses when to
688 synchronize.
689 aggregation: Indicates how a distributed variable will be aggregated.
690 Accepted values are constants defined in the class
691 `tf.VariableAggregation`.
693 Returns:
694 A `PartitionedVariable` object.
696 Raises:
697 ValueError: when creating a new variable and shape is not declared,
698 when reusing a variable and specifying a conflicting shape,
699 when violating reuse during variable creation, or if an existing
700 sharded variable exists for the given name but with different sharding.
701 """
702 initializing_from_value = initializer is not None and isinstance(
703 initializer, ops.Tensor)
704 if name in self._vars:
705 raise ValueError(
706 "A partitioner was provided, but an unpartitioned version of the "
707 "variable was found: %s. Perhaps a variable of the same name was "
708 "already created without partitioning?" % name)
710 shape = tensor_shape.as_shape(shape)
711 if initializing_from_value:
712 shape = shape.merge_with(initializer.get_shape())
714 partitions = None
715 if not reuse or partitioner:
716 partitions = _call_partitioner(partitioner, shape, dtype)
718 if name in self._partitioned_vars:
719 if reuse is False:
720 raise ValueError(
721 "Partitioned variable with name %s already exists. Did you mean to "
722 "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name)
724 existing_var = self._partitioned_vars[name]
725 if not shape.is_compatible_with(existing_var.get_shape()):
726 raise ValueError(
727 "Trying to reuse partitioned variable %s, but specified shape %s "
728 "and found shape %s." % (name, shape, existing_var.get_shape()))
729 if not dtype.is_compatible_with(existing_var.dtype):
730 raise ValueError(
731 "Trying to reuse partitioned variable %s, but specified dtype %s "
732 "and found dtype %s." % (name, dtype.name, existing_var.dtype.name))
734 # pylint: disable=protected-access
735 if (partitions is not None and
736 existing_var._get_partitions() != partitions):
737 raise ValueError(
738 "Trying to reuse partitioned variable %s, but specified partitions "
739 "%s and found partitions %s." %
740 (name, partitions, existing_var._get_partitions()))
741 # pylint: enable=protected-access
743 return existing_var
745 if reuse is True:
746 raise ValueError("PartitionedVariable %s does not exist, or was not "
747 "created with tf.get_variable(). Did you mean to set "
748 "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name)
750 slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions)
752 if "%s/part_0" % name in self._vars:
753 if "%s/part_%d" % (name, num_slices - 1) not in self._vars:
754 raise ValueError(
755 "Partitioner returned a different partitioning than what was "
756 "already found. Partitioner returned %d shards, and shard "
757 "%s/part_0 was found, but %s/part_%d was not." %
758 (num_slices, name, name, num_slices - 1))
759 if "%s/part_%d" % (name, num_slices) in self._vars:
760 raise ValueError(
761 "Partitioner returned a different partitioning than what was "
762 "already found. Partitioner returned %d shards, and shard "
763 "%s/part_0 was found, but so was the extra shard %s/part_%d." %
764 (num_slices, name, name, num_slices))
766 vs = []
767 for i, (var_offset, var_shape) in enumerate(
768 _iter_slices(shape.as_list(), num_slices, slice_dim)):
769 partition_info = _PartitionInfo(
770 full_shape=shape.as_list(), var_offset=var_offset)
771 var_full_name = "%s/part_%d" % (name, i)
772 with ops.name_scope(
773 var_full_name + "/PartitionedInitializer", skip_on_eager=False):
774 # Create the tensor to initialize the variable with default value.
775 if initializer is None:
776 init, initializing_from_value = self._get_default_initializer(
777 name=name, shape=shape, dtype=dtype)
778 if initializing_from_value:
779 init_shape = None
780 else:
781 init_shape = var_shape
782 elif callable(initializer):
783 init = initializer
784 init_shape = var_shape
785 elif isinstance(initializer, ops.Tensor):
786 init = array_ops.slice(initializer, var_offset, var_shape)
787 # Use the dtype of the given tensor.
788 dtype = init.dtype.base_dtype
789 init_shape = None
790 else:
791 init = ops.convert_to_tensor(initializer, dtype=dtype)
792 init = array_ops.slice(init, var_offset, var_shape)
793 init_shape = None
795 with ops.name_scope(None):
796 var = self._get_single_variable(
797 name=var_full_name,
798 shape=init_shape,
799 dtype=dtype,
800 initializer=init,
801 partition_info=partition_info,
802 regularizer=regularizer,
803 reuse=reuse,
804 trainable=trainable,
805 collections=collections,
806 caching_device=caching_device,
807 validate_shape=validate_shape,
808 use_resource=use_resource,
809 constraint=constraint,
810 synchronization=synchronization,
811 aggregation=aggregation)
813 # pylint: disable=protected-access
814 var._set_save_slice_info(
815 variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset,
816 var_shape))
817 vs.append(var)
818 # pylint: enable=protected-access
820 partitioned_var = variables.PartitionedVariable(
821 name=name,
822 shape=shape,
823 dtype=dtype,
824 variable_list=vs,
825 partitions=partitions)
826 if not context.executing_eagerly() or self._store_eager_variables:
827 self._partitioned_vars[name] = partitioned_var
828 return partitioned_var
830 def _get_single_variable(self,
831 name,
832 shape=None,
833 dtype=dtypes.float32,
834 initializer=None,
835 regularizer=None,
836 partition_info=None,
837 reuse=None,
838 trainable=None,
839 collections=None,
840 caching_device=None,
841 validate_shape=True,
842 use_resource=None,
843 constraint=None,
844 synchronization=VariableSynchronization.AUTO,
845 aggregation=VariableAggregation.NONE):
846 """Get or create a single Variable (e.g.
848 a shard or entire variable).
850 See the documentation of get_variable above (ignore partitioning components)
851 for details.
853 Args:
854 name: see get_variable.
855 shape: see get_variable.
856 dtype: see get_variable.
857 initializer: see get_variable.
858 regularizer: see get_variable.
859 partition_info: _PartitionInfo object.
860 reuse: see get_variable.
861 trainable: see get_variable.
862 collections: see get_variable.
863 caching_device: see get_variable.
864 validate_shape: see get_variable.
865 use_resource: see get_variable.
866 constraint: see get_variable.
867 synchronization: see get_variable.
868 aggregation: see get_variable.
870 Returns:
871 A Variable. See documentation of get_variable above.
873 Raises:
874 ValueError: See documentation of get_variable above.
875 """
876 # Set to true if initializer is a constant.
877 initializing_from_value = False
878 if initializer is not None and not callable(initializer):
879 initializing_from_value = True
880 if shape is not None and initializing_from_value:
881 raise ValueError("If initializer is a constant, do not specify shape.")
883 dtype = dtypes.as_dtype(dtype)
884 shape = tensor_shape.as_shape(shape)
886 if name in self._vars:
887 # Here we handle the case when returning an existing variable.
888 if reuse is False:
889 var = self._vars[name]
890 err_msg = ("Variable %s already exists, disallowed."
891 " Did you mean to set reuse=True or "
892 "reuse=tf.AUTO_REUSE in VarScope?" % name)
893 # ResourceVariables don't have an op associated with so no traceback
894 if isinstance(var, resource_variable_ops.ResourceVariable):
895 raise ValueError(err_msg)
896 tb = var.op.traceback[::-1]
897 # Throw away internal tf entries and only take a few lines. In some
898 # cases the traceback can be longer (e.g. if someone uses factory
899 # functions to create variables) so we take more than needed in the
900 # default case.
901 tb = [x for x in tb if "tensorflow/python" not in x[0]][:5]
902 raise ValueError("%s Originally defined at:\n\n%s" %
903 (err_msg, "".join(traceback.format_list(tb))))
904 found_var = self._vars[name]
905 if not shape.is_compatible_with(found_var.get_shape()):
906 raise ValueError("Trying to share variable %s, but specified shape %s"
907 " and found shape %s." %
908 (name, shape, found_var.get_shape()))
909 if not dtype.is_compatible_with(found_var.dtype):
910 dtype_str = dtype.name
911 found_type_str = found_var.dtype.name
912 raise ValueError("Trying to share variable %s, but specified dtype %s"
913 " and found dtype %s." %
914 (name, dtype_str, found_type_str))
915 return found_var
917 # The code below handles only the case of creating a new variable.
918 if reuse is True:
919 raise ValueError("Variable %s does not exist, or was not created with "
920 "tf.get_variable(). Did you mean to set "
921 "reuse=tf.AUTO_REUSE in VarScope?" % name)
923 # Create the tensor to initialize the variable with default value.
924 if initializer is None:
925 initializer, initializing_from_value = self._get_default_initializer(
926 name=name, shape=shape, dtype=dtype)
927 # Enter an init scope when creating the initializer.
928 with ops.init_scope():
929 if initializing_from_value:
930 init_val = initializer
931 variable_dtype = None
932 else:
933 # Instantiate initializer if provided initializer is a type object.
934 if tf_inspect.isclass(initializer):
935 initializer = initializer()
936 if shape.is_fully_defined():
937 if "partition_info" in tf_inspect.getargspec(initializer).args:
938 init_val = functools.partial(initializer,
939 shape.as_list(),
940 dtype=dtype,
941 partition_info=partition_info)
942 else:
943 init_val = functools.partial(initializer,
944 shape.as_list(), dtype=dtype)
945 variable_dtype = dtype.base_dtype
946 elif _needs_no_arguments(initializer):
947 init_val = initializer
948 variable_dtype = None
949 else:
950 raise ValueError("The initializer passed is not valid. It should "
951 "be a callable with no arguments and the "
952 "shape should not be provided or an instance of "
953 "`tf.keras.initializers.*' and `shape` should be "
954 "fully defined.")
956 # Create the variable.
957 if use_resource is None:
958 # Set the default value if unspecified.
959 use_resource = _DEFAULT_USE_RESOURCE
960 v = variable_v1.VariableV1(
961 initial_value=init_val,
962 name=name,
963 trainable=trainable,
964 collections=collections,
965 caching_device=caching_device,
966 dtype=variable_dtype,
967 validate_shape=validate_shape,
968 constraint=constraint,
969 use_resource=use_resource,
970 synchronization=synchronization,
971 aggregation=aggregation)
972 if context.executing_eagerly() and self._store_eager_variables:
973 if collections:
974 ops.add_to_collections(collections, v)
975 else:
976 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
977 if trainable:
978 ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)
980 if not context.executing_eagerly() or self._store_eager_variables:
981 # In eager mode we do not want to keep default references to Variable
982 # objects as this will prevent their memory from being released.
983 self._vars[name] = v
984 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
985 format(shape), initializer)
987 # Run the regularizer if requested and save the resulting loss.
988 if regularizer:
989 def make_regularizer_op():
990 with ops.colocate_with(v):
991 with ops.name_scope(name + "/Regularizer/"):
992 return regularizer(v)
994 if regularizer(v) is not None:
995 lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op)
996 ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
997 lazy_eval_tensor)
999 return v
1001 # Initialize variable when no initializer provided
1002 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
1003 """Provide a default initializer and a corresponding value.
1005 Args:
1006 name: see get_variable.
1007 shape: see get_variable.
1008 dtype: see get_variable.
1010 Returns:
1011 initializer and initializing_from_value. See get_variable above.
1013 Raises:
1014 ValueError: When giving unsupported dtype.
1015 """
1016 del shape
1017 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
1018 if dtype.is_floating:
1019 initializer = init_ops.glorot_uniform_initializer()
1020 initializing_from_value = False
1021 # If dtype is DT_INT/DT_UINT, provide a default value `zero`
1022 # If dtype is DT_BOOL, provide a default value `FALSE`
1023 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or
1024 dtype == dtypes.string):
1025 initializer = init_ops.zeros_initializer()
1026 initializing_from_value = False
1027 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
1028 else:
1029 raise ValueError("An initializer for variable %s of %s is required" %
1030 (name, dtype.base_dtype))
1032 return initializer, initializing_from_value
1035class _LazyEvalTensor(core.Tensor):
1036 """A Tensor-like object that only evaluates its thunk when used."""
1038 def __init__(self, thunk):
1039 """Initializes a _LazyEvalTensor object.
1041 Args:
1042 thunk: A callable. A thunk which computes the value of the tensor.
1043 """
1044 self._thunk = thunk
1045 self._master_tensor = thunk()
1047 def _as_tensor(self, dtype=None, name=None, as_ref=False):
1048 del name
1049 assert not as_ref
1050 assert dtype in [None, self.dtype]
1052 return self._thunk()
1055def _make_master_property(name):
1056 @property
1057 def prop(self):
1058 return getattr(self._master_tensor, name) # pylint: disable=protected-access
1059 return prop
1061_master_property_list = ("device", "dtype", "graph", "name", "op", "shape",
1062 "value_index")
1063for _name in _master_property_list:
1064 setattr(_LazyEvalTensor, _name, _make_master_property(_name))
1067def _make_master_method(name):
1068 def method(self, *args, **kwargs):
1069 return getattr(self._master_tensor, name)(*args, **kwargs) # pylint: disable=protected-access
1070 return method
1072_master_method_list = ("get_shape", "__str__", "shape_as_list")
1073for _name in _master_method_list:
1074 setattr(_LazyEvalTensor, _name, _make_master_method(_name))
1077def _make_op_method(name):
1078 def method(self, *args, **kwargs):
1079 return getattr(self._as_tensor(), name)(*args, **kwargs) # pylint: disable=protected-access
1080 return method
1082_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__",
1083 "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__",
1084 "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__",
1085 "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__",
1086 "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__",
1087 "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__",
1088 "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__",
1089 "eval", "numpy")
1090for _name in _op_list:
1091 setattr(_LazyEvalTensor, _name, _make_op_method(_name))
1094tensor_conversion_registry.register_tensor_conversion_function(
1095 _LazyEvalTensor,
1096 lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref) # pylint: disable=protected-access
1097 )
1099session.register_session_run_conversion_functions(
1100 _LazyEvalTensor,
1101 lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access
1102 )
1105# To stop regularization, use this regularizer
1106@tf_export(v1=["no_regularizer"])
1107def no_regularizer(_):
1108 """Use this function to prevent regularization of variables."""
1109 return None
1112# TODO(alive): support caching devices and partitioned variables in Eager mode.
1113@tf_export(v1=["VariableScope"])
1114class VariableScope:
1115 """Variable scope object to carry defaults to provide to `get_variable`.
1117 Many of the arguments we need for `get_variable` in a variable store are most
1118 easily handled with a context. This object is used for the defaults.
1120 Attributes:
1121 name: name of the current scope, used as prefix in get_variable.
1122 initializer: default initializer passed to get_variable.
1123 regularizer: default regularizer passed to get_variable.
1124 reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in
1125 get_variable. When eager execution is enabled this argument is always
1126 forced to be False.
1127 caching_device: string, callable, or None: the caching device passed to
1128 get_variable.
1129 partitioner: callable or `None`: the partitioner passed to `get_variable`.
1130 custom_getter: default custom getter passed to get_variable.
1131 name_scope: The name passed to `tf.name_scope`.
1132 dtype: default type passed to get_variable (defaults to DT_FLOAT).
1133 use_resource: if False, create a normal Variable; if True create an
1134 experimental ResourceVariable with well-defined semantics. Defaults to
1135 False (will later change to True). When eager execution is enabled this
1136 argument is always forced to be True.
1137 constraint: An optional projection function to be applied to the variable
1138 after being updated by an `Optimizer` (e.g. used to implement norm
1139 constraints or value constraints for layer weights). The function must
1140 take as input the unprojected Tensor representing the value of the
1141 variable and return the Tensor for the projected value (which must have
1142 the same shape). Constraints are not safe to use when doing asynchronous
1143 distributed training.
1144 """
1146 def __init__(self,
1147 reuse,
1148 name="",
1149 initializer=None,
1150 regularizer=None,
1151 caching_device=None,
1152 partitioner=None,
1153 custom_getter=None,
1154 name_scope="",
1155 dtype=dtypes.float32,
1156 use_resource=None,
1157 constraint=None):
1158 """Creates a new VariableScope with the given properties."""
1159 self._name = name
1160 self._initializer = initializer
1161 self._regularizer = regularizer
1162 self._reuse = reuse
1163 self._caching_device = caching_device
1164 self._partitioner = partitioner
1165 self._custom_getter = custom_getter
1166 self._name_scope = name_scope
1167 self._dtype = dtype
1168 self._use_resource = use_resource
1169 self._constraint = constraint
1170 if context.executing_eagerly():
1171 if self._caching_device is not None:
1172 raise NotImplementedError("Caching devices is not yet supported "
1173 "when eager execution is enabled.")
1174 self._reuse = AUTO_REUSE
1175 self._use_resource = True
1177 @property
1178 def name(self):
1179 return self._name
1181 @property
1182 def original_name_scope(self):
1183 return self._name_scope
1185 @property
1186 def reuse(self):
1187 return self._reuse
1189 @property
1190 def initializer(self):
1191 return self._initializer
1193 @property
1194 def dtype(self):
1195 return self._dtype
1197 @property
1198 def use_resource(self):
1199 return self._use_resource
1201 @property
1202 def regularizer(self):
1203 return self._regularizer
1205 @property
1206 def caching_device(self):
1207 return self._caching_device
1209 @property
1210 def partitioner(self):
1211 return self._partitioner
1213 @property
1214 def custom_getter(self):
1215 return self._custom_getter
1217 @property
1218 def constraint(self):
1219 return self._constraint
1221 def reuse_variables(self):
1222 """Reuse variables in this scope."""
1223 self._reuse = True
1225 def set_initializer(self, initializer):
1226 """Set initializer for this scope."""
1227 self._initializer = initializer
1229 def set_dtype(self, dtype):
1230 """Set data type for this scope."""
1231 self._dtype = dtype
1233 def set_use_resource(self, use_resource):
1234 """Sets whether to use ResourceVariables for this scope."""
1235 if context.executing_eagerly() and not use_resource:
1236 raise ValueError("When eager execution is enabled, "
1237 "use_resource cannot be set to false.")
1238 self._use_resource = use_resource
1240 def set_regularizer(self, regularizer):
1241 """Set regularizer for this scope."""
1242 self._regularizer = regularizer
1244 def set_caching_device(self, caching_device):
1245 """Set caching_device for this scope."""
1246 if context.executing_eagerly():
1247 raise NotImplementedError("Caching devices are not yet supported "
1248 "when eager execution is enabled.")
1249 self._caching_device = caching_device
1251 def set_partitioner(self, partitioner):
1252 """Set partitioner for this scope."""
1253 self._partitioner = partitioner
1255 def set_custom_getter(self, custom_getter):
1256 """Set custom getter for this scope."""
1257 self._custom_getter = custom_getter
1259 def get_collection(self, name):
1260 """Get this scope's variables."""
1261 scope = self._name + "/" if self._name else ""
1262 return ops.get_collection(name, scope)
1264 def trainable_variables(self):
1265 """Get this scope's trainable variables."""
1266 return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1268 def global_variables(self):
1269 """Get this scope's global variables."""
1270 return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1272 def local_variables(self):
1273 """Get this scope's local variables."""
1274 return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
1276 def get_variable(self,
1277 var_store,
1278 name,
1279 shape=None,
1280 dtype=None,
1281 initializer=None,
1282 regularizer=None,
1283 reuse=None,
1284 trainable=None,
1285 collections=None,
1286 caching_device=None,
1287 partitioner=None,
1288 validate_shape=True,
1289 use_resource=None,
1290 custom_getter=None,
1291 constraint=None,
1292 synchronization=VariableSynchronization.AUTO,
1293 aggregation=VariableAggregation.NONE):
1294 """Gets an existing variable with this name or create a new one."""
1295 if regularizer is None:
1296 regularizer = self._regularizer
1297 if caching_device is None:
1298 caching_device = self._caching_device
1299 if partitioner is None:
1300 partitioner = self._partitioner
1301 if custom_getter is None:
1302 custom_getter = self._custom_getter
1303 if context.executing_eagerly():
1304 reuse = False
1305 use_resource = True
1306 else:
1307 if reuse is None:
1308 reuse = self._reuse
1309 if use_resource is None:
1310 use_resource = self._use_resource
1312 full_name = self.name + "/" + name if self.name else name
1313 # Variable names only depend on variable_scope (full_name here),
1314 # not name_scope, so we reset it below for the time of variable creation.
1315 with ops.name_scope(None, skip_on_eager=False):
1316 # Check that `initializer` dtype and `dtype` are consistent before
1317 # replacing them with defaults.
1318 if (dtype is not None and initializer is not None and
1319 not callable(initializer)):
1320 init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
1321 if init_dtype != dtype:
1322 raise ValueError("Initializer type '%s' and explicit dtype '%s' "
1323 "don't match." % (init_dtype, dtype))
1324 if initializer is None:
1325 initializer = self._initializer
1326 if constraint is None:
1327 constraint = self._constraint
1328 if dtype is None:
1329 dtype = self._dtype
1330 return var_store.get_variable(
1331 full_name,
1332 shape=shape,
1333 dtype=dtype,
1334 initializer=initializer,
1335 regularizer=regularizer,
1336 reuse=reuse,
1337 trainable=trainable,
1338 collections=collections,
1339 caching_device=caching_device,
1340 partitioner=partitioner,
1341 validate_shape=validate_shape,
1342 use_resource=use_resource,
1343 custom_getter=custom_getter,
1344 constraint=constraint,
1345 synchronization=synchronization,
1346 aggregation=aggregation)
1348 def _get_partitioned_variable(self,
1349 var_store,
1350 name,
1351 shape=None,
1352 dtype=None,
1353 initializer=None,
1354 regularizer=None,
1355 trainable=None,
1356 collections=None,
1357 caching_device=None,
1358 partitioner=None,
1359 validate_shape=True,
1360 use_resource=None,
1361 constraint=None,
1362 synchronization=VariableSynchronization.AUTO,
1363 aggregation=VariableAggregation.NONE):
1364 """Gets an existing variable with this name or create a new one."""
1365 if initializer is None:
1366 initializer = self._initializer
1367 if regularizer is None:
1368 regularizer = self._regularizer
1369 if constraint is None:
1370 constraint = self._constraint
1371 if caching_device is None:
1372 caching_device = self._caching_device
1373 if partitioner is None:
1374 partitioner = self._partitioner
1375 if dtype is None:
1376 dtype = self._dtype
1377 if use_resource is None:
1378 use_resource = self._use_resource
1380 if self._custom_getter is not None:
1381 raise ValueError(
1382 "Private access to _get_partitioned_variable is not allowed when "
1383 "a custom getter is set. Current custom getter: %s. "
1384 "It is likely that you're using create_partitioned_variables. "
1385 "If so, consider instead using get_variable with a non-empty "
1386 "partitioner parameter instead." % self._custom_getter)
1388 if partitioner is None:
1389 raise ValueError("No partitioner was specified")
1391 # This allows the variable scope name to be used as the variable name if
1392 # this function is invoked with an empty name arg, for backward
1393 # compatibility with create_partitioned_variables().
1394 full_name_list = []
1395 if self.name:
1396 full_name_list.append(self.name)
1397 if name:
1398 full_name_list.append(name)
1399 full_name = "/".join(full_name_list)
1401 # Variable names only depend on variable_scope (full_name here),
1402 # not name_scope, so we reset it below for the time of variable creation.
1403 with ops.name_scope(None, skip_on_eager=False):
1404 # pylint: disable=protected-access
1405 return var_store._get_partitioned_variable(
1406 full_name,
1407 shape=shape,
1408 dtype=dtype,
1409 initializer=initializer,
1410 regularizer=regularizer,
1411 reuse=self.reuse,
1412 trainable=trainable,
1413 collections=collections,
1414 caching_device=caching_device,
1415 partitioner=partitioner,
1416 validate_shape=validate_shape,
1417 use_resource=use_resource,
1418 constraint=constraint,
1419 synchronization=synchronization,
1420 aggregation=aggregation)
1421 # pylint: enable=protected-access
1424_VARSTORE_KEY = ("__variable_store",)
1425_VARSCOPESTORE_KEY = ("__varscope",)
1428class _VariableScopeStore(threading.local):
1429 """A thread local store for the current variable scope and scope counts."""
1431 def __init__(self):
1432 super(_VariableScopeStore, self).__init__()
1433 self.current_scope = VariableScope(False)
1434 self.variable_scopes_count = {}
1436 def open_variable_scope(self, scope_name):
1437 if scope_name in self.variable_scopes_count:
1438 self.variable_scopes_count[scope_name] += 1
1439 else:
1440 self.variable_scopes_count[scope_name] = 1
1442 def close_variable_subscopes(self, scope_name):
1443 if scope_name is None:
1444 for k in self.variable_scopes_count:
1445 self.variable_scopes_count[k] = 0
1446 else:
1447 startswith_check = scope_name + "/"
1448 startswith_len = len(startswith_check)
1449 for k in self.variable_scopes_count:
1450 if k[:startswith_len] == startswith_check:
1451 self.variable_scopes_count[k] = 0
1453 def variable_scope_count(self, scope_name):
1454 return self.variable_scopes_count.get(scope_name, 0)
1457def get_variable_scope_store():
1458 """Returns the variable scope store for current thread."""
1459 scope_store = ops.get_collection(_VARSCOPESTORE_KEY)
1461 if not scope_store:
1462 scope_store = _VariableScopeStore()
1463 ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store)
1464 else:
1465 scope_store = scope_store[0]
1467 return scope_store
1470@tf_export(v1=["get_variable_scope"])
1471def get_variable_scope():
1472 """Returns the current variable scope.
1474 @compatibility(TF2)
1475 Although it is a legacy `compat.v1` api,
1476 `tf.compat.v1.get_variable` is compatible with eager
1477 execution and `tf.function`
1479 However, to maintain variable-scope based variable reuse
1480 you will need to combine it with
1481 `tf.compat.v1.keras.utils.track_tf1_style_variables`. (Though
1482 it will behave as if reuse is always set to `tf.compat.v1.AUTO_REUSE`.)
1484 See the
1485 [migration guide](https://www.tensorflow.org/guide/migrate/model_mapping)
1486 for more info.
1488 The TF2 equivalent, if you are just trying to track
1489 variable name prefixes and not control `get_variable`-based variable reuse,
1490 would be to use `tf.name_scope` and capture the output of opening the
1491 scope (which represents the current name prefix).
1493 For example:
1494 ```python
1495 x = tf.name_scope('foo') as current_scope:
1496 ...
1497 ```
1498 @end_compatibility
1499 """
1500 return get_variable_scope_store().current_scope
1503def _get_default_variable_store():
1504 store = ops.get_collection(_VARSTORE_KEY)
1505 if store:
1506 return store[0]
1507 store = _VariableStore()
1508 ops.add_to_collection(_VARSTORE_KEY, store)
1509 return store
1512@tf_contextlib.contextmanager
1513def with_variable_store(store):
1514 store_collection = ops.get_collection_ref(_VARSTORE_KEY)
1515 old = list(store_collection)
1516 store_collection[:] = [store]
1517 try:
1518 yield
1519 finally:
1520 store_collection[:] = old
1523class EagerVariableStore:
1524 """Wrapper allowing functional layers to be used with eager execution.
1526 When eager execution is enabled Variables get deleted when they go out of
1527 scope, and are not stored in global collections by default. A lot of code
1528 (mostly the functional layers in tf.layers) assumes that variables are kept in
1529 a global list.
1531 EagerVariableStore can be used in conjunction with this code to make it
1532 eager-friendly. For example, to create a dense layer, use:
1534 ```
1535 container = tfe.EagerVariableStore()
1536 for input in dataset_iterator:
1537 with container.as_default():
1538 x = tf.compat.v1.layers.dense(input, name="l1")
1539 print(container.variables) # Should print the variables used in the layer.
1540 ```
1541 """
1543 def __init__(self, store=None):
1544 if store is not None:
1545 if not store._store_eager_variables: # pylint: disable=protected-access
1546 raise ValueError("Cannot construct EagerVariableStore from a "
1547 "VariableStore object that does not hold eager "
1548 "variables.")
1549 self._store = store
1550 else:
1551 self._store = _VariableStore()
1552 self._store._store_eager_variables = True # pylint: disable=protected-access
1554 def as_default(self):
1555 return with_variable_store(self._store)
1557 def variables(self):
1558 return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access
1560 def trainable_variables(self):
1561 # pylint: disable=protected-access
1562 return sorted([x for x in self._store._vars.values() if x.trainable],
1563 key=lambda x: x.name)
1564 # pylint: enable=protected-access
1566 def non_trainable_variables(self):
1567 # pylint: disable=protected-access
1568 return sorted([x for x in self._store._vars.values() if not x.trainable],
1569 key=lambda x: x.name)
1570 # pylint: enable=protected-access
1572 def copy(self):
1573 """Copy this variable store and all of its contents.
1575 Variables contained in this store will be copied over to the new variable
1576 store, meaning that they can be modified without affecting the variables in
1577 this store.
1579 Returns:
1580 A new EagerVariableStore instance containing copied variables.
1581 """
1582 # pylint: disable=protected-access
1583 new_store = EagerVariableStore()
1584 for key, var in self._store._vars.items():
1585 # Strip device out of variable name.
1586 try:
1587 index = var.name.index(":")
1588 except ValueError:
1589 stripped_var_name = var.name
1590 else:
1591 stripped_var_name = var.name[:index]
1593 # Create new variable with same value, name, and "trainable" flag.
1594 new_var = resource_variable_ops.ResourceVariable(
1595 var.read_value(), name=stripped_var_name, trainable=var.trainable)
1596 new_store._store._vars[key] = new_var
1597 return new_store
1598 # pylint: enable=protected-access
1601# The argument list for get_variable must match arguments to get_local_variable.
1602# So, if you are updating the arguments, also update arguments to
1603# get_local_variable below.
1604@tf_export(v1=["get_variable"])
1605def get_variable(name,
1606 shape=None,
1607 dtype=None,
1608 initializer=None,
1609 regularizer=None,
1610 trainable=None,
1611 collections=None,
1612 caching_device=None,
1613 partitioner=None,
1614 validate_shape=True,
1615 use_resource=None,
1616 custom_getter=None,
1617 constraint=None,
1618 synchronization=VariableSynchronization.AUTO,
1619 aggregation=VariableAggregation.NONE):
1620 return get_variable_scope().get_variable(
1621 _get_default_variable_store(),
1622 name,
1623 shape=shape,
1624 dtype=dtype,
1625 initializer=initializer,
1626 regularizer=regularizer,
1627 trainable=trainable,
1628 collections=collections,
1629 caching_device=caching_device,
1630 partitioner=partitioner,
1631 validate_shape=validate_shape,
1632 use_resource=use_resource,
1633 custom_getter=custom_getter,
1634 constraint=constraint,
1635 synchronization=synchronization,
1636 aggregation=aggregation)
1639get_variable_or_local_docstring = ("""%s
1641@compatibility(TF2)
1642Although it is a legacy `compat.v1` api,
1643`tf.compat.v1.get_variable` is mostly compatible with eager
1644execution and `tf.function` but only if you combine it with the
1645`tf.compat.v1.keras.utils.track_tf1_style_variables` decorator. (Though
1646it will behave as if reuse is always set to `AUTO_REUSE`.)
1648See the
1649[model migration guide](https://www.tensorflow.org/guide/migrate/model_mapping)
1650for more info.
1652If you do not combine it with
1653`tf.compat.v1.keras.utils.track_tf1_style_variables`, `get_variable` will create
1654a brand new variable every single time it is called and will never reuse
1655variables, regardless of variable names or `reuse` arguments.
1657The TF2 equivalent of this symbol would be `tf.Variable`, but note
1658that when using `tf.Variable` you must make sure you track your variables
1659(and regularizer arguments) either manually or via `tf.Module` or
1660`tf.keras.layers.Layer` mechanisms.
1662A section of the
1663[migration guide](https://www.tensorflow.org/guide/migrate/model_mapping#incremental_migration_to_native_tf2)
1664provides more details on incrementally migrating these usages to `tf.Variable`
1665as well.
1667Note: The `partitioner` arg is not compatible with TF2 behaviors even when
1668using `tf.compat.v1.keras.utils.track_tf1_style_variables`. It can be replaced
1669by using `ParameterServerStrategy` and its partitioners. See the
1670[multi-gpu migration guide](https://www.tensorflow.org/guide/migrate/multi_worker_cpu_gpu_training)
1671and the ParameterServerStrategy guides it references for more info.
1672@end_compatibility
1674%sThis function prefixes the name with the current variable scope
1675and performs reuse checks. See the
1676[Variable Scope How To](https://tensorflow.org/guide/variables)
1677for an extensive description of how reusing works. Here is a basic example:
1679```python
1680def foo():
1681 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
1682 v = tf.get_variable("v", [1])
1683 return v
1685v1 = foo() # Creates v.
1686v2 = foo() # Gets the same, existing v.
1687assert v1 == v2
1688```
1690If initializer is `None` (the default), the default initializer passed in
1691the variable scope will be used. If that one is `None` too, a
1692`glorot_uniform_initializer` will be used. The initializer can also be
1693a Tensor, in which case the variable is initialized to this value and shape.
1695Similarly, if the regularizer is `None` (the default), the default regularizer
1696passed in the variable scope will be used (if that is `None` too,
1697then by default no regularization is performed).
1699If a partitioner is provided, a `PartitionedVariable` is returned.
1700Accessing this object as a `Tensor` returns the shards concatenated along
1701the partition axis.
1703Some useful partitioners are available. See, e.g.,
1704`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1706Args:
1707 name: The name of the new or existing variable.
1708 shape: Shape of the new or existing variable.
1709 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1710 initializer: Initializer for the variable if one is created. Can either be
1711 an initializer object or a Tensor. If it's a Tensor, its shape must be known
1712 unless validate_shape is False.
1713 regularizer: A (Tensor -> Tensor or None) function; the result of
1714 applying it on a newly created variable will be added to the collection
1715 `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization.
1716 %scollections: List of graph collections keys to add the Variable to.
1717 Defaults to `[%s]` (see `tf.Variable`).
1718 caching_device: Optional device string or function describing where the
1719 Variable should be cached for reading. Defaults to the Variable's
1720 device. If not `None`, caches on another device. Typical use is to
1721 cache on the device where the Ops using the Variable reside, to
1722 deduplicate copying through `Switch` and other conditional statements.
1723 partitioner: Optional callable that accepts a fully defined `TensorShape`
1724 and `dtype` of the Variable to be created, and returns a list of
1725 partitions for each axis (currently only one axis can be partitioned).
1726 validate_shape: If False, allows the variable to be initialized with a
1727 value of unknown shape. If True, the default, the shape of initial_value
1728 must be known. For this to be used the initializer must be a Tensor and
1729 not an initializer object.
1730 use_resource: If False, creates a regular Variable. If true, creates an
1731 experimental ResourceVariable instead with well-defined semantics.
1732 Defaults to False (will later change to True). When eager execution is
1733 enabled this argument is always forced to be True.
1734 custom_getter: Callable that takes as a first argument the true getter, and
1735 allows overwriting the internal get_variable method.
1736 The signature of `custom_getter` should match that of this method,
1737 but the most future-proof version will allow for changes:
1738 `def custom_getter(getter, *args, **kwargs)`. Direct access to
1739 all `get_variable` parameters is also allowed:
1740 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity
1741 custom getter that simply creates variables with modified names is:
1742 ```python
1743 def custom_getter(getter, name, *args, **kwargs):
1744 return getter(name + '_suffix', *args, **kwargs)
1745 ```
1746 constraint: An optional projection function to be applied to the variable
1747 after being updated by an `Optimizer` (e.g. used to implement norm
1748 constraints or value constraints for layer weights). The function must
1749 take as input the unprojected Tensor representing the value of the
1750 variable and return the Tensor for the projected value
1751 (which must have the same shape). Constraints are not safe to
1752 use when doing asynchronous distributed training.
1753 synchronization: Indicates when a distributed a variable will be
1754 aggregated. Accepted values are constants defined in the class
1755 `tf.VariableSynchronization`. By default the synchronization is set to
1756 `AUTO` and the current `DistributionStrategy` chooses
1757 when to synchronize.
1758 aggregation: Indicates how a distributed variable will be aggregated.
1759 Accepted values are constants defined in the class
1760 `tf.VariableAggregation`.
1762Returns:
1763 The created or existing `Variable` (or `PartitionedVariable`, if a
1764 partitioner was used).
1766Raises:
1767 ValueError: when creating a new variable and shape is not declared,
1768 when violating reuse during variable creation, or when `initializer` dtype
1769 and `dtype` don't match. Reuse is set inside `variable_scope`.
1770""")
1771get_variable.__doc__ = get_variable_or_local_docstring % (
1772 "Gets an existing variable with these parameters or create a new one.", "",
1773 "trainable: If `True` also add the variable to the graph collection\n"
1774 " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n ",
1775 "GraphKeys.GLOBAL_VARIABLES")
1778# The argument list for get_local_variable must match arguments to get_variable.
1779# So, if you are updating the arguments, also update arguments to get_variable.
1780@tf_export(v1=["get_local_variable"])
1781def get_local_variable( # pylint: disable=missing-docstring
1782 name,
1783 shape=None,
1784 dtype=None,
1785 initializer=None,
1786 regularizer=None,
1787 trainable=False, # pylint: disable=unused-argument
1788 collections=None,
1789 caching_device=None,
1790 partitioner=None,
1791 validate_shape=True,
1792 use_resource=None,
1793 custom_getter=None,
1794 constraint=None,
1795 synchronization=VariableSynchronization.AUTO,
1796 aggregation=VariableAggregation.NONE):
1797 if collections:
1798 collections += [ops.GraphKeys.LOCAL_VARIABLES]
1799 else:
1800 collections = [ops.GraphKeys.LOCAL_VARIABLES]
1801 return get_variable(
1802 name,
1803 shape=shape,
1804 dtype=dtype,
1805 initializer=initializer,
1806 regularizer=regularizer,
1807 trainable=False,
1808 collections=collections,
1809 caching_device=caching_device,
1810 partitioner=partitioner,
1811 validate_shape=validate_shape,
1812 use_resource=use_resource,
1813 synchronization=synchronization,
1814 aggregation=aggregation,
1815 custom_getter=custom_getter,
1816 constraint=constraint)
1819get_local_variable.__doc__ = get_variable_or_local_docstring % (
1820 "Gets an existing *local* variable or creates a new one.",
1821 "Behavior is the same as in `get_variable`, except that variables are\n"
1822 "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
1823 "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES")
1826def _get_partitioned_variable(name,
1827 shape=None,
1828 dtype=None,
1829 initializer=None,
1830 regularizer=None,
1831 trainable=True,
1832 collections=None,
1833 caching_device=None,
1834 partitioner=None,
1835 validate_shape=True,
1836 use_resource=None,
1837 constraint=None,
1838 synchronization=VariableSynchronization.AUTO,
1839 aggregation=VariableAggregation.NONE):
1840 """Gets or creates a sharded variable list with these parameters.
1842 The `partitioner` must be a callable that accepts a fully defined
1843 `TensorShape` and returns a sequence of integers (the `partitions`).
1844 These integers describe how to partition the given sharded `Variable`
1845 along the given dimension. That is, `partitions[1] = 3` means split
1846 the `Variable` into 3 shards along dimension 1. Currently, sharding along
1847 only one axis is supported.
1849 If the list of variables with the given name (prefix) is already stored,
1850 we return the stored variables. Otherwise, we create a new one.
1852 If initializer is `None` (the default), the default initializer passed in
1853 the constructor is used. If that one is `None` too, we use a new
1854 `glorot_uniform_initializer`. If initializer is a Tensor, we use
1855 it as a value and derive the shape from the initializer.
1857 If the initializer is a callable, then it will be called for each
1858 shard. Otherwise the initializer should match the shape of the entire
1859 sharded Variable, and it will be sliced accordingly for each shard.
1861 Some useful partitioners are available. See, e.g.,
1862 `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1864 Args:
1865 name: The name of the new or existing variable.
1866 shape: Shape of the new or existing variable.
1867 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1868 initializer: Initializer for the variable if one is created.
1869 regularizer: A (Tensor -> Tensor or None) function; the result of applying
1870 it on a newly created variable will be added to the collection
1871 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
1872 trainable: If `True` also add the variable to the graph collection
1873 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
1874 collections: List of graph collections keys to add the Variable to. Defaults
1875 to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
1876 caching_device: Optional device string or function describing where the
1877 Variable should be cached for reading. Defaults to the Variable's device.
1878 If not `None`, caches on another device. Typical use is to cache on the
1879 device where the Ops using the Variable reside, to deduplicate copying
1880 through `Switch` and other conditional statements.
1881 partitioner: Optional callable that accepts a fully defined `TensorShape`
1882 and `dtype` of the Variable to be created, and returns a list of
1883 partitions for each axis (currently only one axis can be partitioned).
1884 validate_shape: If False, allows the variable to be initialized with a value
1885 of unknown shape. If True, the default, the shape of initial_value must be
1886 known.
1887 use_resource: If False, creates a regular Variable. If True, creates an
1888 experimental ResourceVariable instead which has well-defined semantics.
1889 Defaults to False (will later change to True).
1890 constraint: An optional projection function to be applied to the variable
1891 after being updated by an `Optimizer` (e.g. used to implement norm
1892 constraints or value constraints for layer weights). The function must
1893 take as input the unprojected Tensor representing the value of the
1894 variable and return the Tensor for the projected value (which must have
1895 the same shape). Constraints are not safe to use when doing asynchronous
1896 distributed training.
1897 synchronization: Indicates when a distributed a variable will be aggregated.
1898 Accepted values are constants defined in the class
1899 `tf.VariableSynchronization`. By default the synchronization is set to
1900 `AUTO` and the current `DistributionStrategy` chooses when to synchronize.
1901 aggregation: Indicates how a distributed variable will be aggregated.
1902 Accepted values are constants defined in the class
1903 `tf.VariableAggregation`.
1905 Returns:
1906 A tuple `(shards, partitions)` where `shards` is the list of `Variable`
1907 shards and `partitions` is the output of the partitioner on the input
1908 shape.
1910 Raises:
1911 ValueError: when creating a new variable and shape is not declared,
1912 or when violating reuse during variable creation. Reuse is set inside
1913 `variable_scope`.
1914 """
1915 # pylint: disable=protected-access
1916 scope = get_variable_scope()
1917 if scope.custom_getter is not None:
1918 raise ValueError(
1919 "Private access to _get_partitioned_variable is not allowed when "
1920 "a custom getter is set. Current custom getter: %s. "
1921 "It is likely that you're using create_partitioned_variables. "
1922 "If so, consider instead using get_variable with a non-empty "
1923 "partitioner parameter instead." % scope.custom_getter)
1924 return scope._get_partitioned_variable(
1925 _get_default_variable_store(),
1926 name,
1927 shape=shape,
1928 dtype=dtype,
1929 initializer=initializer,
1930 regularizer=regularizer,
1931 trainable=trainable,
1932 collections=collections,
1933 caching_device=caching_device,
1934 partitioner=partitioner,
1935 validate_shape=validate_shape,
1936 use_resource=use_resource,
1937 constraint=constraint,
1938 synchronization=synchronization,
1939 aggregation=aggregation)
1940 # pylint: enable=protected-access
1943# Named like a function for compatibility with the previous
1944# @tf_contextlib.contextmanager definition.
1945class _pure_variable_scope: # pylint: disable=invalid-name
1946 """A context for the variable_scope, see `variable_scope` for docs."""
1948 def __init__(self,
1949 name_or_scope,
1950 reuse=None,
1951 initializer=None,
1952 regularizer=None,
1953 caching_device=None,
1954 partitioner=None,
1955 custom_getter=None,
1956 old_name_scope=None,
1957 dtype=dtypes.float32,
1958 use_resource=None,
1959 constraint=None):
1960 """Creates a context for the variable_scope, see `variable_scope` for docs.
1962 Note: this does not create a name scope.
1964 Args:
1965 name_or_scope: `string` or `VariableScope`: the scope to open.
1966 reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit
1967 the parent scope's reuse flag.
1968 initializer: default initializer for variables within this scope.
1969 regularizer: default regularizer for variables within this scope.
1970 caching_device: default caching device for variables within this scope.
1971 partitioner: default partitioner for variables within this scope.
1972 custom_getter: default custom getter for variables within this scope.
1973 old_name_scope: the original name scope when re-entering a variable scope.
1974 dtype: type of the variables within this scope (defaults to `DT_FLOAT`).
1975 use_resource: If False, variables in this scope will be regular Variables.
1976 If True, experimental ResourceVariables will be creates instead, with
1977 well-defined semantics. Defaults to False (will later change to True).
1978 constraint: An optional projection function to be applied to the variable
1979 after being updated by an `Optimizer` (e.g. used to implement norm
1980 constraints or value constraints for layer weights). The function must
1981 take as input the unprojected Tensor representing the value of the
1982 variable and return the Tensor for the projected value (which must have
1983 the same shape). Constraints are not safe to use when doing asynchronous
1984 distributed training.
1985 """
1986 self._name_or_scope = name_or_scope
1987 self._reuse = reuse
1988 self._initializer = initializer
1989 self._regularizer = regularizer
1990 self._caching_device = caching_device
1991 self._partitioner = partitioner
1992 self._custom_getter = custom_getter
1993 self._old_name_scope = old_name_scope
1994 self._dtype = dtype
1995 self._use_resource = use_resource
1996 self._constraint = constraint
1997 self._var_store = _get_default_variable_store()
1998 self._var_scope_store = get_variable_scope_store()
1999 self._last_variable_scope_object = None
2000 if isinstance(self._name_or_scope, VariableScope):
2001 self._new_name = self._name_or_scope.name
2002 name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access
2003 # Handler for the case when we jump to a shared scope. We create a new
2004 # VariableScope (self._var_scope_object) that contains a copy of the
2005 # provided shared scope, possibly with changed reuse and initializer, if
2006 # the user requested this.
2007 variable_scope_object = VariableScope(
2008 self._name_or_scope.reuse if not self._reuse else self._reuse,
2009 name=self._new_name,
2010 initializer=self._name_or_scope.initializer,
2011 regularizer=self._name_or_scope.regularizer,
2012 caching_device=self._name_or_scope.caching_device,
2013 partitioner=self._name_or_scope.partitioner,
2014 dtype=self._name_or_scope.dtype,
2015 custom_getter=self._name_or_scope.custom_getter,
2016 name_scope=name_scope,
2017 use_resource=self._name_or_scope.use_resource,
2018 constraint=self._constraint)
2019 if self._initializer is not None:
2020 variable_scope_object.set_initializer(self._initializer)
2021 if self._regularizer is not None:
2022 variable_scope_object.set_regularizer(self._regularizer)
2023 if self._caching_device is not None:
2024 variable_scope_object.set_caching_device(self._caching_device)
2025 if self._partitioner is not None:
2026 variable_scope_object.set_partitioner(self._partitioner)
2027 if self._custom_getter is not None:
2028 variable_scope_object.set_custom_getter(
2029 _maybe_wrap_custom_getter(self._custom_getter,
2030 self._name_or_scope.custom_getter))
2031 if self._dtype is not None:
2032 variable_scope_object.set_dtype(self._dtype)
2033 if self._use_resource is not None:
2034 variable_scope_object.set_use_resource(self._use_resource)
2035 self._cached_variable_scope_object = variable_scope_object
2037 def __enter__(self):
2038 """Begins the scope block.
2040 Returns:
2041 A VariableScope.
2042 Raises:
2043 ValueError: when trying to reuse within a create scope, or create within
2044 a reuse scope, or if reuse is not `None` or `True`.
2045 TypeError: when the types of some arguments are not appropriate.
2046 """
2047 self._old = self._var_scope_store.current_scope
2048 if isinstance(self._name_or_scope, VariableScope):
2049 self._var_scope_store.open_variable_scope(self._new_name)
2050 self._old_subscopes = copy.copy(
2051 self._var_scope_store.variable_scopes_count)
2052 variable_scope_object = self._cached_variable_scope_object
2053 else:
2054 # Handler for the case when we just prolong current variable scope.
2055 # VariableScope with name extended by the provided one, and inherited
2056 # reuse and initializer (except if the user provided values to set).
2057 self._new_name = (
2058 self._old.name + "/" +
2059 self._name_or_scope if self._old.name else self._name_or_scope)
2060 self._reuse = (self._reuse or
2061 self._old.reuse) # Re-using is inherited by sub-scopes.
2062 if self._old_name_scope is None:
2063 name_scope = self._name_or_scope
2064 else:
2065 name_scope = self._old_name_scope
2066 variable_scope_object = VariableScope(
2067 self._reuse,
2068 name=self._new_name,
2069 initializer=self._old.initializer,
2070 regularizer=self._old.regularizer,
2071 caching_device=self._old.caching_device,
2072 partitioner=self._old.partitioner,
2073 dtype=self._old.dtype,
2074 use_resource=self._old.use_resource,
2075 custom_getter=self._old.custom_getter,
2076 name_scope=name_scope,
2077 constraint=self._constraint)
2078 if self._initializer is not None:
2079 variable_scope_object.set_initializer(self._initializer)
2080 if self._regularizer is not None:
2081 variable_scope_object.set_regularizer(self._regularizer)
2082 if self._caching_device is not None:
2083 variable_scope_object.set_caching_device(self._caching_device)
2084 if self._partitioner is not None:
2085 variable_scope_object.set_partitioner(self._partitioner)
2086 if self._custom_getter is not None:
2087 variable_scope_object.set_custom_getter(
2088 _maybe_wrap_custom_getter(self._custom_getter,
2089 self._old.custom_getter))
2090 if self._dtype is not None:
2091 variable_scope_object.set_dtype(self._dtype)
2092 if self._use_resource is not None:
2093 variable_scope_object.set_use_resource(self._use_resource)
2094 self._var_scope_store.open_variable_scope(self._new_name)
2095 self._var_scope_store.current_scope = variable_scope_object
2096 self._last_variable_scope_object = variable_scope_object
2097 return variable_scope_object
2099 def __exit__(self, type_arg, value_arg, traceback_arg):
2100 if (self._var_scope_store.current_scope is
2101 not self._last_variable_scope_object):
2102 raise RuntimeError("Improper nesting of variable_scope.")
2103 # If jumping out from a non-prolonged scope, restore counts.
2104 if isinstance(self._name_or_scope, VariableScope):
2105 self._var_scope_store.variable_scopes_count = self._old_subscopes
2106 else:
2107 self._var_scope_store.close_variable_subscopes(self._new_name)
2108 self._var_scope_store.current_scope = self._old
2111def _maybe_wrap_custom_getter(custom_getter, old_getter):
2112 """Wrap a call to a custom_getter to use the old_getter internally."""
2113 if old_getter is None:
2114 return custom_getter
2116 # The new custom_getter should call the old one
2117 def wrapped_custom_getter(getter, *args, **kwargs):
2118 # Call:
2119 # custom_getter(
2120 # lambda: old_getter(true_getter, ...), *args, **kwargs)
2121 # which means custom_getter will call old_getter, which
2122 # will call the true_getter, perform any intermediate
2123 # processing, and return the results to the current
2124 # getter, which will also perform additional processing.
2125 return custom_getter(functools.partial(old_getter, getter), *args, **kwargs)
2127 return wrapped_custom_getter
2130def _get_unique_variable_scope(prefix):
2131 """Get a name with the given prefix unique in the current variable scope."""
2132 var_scope_store = get_variable_scope_store()
2133 current_scope = get_variable_scope()
2134 name = current_scope.name + "/" + prefix if current_scope.name else prefix
2135 if var_scope_store.variable_scope_count(name) == 0:
2136 return prefix
2137 idx = 1
2138 while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0:
2139 idx += 1
2140 return prefix + ("_%d" % idx)
2143# Named like a function for backwards compatibility with the
2144# @tf_contextlib.contextmanager version, which was switched to a class to avoid
2145# some object creation overhead.
2146@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name
2147class variable_scope:
2148 """A context manager for defining ops that creates variables (layers).
2150 @compatibility(TF2)
2151 Although it is a legacy `compat.v1` api,
2152 `tf.compat.v1.variable_scope` is mostly compatible with eager
2153 execution and `tf.function` as long as you combine it with the
2154 `tf.compat.v1.keras.utils.track_tf1_style_variables` decorator (though
2155 it will behave as if reuse is always set to `AUTO_REUSE`.)
2157 See the
2158 [model migration guide](
2159 https://www.tensorflow.org/guide/migrate/model_mapping)
2160 for more info on
2161 migrating code that relies on `variable_scope`-based variable reuse.
2163 When you use it with eager execution enabled but without
2164 `tf.compat.v1.keras.utils.track_tf1_style_variables`,
2165 `tf.compat.v1.variable_scope` will still be able to prefix the names
2166 of variables created within the scope but it will not enable variable reuse
2167 or error-raising checks around variable reuse (`get_variable` calls within
2168 it would always create new variables).
2170 Once you have switched away from `get_variable`-based variable reuse
2171 mechanisms, to switch to TF2 APIs you can just use
2172 `tf.name_scope` to prefix variable names.
2173 @end_compatibility
2175 This context manager validates that the (optional) `values` are from the same
2176 graph, ensures that graph is the default graph, and pushes a name scope and a
2177 variable scope.
2179 If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None,
2180 then `default_name` is used. In that case, if the same name has been
2181 previously used in the same scope, it will be made unique by appending `_N`
2182 to it.
2184 Variable scope allows you to create new variables and to share already created
2185 ones while providing checks to not create or share by accident. For details,
2186 see the [Variable Scope How To](https://tensorflow.org/guide/variables), here
2187 we present only a few basic examples.
2189 The Variable Scope works as expected when the Eager Execution is Disabled.
2191 ```python
2192 tf.compat.v1.disable_eager_execution()
2193 ```
2195 Simple example of how to create a new variable:
2197 ```python
2198 with tf.compat.v1.variable_scope("foo"):
2199 with tf.compat.v1.variable_scope("bar"):
2200 v = tf.compat.v1.get_variable("v", [1])
2201 assert v.name == "foo/bar/v:0"
2202 ```
2204 Simple example of how to reenter a premade variable scope safely:
2206 ```python
2207 with tf.compat.v1.variable_scope("foo") as vs:
2208 pass
2210 # Re-enter the variable scope.
2211 with tf.compat.v1.variable_scope(vs,
2212 auxiliary_name_scope=False) as vs1:
2213 # Restore the original name_scope.
2214 with tf.name_scope(vs1.original_name_scope):
2215 v = tf.compat.v1.get_variable("v", [1])
2216 assert v.name == "foo/v:0"
2217 c = tf.constant([1], name="c")
2218 assert c.name == "foo/c:0"
2219 ```
2221 Keep in mind that the counters for `default_name` are discarded once the
2222 parent scope is exited. Therefore when the code re-enters the scope (for
2223 instance by saving it), all nested default_name counters will be restarted.
2225 For instance:
2227 ```python
2228 with tf.compat.v1.variable_scope("foo") as vs:
2229 with tf.compat.v1.variable_scope(None, default_name="bar"):
2230 v = tf.compat.v1.get_variable("a", [1])
2231 assert v.name == "foo/bar/a:0", v.name
2232 with tf.compat.v1.variable_scope(None, default_name="bar"):
2233 v = tf.compat.v1.get_variable("b", [1])
2234 assert v.name == "foo/bar_1/b:0"
2236 with tf.compat.v1.variable_scope(vs):
2237 with tf.compat.v1.variable_scope(None, default_name="bar"):
2238 v = tf.compat.v1.get_variable("c", [1])
2239 assert v.name == "foo/bar/c:0" # Uses bar instead of bar_2!
2240 ```
2242 Basic example of sharing a variable AUTO_REUSE:
2244 ```python
2245 def foo():
2246 with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE):
2247 v = tf.compat.v1.get_variable("v", [1])
2248 return v
2250 v1 = foo() # Creates v.
2251 v2 = foo() # Gets the same, existing v.
2252 assert v1 == v2
2253 ```
2255 Basic example of sharing a variable with reuse=True:
2257 ```python
2258 with tf.compat.v1.variable_scope("foo"):
2259 v = tf.compat.v1.get_variable("v", [1])
2260 with tf.compat.v1.variable_scope("foo", reuse=True):
2261 v1 = tf.compat.v1.get_variable("v", [1])
2262 assert v1 == v
2263 ```
2265 Sharing a variable by capturing a scope and setting reuse:
2267 ```python
2268 with tf.compat.v1.variable_scope("foo") as scope:
2269 v = tf.compat.v1.get_variable("v", [1])
2270 scope.reuse_variables()
2271 v1 = tf.compat.v1.get_variable("v", [1])
2272 assert v1 == v
2273 ```
2275 To prevent accidental sharing of variables, we raise an exception when getting
2276 an existing variable in a non-reusing scope.
2278 ```python
2279 with tf.compat.v1.variable_scope("foo"):
2280 v = tf.compat.v1.get_variable("v", [1])
2281 v1 = tf.compat.v1.get_variable("v", [1])
2282 # Raises ValueError("... v already exists ...").
2283 ```
2285 Similarly, we raise an exception when trying to get a variable that does not
2286 exist in reuse mode.
2288 ```python
2289 with tf.compat.v1.variable_scope("foo", reuse=True):
2290 v = tf.compat.v1.get_variable("v", [1])
2291 # Raises ValueError("... v does not exists ...").
2292 ```
2294 Note that the `reuse` flag is inherited: if we open a reusing scope, then all
2295 its sub-scopes become reusing as well.
2297 A note about name scoping: Setting `reuse` does not impact the naming of other
2298 ops such as mult. See related discussion on
2299 [github#6189](https://github.com/tensorflow/tensorflow/issues/6189)
2301 Note that up to and including version 1.0, it was allowed (though explicitly
2302 discouraged) to pass False to the reuse argument, yielding undocumented
2303 behaviour slightly different from None. Starting at 1.1.0 passing None and
2304 False as reuse has exactly the same effect.
2306 A note about using variable scopes in multi-threaded environment: Variable
2307 scopes are thread local, so one thread will not see another thread's current
2308 scope. Also, when using `default_name`, unique scopes names are also generated
2309 only on a per thread basis. If the same name was used within a different
2310 thread, that doesn't prevent a new thread from creating the same scope.
2311 However, the underlying variable store is shared across threads (within the
2312 same graph). As such, if another thread tries to create a new variable with
2313 the same name as a variable created by a previous thread, it will fail unless
2314 reuse is True.
2316 Further, each thread starts with an empty variable scope. So if you wish to
2317 preserve name prefixes from a scope from the main thread, you should capture
2318 the main thread's scope and re-enter it in each thread. For e.g.
2320 ```
2321 main_thread_scope = variable_scope.get_variable_scope()
2323 # Thread's target function:
2324 def thread_target_fn(captured_scope):
2325 with variable_scope.variable_scope(captured_scope):
2326 # .... regular code for this thread
2329 thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,))
2330 ```
2331 """
2333 def __init__(self,
2334 name_or_scope,
2335 default_name=None,
2336 values=None,
2337 initializer=None,
2338 regularizer=None,
2339 caching_device=None,
2340 partitioner=None,
2341 custom_getter=None,
2342 reuse=None,
2343 dtype=None,
2344 use_resource=None,
2345 constraint=None,
2346 auxiliary_name_scope=True):
2347 """Initialize the context manager.
2349 Args:
2350 name_or_scope: `string` or `VariableScope`: the scope to open.
2351 default_name: The default name to use if the `name_or_scope` argument is
2352 `None`, this name will be uniquified. If name_or_scope is provided it
2353 won't be used and therefore it is not required and can be None.
2354 values: The list of `Tensor` arguments that are passed to the op function.
2355 initializer: default initializer for variables within this scope.
2356 regularizer: default regularizer for variables within this scope.
2357 caching_device: default caching device for variables within this scope.
2358 partitioner: default partitioner for variables within this scope.
2359 custom_getter: default custom getter for variables within this scope.
2360 reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into
2361 reuse mode for this scope as well as all sub-scopes; if
2362 tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and
2363 return them otherwise; if None, we inherit the parent scope's reuse
2364 flag. When eager execution is enabled, new variables are always created
2365 unless an EagerVariableStore or template is currently active.
2366 dtype: type of variables created in this scope (defaults to the type in
2367 the passed scope, or inherited from parent scope).
2368 use_resource: If False, all variables will be regular Variables. If True,
2369 experimental ResourceVariables with well-defined semantics will be used
2370 instead. Defaults to False (will later change to True). When eager
2371 execution is enabled this argument is always forced to be True.
2372 constraint: An optional projection function to be applied to the variable
2373 after being updated by an `Optimizer` (e.g. used to implement norm
2374 constraints or value constraints for layer weights). The function must
2375 take as input the unprojected Tensor representing the value of the
2376 variable and return the Tensor for the projected value (which must have
2377 the same shape). Constraints are not safe to use when doing asynchronous
2378 distributed training.
2379 auxiliary_name_scope: If `True`, we create an auxiliary name scope with
2380 the scope. If `False`, we don't create it. Note that the argument is not
2381 inherited, and it only takes effect for once when creating. You should
2382 only use it for re-entering a premade variable scope.
2384 Returns:
2385 A scope that can be captured and reused.
2387 Raises:
2388 ValueError: when trying to reuse within a create scope, or create within
2389 a reuse scope.
2390 TypeError: when the types of some arguments are not appropriate.
2391 """
2392 self._name_or_scope = name_or_scope
2393 self._default_name = default_name
2394 self._values = values
2395 self._initializer = initializer
2396 self._regularizer = regularizer
2397 self._caching_device = caching_device
2398 self._partitioner = partitioner
2399 self._custom_getter = custom_getter
2400 self._reuse = reuse
2401 self._dtype = dtype
2402 self._use_resource = use_resource
2403 self._constraint = constraint
2404 if self._default_name is None and self._name_or_scope is None:
2405 raise TypeError("If default_name is None then name_or_scope is required")
2406 if self._reuse is False:
2407 # We don't allow non-inheriting scopes, False = None here.
2408 self._reuse = None
2409 if not (self._reuse is True
2410 or self._reuse is None
2411 or self._reuse is AUTO_REUSE):
2412 raise ValueError("The reuse parameter must be True or False or None.")
2413 if self._values is None:
2414 self._values = []
2415 self._in_graph_mode = not context.executing_eagerly()
2416 if self._in_graph_mode:
2417 self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access
2418 self._cached_pure_variable_scope = None
2419 self._current_name_scope = None
2420 if not isinstance(auxiliary_name_scope, bool):
2421 raise TypeError("The auxiliary_name_scope must be `True` or `False`, "
2422 "while get {}".format(auxiliary_name_scope))
2423 self._auxiliary_name_scope = auxiliary_name_scope
2425 def __enter__(self):
2426 # If the default graph is building a function, then we should not replace it
2427 # with the cached graph.
2428 if ops.get_default_graph().building_function:
2429 self._building_function = True
2430 else:
2431 self._building_function = False
2432 if self._in_graph_mode and not self._building_function:
2433 self._graph_context_manager = self._graph.as_default()
2434 self._graph_context_manager.__enter__()
2435 if self._cached_pure_variable_scope is not None:
2436 # Fast path for re-entering variable_scopes. We've held on to the pure
2437 # variable scope from a previous successful __enter__, so we avoid some
2438 # overhead by re-using that object.
2439 if self._current_name_scope is not None:
2440 self._current_name_scope.__enter__()
2441 return self._cached_pure_variable_scope.__enter__()
2443 try:
2444 return self._enter_scope_uncached()
2445 except:
2446 if (self._in_graph_mode and not self._building_function and
2447 self._graph_context_manager is not None):
2448 self._graph_context_manager.__exit__(*sys.exc_info())
2449 raise
2451 def _enter_scope_uncached(self):
2452 """Enters the context manager when there is no cached scope yet.
2454 Returns:
2455 The entered variable scope.
2457 Raises:
2458 TypeError: A wrong type is passed as `scope` at __init__().
2459 ValueError: `reuse` is incorrectly set at __init__().
2460 """
2461 if self._auxiliary_name_scope:
2462 # Create a new name scope later
2463 current_name_scope = None
2464 else:
2465 # Reenter the current name scope
2466 name_scope = ops.get_name_scope()
2467 if name_scope:
2468 # Hack to reenter
2469 name_scope += "/"
2470 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2471 else:
2472 # Root scope
2473 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2475 # IMPORTANT: Only assign to self._cached_pure_variable_scope and
2476 # self._current_name_scope after successful __enter__() calls.
2477 if self._name_or_scope is not None:
2478 if not isinstance(self._name_or_scope, (VariableScope, str)):
2479 raise TypeError("VariableScope: name_or_scope must be a string or "
2480 "VariableScope.")
2481 if isinstance(self._name_or_scope, str):
2482 name_scope = self._name_or_scope
2483 else:
2484 name_scope = self._name_or_scope.name.split("/")[-1]
2485 if name_scope or current_name_scope:
2486 current_name_scope = current_name_scope or ops.name_scope(
2487 name_scope, skip_on_eager=False)
2488 try:
2489 current_name_scope_name = current_name_scope.__enter__()
2490 except:
2491 current_name_scope.__exit__(*sys.exc_info())
2492 raise
2493 self._current_name_scope = current_name_scope
2494 if isinstance(self._name_or_scope, str):
2495 old_name_scope = current_name_scope_name
2496 else:
2497 old_name_scope = self._name_or_scope.original_name_scope
2498 pure_variable_scope = _pure_variable_scope(
2499 self._name_or_scope,
2500 reuse=self._reuse,
2501 initializer=self._initializer,
2502 regularizer=self._regularizer,
2503 caching_device=self._caching_device,
2504 partitioner=self._partitioner,
2505 custom_getter=self._custom_getter,
2506 old_name_scope=old_name_scope,
2507 dtype=self._dtype,
2508 use_resource=self._use_resource,
2509 constraint=self._constraint)
2510 try:
2511 entered_pure_variable_scope = pure_variable_scope.__enter__()
2512 except:
2513 pure_variable_scope.__exit__(*sys.exc_info())
2514 raise
2515 self._cached_pure_variable_scope = pure_variable_scope
2516 return entered_pure_variable_scope
2517 else:
2518 self._current_name_scope = None
2519 # This can only happen if someone is entering the root variable scope.
2520 pure_variable_scope = _pure_variable_scope(
2521 self._name_or_scope,
2522 reuse=self._reuse,
2523 initializer=self._initializer,
2524 regularizer=self._regularizer,
2525 caching_device=self._caching_device,
2526 partitioner=self._partitioner,
2527 custom_getter=self._custom_getter,
2528 dtype=self._dtype,
2529 use_resource=self._use_resource,
2530 constraint=self._constraint)
2531 try:
2532 entered_pure_variable_scope = pure_variable_scope.__enter__()
2533 except:
2534 pure_variable_scope.__exit__(*sys.exc_info())
2535 raise
2536 self._cached_pure_variable_scope = pure_variable_scope
2537 return entered_pure_variable_scope
2539 else: # Here name_or_scope is None. Using default name, but made unique.
2540 if self._reuse:
2541 raise ValueError("reuse=True cannot be used without a name_or_scope")
2542 current_name_scope = current_name_scope or ops.name_scope(
2543 self._default_name, skip_on_eager=False)
2544 try:
2545 current_name_scope_name = current_name_scope.__enter__()
2546 except:
2547 current_name_scope.__exit__(*sys.exc_info())
2548 raise
2549 self._current_name_scope = current_name_scope
2550 unique_default_name = _get_unique_variable_scope(self._default_name)
2551 pure_variable_scope = _pure_variable_scope(
2552 unique_default_name,
2553 initializer=self._initializer,
2554 regularizer=self._regularizer,
2555 caching_device=self._caching_device,
2556 partitioner=self._partitioner,
2557 custom_getter=self._custom_getter,
2558 old_name_scope=current_name_scope_name,
2559 dtype=self._dtype,
2560 use_resource=self._use_resource,
2561 constraint=self._constraint)
2562 try:
2563 entered_pure_variable_scope = pure_variable_scope.__enter__()
2564 except:
2565 pure_variable_scope.__exit__(*sys.exc_info())
2566 raise
2567 self._cached_pure_variable_scope = pure_variable_scope
2568 return entered_pure_variable_scope
2570 def __exit__(self, type_arg, value_arg, traceback_arg):
2571 try:
2572 self._cached_pure_variable_scope.__exit__(type_arg, value_arg,
2573 traceback_arg)
2574 finally:
2575 try:
2576 if self._current_name_scope:
2577 self._current_name_scope.__exit__(type_arg, value_arg,
2578 traceback_arg)
2579 finally:
2580 if self._in_graph_mode and not self._building_function:
2581 self._graph_context_manager.__exit__(type_arg, value_arg,
2582 traceback_arg)
2585# pylint: disable=g-doc-return-or-yield
2586@tf_export(v1=["variable_op_scope"])
2587@tf_contextlib.contextmanager
2588def variable_op_scope(values,
2589 name_or_scope,
2590 default_name=None,
2591 initializer=None,
2592 regularizer=None,
2593 caching_device=None,
2594 partitioner=None,
2595 custom_getter=None,
2596 reuse=None,
2597 dtype=None,
2598 use_resource=None,
2599 constraint=None):
2600 """Deprecated: context manager for defining an op that creates variables."""
2601 logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated,"
2602 " use tf.variable_scope(name, default_name, values)")
2603 with variable_scope(
2604 name_or_scope,
2605 default_name=default_name,
2606 values=values,
2607 initializer=initializer,
2608 regularizer=regularizer,
2609 caching_device=caching_device,
2610 partitioner=partitioner,
2611 custom_getter=custom_getter,
2612 reuse=reuse,
2613 dtype=dtype,
2614 use_resource=use_resource,
2615 constraint=constraint) as scope:
2616 yield scope
2619def _call_partitioner(partitioner, shape, dtype):
2620 """Call partitioner validating its inputs/output.
2622 Args:
2623 partitioner: a function mapping `Tensor` shape and dtype to a list of
2624 partitions.
2625 shape: shape of the `Tensor` to partition, must have at least two
2626 dimensions.
2627 dtype: dtype of the elements in the `Tensor`.
2629 Returns:
2630 A list with elements >=1 and exactly one >1. The index of that
2631 element corresponds to the partitioning axis.
2632 """
2633 if not shape.is_fully_defined():
2634 raise ValueError("Shape of a new partitioned variable must be "
2635 "fully defined, but instead was %s." % (shape,))
2636 if shape.ndims < 1:
2637 raise ValueError("A partitioned Variable must have rank at least 1, "
2638 "shape: %s" % shape)
2640 slicing = partitioner(shape=shape, dtype=dtype)
2641 if not isinstance(slicing, collections_abc.Sequence):
2642 raise ValueError("Partitioner must return a sequence, but saw: %s" %
2643 slicing)
2644 if len(slicing) != shape.ndims:
2645 raise ValueError(
2646 "Partitioner returned a partition list that does not match the "
2647 "Variable's rank: %s vs. %s" % (slicing, shape))
2648 if any(p < 1 for p in slicing):
2649 raise ValueError("Partitioner returned zero partitions for some axes: %s" %
2650 slicing)
2651 if sum(p > 1 for p in slicing) > 1:
2652 raise ValueError("Can only slice a variable along one dimension: "
2653 "shape: %s, partitioning: %s" % (shape, slicing))
2654 return slicing
2657# TODO(slebedev): could be inlined, but
2658# `_VariableStore._get_partitioned_variable` is too complex even
2659# without this logic.
2660def _get_slice_dim_and_num_slices(slicing):
2661 """Get slicing dimension and number of slices from the partitioner output."""
2662 for slice_dim, num_slices in enumerate(slicing):
2663 if num_slices > 1:
2664 break
2665 else:
2666 # Degenerate case: no partitioning applied.
2667 slice_dim = 0
2668 num_slices = 1
2669 return slice_dim, num_slices
2672def _iter_slices(full_shape, num_slices, slice_dim):
2673 """Slices a given a shape along the specified dimension."""
2674 num_slices_with_excess = full_shape[slice_dim] % num_slices
2675 offset = [0] * len(full_shape)
2676 min_slice_len = full_shape[slice_dim] // num_slices
2677 for i in range(num_slices):
2678 shape = full_shape[:]
2679 shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess)
2680 yield offset[:], shape
2681 offset[slice_dim] += shape[slice_dim]
2684def _make_getter(captured_getter, captured_previous):
2685 """Gets around capturing loop variables in python being broken."""
2686 return lambda **kwargs: captured_getter(captured_previous, **kwargs)
2689# TODO(apassos) remove forwarding symbol
2690variable = variable_v1.VariableV1
2692# temporary references needed while refactors are in progress
2693default_variable_creator = ref_variable.default_variable_creator
2694_to_proto_fn = ref_variable._to_proto_fn # pylint: disable=protected-access
2695_from_proto_fn = ref_variable._from_proto_fn # pylint: disable=protected-access
2698@tf_export(v1=["variable_creator_scope"])
2699@tf_contextlib.contextmanager
2700def variable_creator_scope_v1(variable_creator):
2701 """Scope which defines a variable creation function to be used by variable().
2703 variable_creator is expected to be a function with the following signature:
2705 ```
2706 def variable_creator(next_creator, **kwargs)
2707 ```
2709 The creator is supposed to eventually call the next_creator to create a
2710 variable if it does want to create a variable and not call Variable or
2711 ResourceVariable directly. This helps make creators composable. A creator may
2712 choose to create multiple variables, return already existing variables, or
2713 simply register that a variable was created and defer to the next creators in
2714 line. Creators can also modify the keyword arguments seen by the next
2715 creators.
2717 Custom getters in the variable scope will eventually resolve down to these
2718 custom creators when they do create variables.
2720 The valid keyword arguments in kwds are:
2722 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2723 which is the initial value for the Variable. The initial value must have
2724 a shape specified unless `validate_shape` is set to False. Can also be a
2725 callable with no argument that returns the initial value when called. In
2726 that case, `dtype` must be specified. (Note that initializer functions
2727 from init_ops.py must first be bound to a shape before being used here.)
2728 * trainable: If `True`, the default, also adds the variable to the graph
2729 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
2730 the default list of variables to use by the `Optimizer` classes.
2731 `trainable` defaults to `True`, unless `synchronization` is
2732 set to `ON_READ`, in which case it defaults to `False`.
2733 * collections: List of graph collections keys. The new variable is added to
2734 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
2735 * validate_shape: If `False`, allows the variable to be initialized with a
2736 value of unknown shape. If `True`, the default, the shape of
2737 `initial_value` must be known.
2738 * caching_device: Optional device string describing where the Variable
2739 should be cached for reading. Defaults to the Variable's device.
2740 If not `None`, caches on another device. Typical use is to cache
2741 on the device where the Ops using the Variable reside, to deduplicate
2742 copying through `Switch` and other conditional statements.
2743 * name: Optional name for the variable. Defaults to `'Variable'` and gets
2744 uniquified automatically.
2745 * dtype: If set, initial_value will be converted to the given type.
2746 If `None`, either the datatype will be kept (if `initial_value` is
2747 a Tensor), or `convert_to_tensor` will decide.
2748 * constraint: A constraint function to be applied to the variable after
2749 updates by some algorithms.
2750 * use_resource: if True, a ResourceVariable is always created.
2751 * synchronization: Indicates when a distributed a variable will be
2752 aggregated. Accepted values are constants defined in the class
2753 `tf.VariableSynchronization`. By default the synchronization is set to
2754 `AUTO` and the current `DistributionStrategy` chooses
2755 when to synchronize.
2756 * aggregation: Indicates how a distributed variable will be aggregated.
2757 Accepted values are constants defined in the class
2758 `tf.VariableAggregation`.
2760 This set may grow over time, so it's important the signature of creators is as
2761 mentioned above.
2763 Args:
2764 variable_creator: the passed creator
2766 Yields:
2767 A scope in which the creator is active
2768 """
2769 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
2770 yield
2773# Note: only the docstrings differ between this and v1.
2774@tf_export("variable_creator_scope", v1=[])
2775@tf_contextlib.contextmanager
2776def variable_creator_scope(variable_creator):
2777 """Scope which defines a variable creation function to be used by variable().
2779 variable_creator is expected to be a function with the following signature:
2781 ```
2782 def variable_creator(next_creator, **kwargs)
2783 ```
2785 The creator is supposed to eventually call the next_creator to create a
2786 variable if it does want to create a variable and not call Variable or
2787 ResourceVariable directly. This helps make creators composable. A creator may
2788 choose to create multiple variables, return already existing variables, or
2789 simply register that a variable was created and defer to the next creators in
2790 line. Creators can also modify the keyword arguments seen by the next
2791 creators.
2793 Custom getters in the variable scope will eventually resolve down to these
2794 custom creators when they do create variables.
2796 The valid keyword arguments in kwds are:
2798 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2799 which is the initial value for the Variable. The initial value must have
2800 a shape specified unless `validate_shape` is set to False. Can also be a
2801 callable with no argument that returns the initial value when called. In
2802 that case, `dtype` must be specified. (Note that initializer functions
2803 from init_ops.py must first be bound to a shape before being used here.)
2804 * trainable: If `True`, the default, GradientTapes automatically watch
2805 uses of this Variable.
2806 * validate_shape: If `False`, allows the variable to be initialized with a
2807 value of unknown shape. If `True`, the default, the shape of
2808 `initial_value` must be known.
2809 * caching_device: Optional device string describing where the Variable
2810 should be cached for reading. Defaults to the Variable's device.
2811 If not `None`, caches on another device. Typical use is to cache
2812 on the device where the Ops using the Variable reside, to deduplicate
2813 copying through `Switch` and other conditional statements.
2814 * name: Optional name for the variable. Defaults to `'Variable'` and gets
2815 uniquified automatically.
2816 dtype: If set, initial_value will be converted to the given type.
2817 If `None`, either the datatype will be kept (if `initial_value` is
2818 a Tensor), or `convert_to_tensor` will decide.
2819 * constraint: A constraint function to be applied to the variable after
2820 updates by some algorithms.
2821 * synchronization: Indicates when a distributed a variable will be
2822 aggregated. Accepted values are constants defined in the class
2823 `tf.VariableSynchronization`. By default the synchronization is set to
2824 `AUTO` and the current `DistributionStrategy` chooses
2825 when to synchronize.
2826 * aggregation: Indicates how a distributed variable will be aggregated.
2827 Accepted values are constants defined in the class
2828 `tf.VariableAggregation`.
2830 This set may grow over time, so it's important the signature of creators is as
2831 mentioned above.
2833 Args:
2834 variable_creator: the passed creator
2836 Yields:
2837 A scope in which the creator is active
2838 """
2839 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
2840 yield