Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/variable_v1.py: 62%

37 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""VariableV1 class.""" 

16 

17from tensorflow.python.framework import ops 

18from tensorflow.python.ops import cond 

19from tensorflow.python.ops import state_ops 

20from tensorflow.python.ops import variables 

21from tensorflow.python.util import tf_should_use 

22from tensorflow.python.util.tf_export import tf_export 

23 

24 

25_variable_from_proto_fn = None 

26 

27 

28def set_variable_from_proto_fn(variable_from_proto_fn): 

29 """Set the variable class that variable proto defs will be converted to.""" 

30 global _variable_from_proto_fn 

31 _variable_from_proto_fn = variable_from_proto_fn 

32 

33 

34@tf_export(v1=["is_variable_initialized"]) 

35@tf_should_use.should_use_result 

36def is_variable_initialized(variable): 

37 """Tests if a variable has been initialized. 

38 

39 Args: 

40 variable: A `Variable`. 

41 

42 Returns: 

43 Returns a scalar boolean Tensor, `True` if the variable has been 

44 initialized, `False` otherwise. 

45 """ 

46 return state_ops.is_variable_initialized(variable) 

47 

48 

49def default_variable_creator(_, **kwds): 

50 del kwds 

51 raise NotImplementedError("ref_variable needs to be imported") 

52 

53 

54@tf_export(v1=["Variable"]) 

55class VariableV1(variables.Variable): 

56 """See the [Variables Guide](https://tensorflow.org/guide/variables). 

57 

58 A variable maintains state in the graph across calls to `run()`. You add a 

59 variable to the graph by constructing an instance of the class `Variable`. 

60 

61 The `Variable()` constructor requires an initial value for the variable, 

62 which can be a `Tensor` of any type and shape. The initial value defines the 

63 type and shape of the variable. After construction, the type and shape of 

64 the variable are fixed. The value can be changed using one of the assign 

65 methods. 

66 

67 If you want to change the shape of a variable later you have to use an 

68 `assign` Op with `validate_shape=False`. 

69 

70 Just like any `Tensor`, variables created with `Variable()` can be used as 

71 inputs for other Ops in the graph. Additionally, all the operators 

72 overloaded for the `Tensor` class are carried over to variables, so you can 

73 also add nodes to the graph by just doing arithmetic on variables. 

74 

75 ```python 

76 import tensorflow as tf 

77 

78 # Create a variable. 

79 w = tf.Variable(<initial-value>, name=<optional-name>) 

80 

81 # Use the variable in the graph like any Tensor. 

82 y = tf.matmul(w, ...another variable or tensor...) 

83 

84 # The overloaded operators are available too. 

85 z = tf.sigmoid(w + y) 

86 

87 # Assign a new value to the variable with `assign()` or a related method. 

88 w.assign(w + 1.0) 

89 w.assign_add(1.0) 

90 ``` 

91 

92 When you launch the graph, variables have to be explicitly initialized before 

93 you can run Ops that use their value. You can initialize a variable by 

94 running its *initializer op*, restoring the variable from a save file, or 

95 simply running an `assign` Op that assigns a value to the variable. In fact, 

96 the variable *initializer op* is just an `assign` Op that assigns the 

97 variable's initial value to the variable itself. 

98 

99 ```python 

100 # Launch the graph in a session. 

101 with tf.compat.v1.Session() as sess: 

102 # Run the variable initializer. 

103 sess.run(w.initializer) 

104 # ...you now can run ops that use the value of 'w'... 

105 ``` 

106 

107 The most common initialization pattern is to use the convenience function 

108 `global_variables_initializer()` to add an Op to the graph that initializes 

109 all the variables. You then run that Op after launching the graph. 

110 

111 ```python 

112 # Add an Op to initialize global variables. 

113 init_op = tf.compat.v1.global_variables_initializer() 

114 

115 # Launch the graph in a session. 

116 with tf.compat.v1.Session() as sess: 

117 # Run the Op that initializes global variables. 

118 sess.run(init_op) 

119 # ...you can now run any Op that uses variable values... 

120 ``` 

121 

122 If you need to create a variable with an initial value dependent on another 

123 variable, use the other variable's `initialized_value()`. This ensures that 

124 variables are initialized in the right order. 

125 

126 All variables are automatically collected in the graph where they are 

127 created. By default, the constructor adds the new variable to the graph 

128 collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function 

129 `global_variables()` returns the contents of that collection. 

130 

131 When building a machine learning model it is often convenient to distinguish 

132 between variables holding the trainable model parameters and other variables 

133 such as a `global step` variable used to count training steps. To make this 

134 easier, the variable constructor supports a `trainable=<bool>` parameter. If 

135 `True`, the new variable is also added to the graph collection 

136 `GraphKeys.TRAINABLE_VARIABLES`. The convenience function 

137 `trainable_variables()` returns the contents of this collection. The 

138 various `Optimizer` classes use this collection as the default list of 

139 variables to optimize. 

140 

141 WARNING: tf.Variable objects by default have a non-intuitive memory model. A 

142 Variable is represented internally as a mutable Tensor which can 

143 non-deterministically alias other Tensors in a graph. The set of operations 

144 which consume a Variable and can lead to aliasing is undetermined and can 

145 change across TensorFlow versions. Avoid writing code which relies on the 

146 value of a Variable either changing or not changing as other operations 

147 happen. For example, using Variable objects or simple functions thereof as 

148 predicates in a `tf.cond` is dangerous and error-prone: 

149 

150 ``` 

151 v = tf.Variable(True) 

152 tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. 

153 ``` 

154 

155 Here, adding `use_resource=True` when constructing the variable will 

156 fix any nondeterminism issues: 

157 ``` 

158 v = tf.Variable(True, use_resource=True) 

159 tf.cond(v, lambda: v.assign(False), my_false_fn) 

160 ``` 

161 

162 To use the replacement for variables which does 

163 not have these issues: 

164 

165 * Add `use_resource=True` when constructing `tf.Variable`; 

166 * Call `tf.compat.v1.get_variable_scope().set_use_resource(True)` inside a 

167 `tf.compat.v1.variable_scope` before the `tf.compat.v1.get_variable()` call. 

168 """ 

169 

170 def __init__( 

171 self, # pylint: disable=super-init-not-called 

172 initial_value=None, 

173 trainable=None, 

174 collections=None, 

175 validate_shape=True, 

176 caching_device=None, 

177 name=None, 

178 variable_def=None, 

179 dtype=None, 

180 expected_shape=None, 

181 import_scope=None, 

182 constraint=None, 

183 use_resource=None, 

184 synchronization=variables.VariableSynchronization.AUTO, 

185 aggregation=variables.VariableAggregation.NONE, 

186 shape=None): 

187 """Creates a new variable with value `initial_value`. 

188 

189 The new variable is added to the graph collections listed in `collections`, 

190 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

191 

192 If `trainable` is `True` the variable is also added to the graph collection 

193 `GraphKeys.TRAINABLE_VARIABLES`. 

194 

195 This constructor creates both a `variable` Op and an `assign` Op to set the 

196 variable to its initial value. 

197 

198 Args: 

199 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

200 which is the initial value for the Variable. The initial value must have 

201 a shape specified unless `validate_shape` is set to False. Can also be a 

202 callable with no argument that returns the initial value when called. In 

203 that case, `dtype` must be specified. (Note that initializer functions 

204 from init_ops.py must first be bound to a shape before being used here.) 

205 trainable: If `True`, also adds the variable to the graph collection 

206 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 

207 list of variables to use by the `Optimizer` classes. Defaults to `True`, 

208 unless `synchronization` is set to `ON_READ`, in which case it defaults 

209 to `False`. 

210 collections: List of graph collections keys. The new variable is added to 

211 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

212 validate_shape: If `False`, allows the variable to be initialized with a 

213 value of unknown shape. If `True`, the default, the shape of 

214 `initial_value` must be known. 

215 caching_device: Optional device string describing where the Variable 

216 should be cached for reading. Defaults to the Variable's device. If not 

217 `None`, caches on another device. Typical use is to cache on the device 

218 where the Ops using the Variable reside, to deduplicate copying through 

219 `Switch` and other conditional statements. 

220 name: Optional name for the variable. Defaults to `'Variable'` and gets 

221 uniquified automatically. 

222 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 

223 Variable object with its contents, referencing the variable's nodes in 

224 the graph, which must already exist. The graph is not changed. 

225 `variable_def` and the other arguments are mutually exclusive. 

226 dtype: If set, initial_value will be converted to the given type. If 

227 `None`, either the datatype will be kept (if `initial_value` is a 

228 Tensor), or `convert_to_tensor` will decide. 

229 expected_shape: A TensorShape. If set, initial_value is expected to have 

230 this shape. 

231 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 

232 used when initializing from protocol buffer. 

233 constraint: An optional projection function to be applied to the variable 

234 after being updated by an `Optimizer` (e.g. used to implement norm 

235 constraints or value constraints for layer weights). The function must 

236 take as input the unprojected Tensor representing the value of the 

237 variable and return the Tensor for the projected value (which must have 

238 the same shape). Constraints are not safe to use when doing asynchronous 

239 distributed training. 

240 use_resource: whether to use resource variables. 

241 synchronization: Indicates when a distributed a variable will be 

242 aggregated. Accepted values are constants defined in the class 

243 `tf.VariableSynchronization`. By default the synchronization is set to 

244 `AUTO` and the current `DistributionStrategy` chooses when to 

245 synchronize. 

246 aggregation: Indicates how a distributed variable will be aggregated. 

247 Accepted values are constants defined in the class 

248 `tf.VariableAggregation`. 

249 shape: (optional) The shape of this variable. If None, the shape of 

250 `initial_value` will be used. When setting this argument to 

251 `tf.TensorShape(None)` (representing an unspecified shape), the variable 

252 can be assigned with values of different shapes. 

253 

254 Raises: 

255 ValueError: If both `variable_def` and initial_value are specified. 

256 ValueError: If the initial value is not specified, or does not have a 

257 shape and `validate_shape` is `True`. 

258 RuntimeError: If eager execution is enabled. 

259 """ 

260 

261 SaveSliceInfo = variables.Variable.SaveSliceInfo 

262 

263 def initialized_value(self): 

264 with ops.init_scope(): 

265 return cond.cond( 

266 is_variable_initialized(self), self.read_value, 

267 lambda: self.initial_value) 

268 

269 @staticmethod 

270 def from_proto(variable_def, import_scope=None): 

271 return _variable_from_proto_fn( 

272 variable_def=variable_def, import_scope=import_scope) 

273 

274 @classmethod 

275 def _variable_call( 

276 cls, 

277 initial_value=None, 

278 trainable=None, 

279 validate_shape=True, 

280 caching_device=None, 

281 name=None, 

282 variable_def=None, 

283 dtype=None, 

284 import_scope=None, 

285 constraint=None, 

286 synchronization=variables.VariableSynchronization.AUTO, 

287 aggregation=variables.VariableAggregation.NONE, 

288 shape=None, 

289 experimental_enable_variable_lifting=None, 

290 expected_shape=None, 

291 collections=None, 

292 use_resource=None, 

293 **kwargs, 

294 ): 

295 """VariableV1 class getter. Useful to force the signature.""" 

296 if cls is not VariableV1: 

297 return None 

298 previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) 

299 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 

300 previous_getter = variables._make_getter(getter, previous_getter) # pylint: disable=protected-access 

301 

302 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 

303 if aggregation is None: 

304 aggregation = variables.VariableAggregation.NONE 

305 return previous_getter( 

306 initial_value=initial_value, 

307 trainable=trainable, 

308 validate_shape=validate_shape, 

309 caching_device=caching_device, 

310 name=name, 

311 variable_def=variable_def, 

312 dtype=dtype, 

313 import_scope=import_scope, 

314 constraint=constraint, 

315 synchronization=synchronization, 

316 aggregation=aggregation, 

317 shape=shape, 

318 experimental_enable_variable_lifting=experimental_enable_variable_lifting, 

319 expected_shape=expected_shape, 

320 collections=collections, 

321 use_resource=use_resource, 

322 )