Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/variable_v1.py: 62%
37 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""VariableV1 class."""
17from tensorflow.python.framework import ops
18from tensorflow.python.ops import cond
19from tensorflow.python.ops import state_ops
20from tensorflow.python.ops import variables
21from tensorflow.python.util import tf_should_use
22from tensorflow.python.util.tf_export import tf_export
25_variable_from_proto_fn = None
28def set_variable_from_proto_fn(variable_from_proto_fn):
29 """Set the variable class that variable proto defs will be converted to."""
30 global _variable_from_proto_fn
31 _variable_from_proto_fn = variable_from_proto_fn
34@tf_export(v1=["is_variable_initialized"])
35@tf_should_use.should_use_result
36def is_variable_initialized(variable):
37 """Tests if a variable has been initialized.
39 Args:
40 variable: A `Variable`.
42 Returns:
43 Returns a scalar boolean Tensor, `True` if the variable has been
44 initialized, `False` otherwise.
45 """
46 return state_ops.is_variable_initialized(variable)
49def default_variable_creator(_, **kwds):
50 del kwds
51 raise NotImplementedError("ref_variable needs to be imported")
54@tf_export(v1=["Variable"])
55class VariableV1(variables.Variable):
56 """See the [Variables Guide](https://tensorflow.org/guide/variables).
58 A variable maintains state in the graph across calls to `run()`. You add a
59 variable to the graph by constructing an instance of the class `Variable`.
61 The `Variable()` constructor requires an initial value for the variable,
62 which can be a `Tensor` of any type and shape. The initial value defines the
63 type and shape of the variable. After construction, the type and shape of
64 the variable are fixed. The value can be changed using one of the assign
65 methods.
67 If you want to change the shape of a variable later you have to use an
68 `assign` Op with `validate_shape=False`.
70 Just like any `Tensor`, variables created with `Variable()` can be used as
71 inputs for other Ops in the graph. Additionally, all the operators
72 overloaded for the `Tensor` class are carried over to variables, so you can
73 also add nodes to the graph by just doing arithmetic on variables.
75 ```python
76 import tensorflow as tf
78 # Create a variable.
79 w = tf.Variable(<initial-value>, name=<optional-name>)
81 # Use the variable in the graph like any Tensor.
82 y = tf.matmul(w, ...another variable or tensor...)
84 # The overloaded operators are available too.
85 z = tf.sigmoid(w + y)
87 # Assign a new value to the variable with `assign()` or a related method.
88 w.assign(w + 1.0)
89 w.assign_add(1.0)
90 ```
92 When you launch the graph, variables have to be explicitly initialized before
93 you can run Ops that use their value. You can initialize a variable by
94 running its *initializer op*, restoring the variable from a save file, or
95 simply running an `assign` Op that assigns a value to the variable. In fact,
96 the variable *initializer op* is just an `assign` Op that assigns the
97 variable's initial value to the variable itself.
99 ```python
100 # Launch the graph in a session.
101 with tf.compat.v1.Session() as sess:
102 # Run the variable initializer.
103 sess.run(w.initializer)
104 # ...you now can run ops that use the value of 'w'...
105 ```
107 The most common initialization pattern is to use the convenience function
108 `global_variables_initializer()` to add an Op to the graph that initializes
109 all the variables. You then run that Op after launching the graph.
111 ```python
112 # Add an Op to initialize global variables.
113 init_op = tf.compat.v1.global_variables_initializer()
115 # Launch the graph in a session.
116 with tf.compat.v1.Session() as sess:
117 # Run the Op that initializes global variables.
118 sess.run(init_op)
119 # ...you can now run any Op that uses variable values...
120 ```
122 If you need to create a variable with an initial value dependent on another
123 variable, use the other variable's `initialized_value()`. This ensures that
124 variables are initialized in the right order.
126 All variables are automatically collected in the graph where they are
127 created. By default, the constructor adds the new variable to the graph
128 collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
129 `global_variables()` returns the contents of that collection.
131 When building a machine learning model it is often convenient to distinguish
132 between variables holding the trainable model parameters and other variables
133 such as a `global step` variable used to count training steps. To make this
134 easier, the variable constructor supports a `trainable=<bool>` parameter. If
135 `True`, the new variable is also added to the graph collection
136 `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
137 `trainable_variables()` returns the contents of this collection. The
138 various `Optimizer` classes use this collection as the default list of
139 variables to optimize.
141 WARNING: tf.Variable objects by default have a non-intuitive memory model. A
142 Variable is represented internally as a mutable Tensor which can
143 non-deterministically alias other Tensors in a graph. The set of operations
144 which consume a Variable and can lead to aliasing is undetermined and can
145 change across TensorFlow versions. Avoid writing code which relies on the
146 value of a Variable either changing or not changing as other operations
147 happen. For example, using Variable objects or simple functions thereof as
148 predicates in a `tf.cond` is dangerous and error-prone:
150 ```
151 v = tf.Variable(True)
152 tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken.
153 ```
155 Here, adding `use_resource=True` when constructing the variable will
156 fix any nondeterminism issues:
157 ```
158 v = tf.Variable(True, use_resource=True)
159 tf.cond(v, lambda: v.assign(False), my_false_fn)
160 ```
162 To use the replacement for variables which does
163 not have these issues:
165 * Add `use_resource=True` when constructing `tf.Variable`;
166 * Call `tf.compat.v1.get_variable_scope().set_use_resource(True)` inside a
167 `tf.compat.v1.variable_scope` before the `tf.compat.v1.get_variable()` call.
168 """
170 def __init__(
171 self, # pylint: disable=super-init-not-called
172 initial_value=None,
173 trainable=None,
174 collections=None,
175 validate_shape=True,
176 caching_device=None,
177 name=None,
178 variable_def=None,
179 dtype=None,
180 expected_shape=None,
181 import_scope=None,
182 constraint=None,
183 use_resource=None,
184 synchronization=variables.VariableSynchronization.AUTO,
185 aggregation=variables.VariableAggregation.NONE,
186 shape=None):
187 """Creates a new variable with value `initial_value`.
189 The new variable is added to the graph collections listed in `collections`,
190 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
192 If `trainable` is `True` the variable is also added to the graph collection
193 `GraphKeys.TRAINABLE_VARIABLES`.
195 This constructor creates both a `variable` Op and an `assign` Op to set the
196 variable to its initial value.
198 Args:
199 initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
200 which is the initial value for the Variable. The initial value must have
201 a shape specified unless `validate_shape` is set to False. Can also be a
202 callable with no argument that returns the initial value when called. In
203 that case, `dtype` must be specified. (Note that initializer functions
204 from init_ops.py must first be bound to a shape before being used here.)
205 trainable: If `True`, also adds the variable to the graph collection
206 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
207 list of variables to use by the `Optimizer` classes. Defaults to `True`,
208 unless `synchronization` is set to `ON_READ`, in which case it defaults
209 to `False`.
210 collections: List of graph collections keys. The new variable is added to
211 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
212 validate_shape: If `False`, allows the variable to be initialized with a
213 value of unknown shape. If `True`, the default, the shape of
214 `initial_value` must be known.
215 caching_device: Optional device string describing where the Variable
216 should be cached for reading. Defaults to the Variable's device. If not
217 `None`, caches on another device. Typical use is to cache on the device
218 where the Ops using the Variable reside, to deduplicate copying through
219 `Switch` and other conditional statements.
220 name: Optional name for the variable. Defaults to `'Variable'` and gets
221 uniquified automatically.
222 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
223 Variable object with its contents, referencing the variable's nodes in
224 the graph, which must already exist. The graph is not changed.
225 `variable_def` and the other arguments are mutually exclusive.
226 dtype: If set, initial_value will be converted to the given type. If
227 `None`, either the datatype will be kept (if `initial_value` is a
228 Tensor), or `convert_to_tensor` will decide.
229 expected_shape: A TensorShape. If set, initial_value is expected to have
230 this shape.
231 import_scope: Optional `string`. Name scope to add to the `Variable.` Only
232 used when initializing from protocol buffer.
233 constraint: An optional projection function to be applied to the variable
234 after being updated by an `Optimizer` (e.g. used to implement norm
235 constraints or value constraints for layer weights). The function must
236 take as input the unprojected Tensor representing the value of the
237 variable and return the Tensor for the projected value (which must have
238 the same shape). Constraints are not safe to use when doing asynchronous
239 distributed training.
240 use_resource: whether to use resource variables.
241 synchronization: Indicates when a distributed a variable will be
242 aggregated. Accepted values are constants defined in the class
243 `tf.VariableSynchronization`. By default the synchronization is set to
244 `AUTO` and the current `DistributionStrategy` chooses when to
245 synchronize.
246 aggregation: Indicates how a distributed variable will be aggregated.
247 Accepted values are constants defined in the class
248 `tf.VariableAggregation`.
249 shape: (optional) The shape of this variable. If None, the shape of
250 `initial_value` will be used. When setting this argument to
251 `tf.TensorShape(None)` (representing an unspecified shape), the variable
252 can be assigned with values of different shapes.
254 Raises:
255 ValueError: If both `variable_def` and initial_value are specified.
256 ValueError: If the initial value is not specified, or does not have a
257 shape and `validate_shape` is `True`.
258 RuntimeError: If eager execution is enabled.
259 """
261 SaveSliceInfo = variables.Variable.SaveSliceInfo
263 def initialized_value(self):
264 with ops.init_scope():
265 return cond.cond(
266 is_variable_initialized(self), self.read_value,
267 lambda: self.initial_value)
269 @staticmethod
270 def from_proto(variable_def, import_scope=None):
271 return _variable_from_proto_fn(
272 variable_def=variable_def, import_scope=import_scope)
274 @classmethod
275 def _variable_call(
276 cls,
277 initial_value=None,
278 trainable=None,
279 validate_shape=True,
280 caching_device=None,
281 name=None,
282 variable_def=None,
283 dtype=None,
284 import_scope=None,
285 constraint=None,
286 synchronization=variables.VariableSynchronization.AUTO,
287 aggregation=variables.VariableAggregation.NONE,
288 shape=None,
289 experimental_enable_variable_lifting=None,
290 expected_shape=None,
291 collections=None,
292 use_resource=None,
293 **kwargs,
294 ):
295 """VariableV1 class getter. Useful to force the signature."""
296 if cls is not VariableV1:
297 return None
298 previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
299 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
300 previous_getter = variables._make_getter(getter, previous_getter) # pylint: disable=protected-access
302 # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
303 if aggregation is None:
304 aggregation = variables.VariableAggregation.NONE
305 return previous_getter(
306 initial_value=initial_value,
307 trainable=trainable,
308 validate_shape=validate_shape,
309 caching_device=caching_device,
310 name=name,
311 variable_def=variable_def,
312 dtype=dtype,
313 import_scope=import_scope,
314 constraint=constraint,
315 synchronization=synchronization,
316 aggregation=aggregation,
317 shape=shape,
318 experimental_enable_variable_lifting=experimental_enable_variable_lifting,
319 expected_shape=expected_shape,
320 collections=collections,
321 use_resource=use_resource,
322 )