Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py: 41%

540 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Variable class.""" 

16 

17import abc 

18import enum 

19import functools 

20import itertools 

21import os 

22 

23from tensorflow.core.framework import variable_pb2 

24from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 

25from tensorflow.python.eager import context 

26from tensorflow.python.framework import dtypes 

27from tensorflow.python.framework import ops 

28from tensorflow.python.framework import tensor_conversion_registry 

29from tensorflow.python.framework import tensor_shape 

30from tensorflow.python.ops import array_ops 

31from tensorflow.python.ops import array_ops_stack 

32from tensorflow.python.ops import control_flow_ops 

33from tensorflow.python.ops import gen_math_ops 

34from tensorflow.python.ops import math_ops 

35from tensorflow.python.ops import state_ops 

36from tensorflow.python.trackable import base as trackable 

37from tensorflow.python.util import _pywrap_utils 

38from tensorflow.python.util import object_identity 

39from tensorflow.python.util import tf_should_use 

40from tensorflow.python.util import traceback_utils 

41from tensorflow.python.util.deprecation import deprecated 

42from tensorflow.python.util.deprecation import deprecated_args 

43from tensorflow.python.util.tf_export import tf_export 

44 

45 

46def default_variable_creator_v2(_, **kwds): 

47 del kwds 

48 raise NotImplementedError("resource_variable_ops needs to be imported") 

49 

50 

51def _make_getter(captured_getter, captured_previous): 

52 """To avoid capturing loop variables.""" 

53 

54 def getter(**kwargs): 

55 return captured_getter(captured_previous, **kwargs) 

56 

57 return getter 

58 

59 

60@tf_export("VariableSynchronization") 

61class VariableSynchronization(enum.Enum): 

62 """Indicates when a distributed variable will be synced. 

63 

64 * `AUTO`: Indicates that the synchronization will be determined by the current 

65 `DistributionStrategy` (eg. With `MirroredStrategy` this would be 

66 `ON_WRITE`). 

67 * `NONE`: Indicates that there will only be one copy of the variable, so 

68 there is no need to sync. 

69 * `ON_WRITE`: Indicates that the variable will be updated across devices 

70 every time it is written. 

71 * `ON_READ`: Indicates that the variable will be aggregated across devices 

72 when it is read (eg. when checkpointing or when evaluating an op that uses 

73 the variable). 

74 

75 Example: 

76 >>> temp_grad=[tf.Variable([0.], trainable=False, 

77 ... synchronization=tf.VariableSynchronization.ON_READ, 

78 ... aggregation=tf.VariableAggregation.MEAN 

79 ... )] 

80 """ 

81 AUTO = 0 

82 NONE = 1 

83 ON_WRITE = 2 

84 ON_READ = 3 

85 

86 

87# LINT.IfChange 

88@tf_export("VariableAggregation", v1=[]) 

89class VariableAggregationV2(enum.Enum): 

90 """Indicates how a distributed variable will be aggregated. 

91 

92 `tf.distribute.Strategy` distributes a model by making multiple copies 

93 (called "replicas") acting on different elements of the input batch in a 

94 data parallel model. When performing some variable-update operation, 

95 for example `var.assign_add(x)`, in a model, we need to resolve how to combine 

96 the different values for `x` computed in the different replicas. 

97 

98 * `NONE`: This is the default, giving an error if you use a 

99 variable-update operation with multiple replicas. 

100 * `SUM`: Add the updates across replicas. 

101 * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas. 

102 * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same 

103 update, but we only want to perform the update once. Used, e.g., for the 

104 global step counter. 

105 

106 For example: 

107 

108 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

109 >>> with strategy.scope(): 

110 ... v = tf.Variable(5.0, aggregation=tf.VariableAggregation.MEAN) 

111 >>> @tf.function 

112 ... def update_fn(): 

113 ... return v.assign_add(1.0) 

114 >>> strategy.run(update_fn) 

115 PerReplica:{ 

116 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>, 

117 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 

118 } 

119 

120 """ 

121 NONE = 0 

122 SUM = 1 

123 MEAN = 2 

124 ONLY_FIRST_REPLICA = 3 

125 

126 def __hash__(self): 

127 return hash(self.value) 

128 

129 def __eq__(self, other): 

130 if self is other: 

131 return True 

132 elif isinstance(other, VariableAggregation): 

133 return int(self.value) == int(other.value) 

134 else: 

135 return False 

136 

137 

138@tf_export(v1=["VariableAggregation"]) 

139class VariableAggregation(enum.Enum): 

140 NONE = 0 

141 SUM = 1 

142 MEAN = 2 

143 ONLY_FIRST_REPLICA = 3 

144 ONLY_FIRST_TOWER = 3 # DEPRECATED 

145 

146 def __hash__(self): 

147 return hash(self.value) 

148 

149 

150# LINT.ThenChange(//tensorflow/core/framework/variable.proto) 

151# 

152# Note that we are currently relying on the integer values of the Python enums 

153# matching the integer values of the proto enums. 

154 

155VariableAggregation.__doc__ = ( 

156 VariableAggregationV2.__doc__ + 

157 "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n ") 

158 

159 

160def validate_synchronization_aggregation_trainable(synchronization, aggregation, 

161 trainable, name): 

162 """Given user-provided variable properties, sets defaults and validates.""" 

163 if aggregation is None: 

164 aggregation = VariableAggregation.NONE 

165 else: 

166 if not isinstance(aggregation, 

167 (VariableAggregation, VariableAggregationV2)): 

168 try: 

169 aggregation = VariableAggregationV2(aggregation) 

170 except ValueError: 

171 raise ValueError( 

172 "Invalid variable aggregation mode: {} for variable: {}".format( 

173 aggregation, name)) 

174 if synchronization is None: 

175 synchronization = VariableSynchronization.AUTO 

176 else: 

177 try: 

178 synchronization = VariableSynchronization(synchronization) 

179 except ValueError: 

180 raise ValueError( 

181 "Invalid variable synchronization mode: {} for variable: {}".format( 

182 synchronization, name)) 

183 if trainable is None: 

184 trainable = synchronization != VariableSynchronization.ON_READ 

185 return synchronization, aggregation, trainable 

186 

187 

188class VariableMetaclass(abc.ABCMeta): 

189 """Metaclass to allow construction of tf.Variable to be overridden.""" 

190 

191 @traceback_utils.filter_traceback 

192 def __call__(cls, *args, **kwargs): 

193 if hasattr(cls, "_variable_call") and callable(cls._variable_call): 

194 variable_call = cls._variable_call(*args, **kwargs) 

195 if variable_call is not None: 

196 return variable_call 

197 return super(VariableMetaclass, cls).__call__(*args, **kwargs) 

198 

199 

200@tf_export("Variable", v1=[]) 

201# TODO(mdan): This should subclass core.Tensor, and not all its subclasses? 

202class Variable(trackable.Trackable, metaclass=VariableMetaclass): 

203 """See the [variable guide](https://tensorflow.org/guide/variable). 

204 

205 A variable maintains shared, persistent state manipulated by a program. 

206 

207 The `Variable()` constructor requires an initial value for the variable, which 

208 can be a `Tensor` of any type and shape. This initial value defines the type 

209 and shape of the variable. After construction, the type and shape of the 

210 variable are fixed. The value can be changed using one of the assign methods. 

211 

212 >>> v = tf.Variable(1.) 

213 >>> v.assign(2.) 

214 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 

215 >>> v.assign_add(0.5) 

216 <tf.Variable ... shape=() dtype=float32, numpy=2.5> 

217 

218 The `shape` argument to `Variable`'s constructor allows you to construct a 

219 variable with a less defined shape than its `initial_value`: 

220 

221 >>> v = tf.Variable(1., shape=tf.TensorShape(None)) 

222 >>> v.assign([[1.]]) 

223 <tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)> 

224 

225 Just like any `Tensor`, variables created with `Variable()` can be used as 

226 inputs to operations. Additionally, all the operators overloaded for the 

227 `Tensor` class are carried over to variables. 

228 

229 >>> w = tf.Variable([[1.], [2.]]) 

230 >>> x = tf.constant([[3., 4.]]) 

231 >>> tf.matmul(w, x) 

232 <tf.Tensor:... shape=(2, 2), ... numpy= 

233 array([[3., 4.], 

234 [6., 8.]], dtype=float32)> 

235 >>> tf.sigmoid(w + x) 

236 <tf.Tensor:... shape=(2, 2), ...> 

237 

238 When building a machine learning model it is often convenient to distinguish 

239 between variables holding trainable model parameters and other variables such 

240 as a `step` variable used to count training steps. To make this easier, the 

241 variable constructor supports a `trainable=<bool>` 

242 parameter. `tf.GradientTape` watches trainable variables by default: 

243 

244 >>> with tf.GradientTape(persistent=True) as tape: 

245 ... trainable = tf.Variable(1.) 

246 ... non_trainable = tf.Variable(2., trainable=False) 

247 ... x1 = trainable * 2. 

248 ... x2 = non_trainable * 3. 

249 >>> tape.gradient(x1, trainable) 

250 <tf.Tensor:... shape=(), dtype=float32, numpy=2.0> 

251 >>> assert tape.gradient(x2, non_trainable) is None # Unwatched 

252 

253 Variables are automatically tracked when assigned to attributes of types 

254 inheriting from `tf.Module`. 

255 

256 >>> m = tf.Module() 

257 >>> m.v = tf.Variable([1.]) 

258 >>> m.trainable_variables 

259 (<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,) 

260 

261 This tracking then allows saving variable values to 

262 [training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to 

263 [SavedModels](https://www.tensorflow.org/guide/saved_model) which include 

264 serialized TensorFlow graphs. 

265 

266 Variables are often captured and manipulated by `tf.function`s. This works the 

267 same way the un-decorated function would have: 

268 

269 >>> v = tf.Variable(0.) 

270 >>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1)) 

271 >>> read_and_decrement() 

272 <tf.Tensor: shape=(), dtype=float32, numpy=-0.1> 

273 >>> read_and_decrement() 

274 <tf.Tensor: shape=(), dtype=float32, numpy=-0.2> 

275 

276 Variables created inside a `tf.function` must be owned outside the function 

277 and be created only once: 

278 

279 >>> class M(tf.Module): 

280 ... @tf.function 

281 ... def __call__(self, x): 

282 ... if not hasattr(self, "v"): # Or set self.v to None in __init__ 

283 ... self.v = tf.Variable(x) 

284 ... return self.v * x 

285 >>> m = M() 

286 >>> m(2.) 

287 <tf.Tensor: shape=(), dtype=float32, numpy=4.0> 

288 >>> m(3.) 

289 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 

290 >>> m.v 

291 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 

292 

293 See the `tf.function` documentation for details. 

294 """ 

295 

296 @deprecated_args( 

297 None, "A variable's value can be manually cached by calling " 

298 "tf.Variable.read_value() under a tf.device scope. The caching_device " 

299 "argument does not work properly.", "caching_device") 

300 def __init__(self, 

301 initial_value=None, 

302 trainable=None, 

303 validate_shape=True, 

304 caching_device=None, 

305 name=None, 

306 variable_def=None, 

307 dtype=None, 

308 import_scope=None, 

309 constraint=None, 

310 synchronization=VariableSynchronization.AUTO, 

311 aggregation=VariableAggregation.NONE, 

312 shape=None, 

313 experimental_enable_variable_lifting=True, 

314 ): 

315 """Creates a new variable with value `initial_value`. 

316 

317 Args: 

318 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

319 which is the initial value for the Variable. The initial value must have 

320 a shape specified unless `validate_shape` is set to False. Can also be a 

321 callable with no argument that returns the initial value when called. In 

322 that case, `dtype` must be specified. (Note that initializer functions 

323 from init_ops.py must first be bound to a shape before being used here.) 

324 trainable: If `True`, GradientTapes automatically watch uses of this 

325 variable. Defaults to `True`, unless `synchronization` is set to 

326 `ON_READ`, in which case it defaults to `False`. 

327 validate_shape: If `False`, allows the variable to be initialized with a 

328 value of unknown shape. If `True`, the default, the shape of 

329 `initial_value` must be known. 

330 caching_device: Note: This argument is only valid when using a v1-style 

331 `Session`. Optional device string describing where the Variable should 

332 be cached for reading. Defaults to the Variable's device. If not `None`, 

333 caches on another device. Typical use is to cache on the device where 

334 the Ops using the Variable reside, to deduplicate copying through 

335 `Switch` and other conditional statements. 

336 name: Optional name for the variable. Defaults to `'Variable'` and gets 

337 uniquified automatically. 

338 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 

339 Variable object with its contents, referencing the variable's nodes in 

340 the graph, which must already exist. The graph is not changed. 

341 `variable_def` and the other arguments are mutually exclusive. 

342 dtype: If set, initial_value will be converted to the given type. If 

343 `None`, either the datatype will be kept (if `initial_value` is a 

344 Tensor), or `convert_to_tensor` will decide. 

345 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 

346 used when initializing from protocol buffer. 

347 constraint: An optional projection function to be applied to the variable 

348 after being updated by an `Optimizer` (e.g. used to implement norm 

349 constraints or value constraints for layer weights). The function must 

350 take as input the unprojected Tensor representing the value of the 

351 variable and return the Tensor for the projected value (which must have 

352 the same shape). Constraints are not safe to use when doing asynchronous 

353 distributed training. 

354 synchronization: Indicates when a distributed a variable will be 

355 aggregated. Accepted values are constants defined in the class 

356 `tf.VariableSynchronization`. By default the synchronization is set to 

357 `AUTO` and the current `DistributionStrategy` chooses when to 

358 synchronize. 

359 aggregation: Indicates how a distributed variable will be aggregated. 

360 Accepted values are constants defined in the class 

361 `tf.VariableAggregation`. 

362 shape: (optional) The shape of this variable. If None, the shape of 

363 `initial_value` will be used. When setting this argument to 

364 `tf.TensorShape(None)` (representing an unspecified shape), the variable 

365 can be assigned with values of different shapes. 

366 experimental_enable_variable_lifting: Whether to lift the variable out if 

367 it's in a `tf.function`. Default is `True`. When this argument 

368 is `True`, variable creation will follow the behavior and 

369 restrictions described 

370 [here](https://www.tensorflow.org/guide/function#creating_tfvariables). 

371 If this argument is `False`, that description doesn't apply, 

372 and you can freely create and use the variable in the 

373 `tf.function`, as if it's a "mutable `tf.Tensor`". You can't 

374 return the variable though. 

375 

376 Raises: 

377 ValueError: If both `variable_def` and initial_value are specified. 

378 ValueError: If the initial value is not specified, or does not have a 

379 shape and `validate_shape` is `True`. 

380 """ 

381 raise NotImplementedError 

382 

383 def __repr__(self): 

384 raise NotImplementedError 

385 

386 def value(self): 

387 """Returns the last snapshot of this variable. 

388 

389 You usually do not need to call this method as all ops that need the value 

390 of the variable call it automatically through a `convert_to_tensor()` call. 

391 

392 Returns a `Tensor` which holds the value of the variable. You can not 

393 assign a new value to this tensor as it is not a reference to the variable. 

394 

395 To avoid copies, if the consumer of the returned value is on the same device 

396 as the variable, this actually returns the live value of the variable, not 

397 a copy. Updates to the variable are seen by the consumer. If the consumer 

398 is on a different device it will get a copy of the variable. 

399 

400 Returns: 

401 A `Tensor` containing the value of the variable. 

402 """ 

403 raise NotImplementedError 

404 

405 def read_value(self): 

406 """Returns the value of this variable, read in the current context. 

407 

408 Can be different from value() if it's on another device, with control 

409 dependencies, etc. 

410 

411 Returns: 

412 A `Tensor` containing the value of the variable. 

413 """ 

414 raise NotImplementedError 

415 

416 def set_shape(self, shape): 

417 """Overrides the shape for this variable. 

418 

419 Args: 

420 shape: the `TensorShape` representing the overridden shape. 

421 """ 

422 raise NotImplementedError 

423 

424 @property 

425 def trainable(self): 

426 raise NotImplementedError 

427 

428 @property 

429 def synchronization(self): 

430 raise NotImplementedError 

431 

432 @property 

433 def aggregation(self): 

434 raise NotImplementedError 

435 

436 def eval(self, session=None): 

437 """In a session, computes and returns the value of this variable. 

438 

439 This is not a graph construction method, it does not add ops to the graph. 

440 

441 This convenience method requires a session where the graph 

442 containing this variable has been launched. If no session is 

443 passed, the default session is used. See `tf.compat.v1.Session` for more 

444 information on launching a graph and on sessions. 

445 

446 ```python 

447 v = tf.Variable([1, 2]) 

448 init = tf.compat.v1.global_variables_initializer() 

449 

450 with tf.compat.v1.Session() as sess: 

451 sess.run(init) 

452 # Usage passing the session explicitly. 

453 print(v.eval(sess)) 

454 # Usage with the default session. The 'with' block 

455 # above makes 'sess' the default session. 

456 print(v.eval()) 

457 ``` 

458 

459 Args: 

460 session: The session to use to evaluate this variable. If none, the 

461 default session is used. 

462 

463 Returns: 

464 A numpy `ndarray` with a copy of the value of this variable. 

465 """ 

466 raise NotImplementedError 

467 

468 @deprecated( 

469 None, "Use Variable.read_value. Variables in 2.X are initialized " 

470 "automatically both in eager and graph (inside tf.defun) contexts.") 

471 def initialized_value(self): 

472 """Returns the value of the initialized variable. 

473 

474 You should use this instead of the variable itself to initialize another 

475 variable with a value that depends on the value of this variable. 

476 

477 ```python 

478 # Initialize 'v' with a random tensor. 

479 v = tf.Variable(tf.random.truncated_normal([10, 40])) 

480 # Use `initialized_value` to guarantee that `v` has been 

481 # initialized before its value is used to initialize `w`. 

482 # The random values are picked only once. 

483 w = tf.Variable(v.initialized_value() * 2.0) 

484 ``` 

485 

486 Returns: 

487 A `Tensor` holding the value of this variable after its initializer 

488 has run. 

489 """ 

490 raise NotImplementedError 

491 

492 @property 

493 def initial_value(self): 

494 """Returns the Tensor used as the initial value for the variable. 

495 

496 Note that this is different from `initialized_value()` which runs 

497 the op that initializes the variable before returning its value. 

498 This method returns the tensor that is used by the op that initializes 

499 the variable. 

500 

501 Returns: 

502 A `Tensor`. 

503 """ 

504 raise NotImplementedError 

505 

506 @property 

507 def constraint(self): 

508 """Returns the constraint function associated with this variable. 

509 

510 Returns: 

511 The constraint function that was passed to the variable constructor. 

512 Can be `None` if no constraint was passed. 

513 """ 

514 raise NotImplementedError 

515 

516 def assign(self, value, use_locking=False, name=None, read_value=True): 

517 """Assigns a new value to the variable. 

518 

519 This is essentially a shortcut for `assign(self, value)`. 

520 

521 Args: 

522 value: A `Tensor`. The new value for this variable. 

523 use_locking: If `True`, use locking during the assignment. 

524 name: The name of the operation to be created 

525 read_value: if True, will return something which evaluates to the new 

526 value of the variable; if False will return the assign op. 

527 

528 Returns: 

529 The updated variable. If `read_value` is false, instead returns None in 

530 Eager mode and the assign op in graph mode. 

531 """ 

532 raise NotImplementedError 

533 

534 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 

535 """Adds a value to this variable. 

536 

537 This is essentially a shortcut for `assign_add(self, delta)`. 

538 

539 Args: 

540 delta: A `Tensor`. The value to add to this variable. 

541 use_locking: If `True`, use locking during the operation. 

542 name: The name of the operation to be created 

543 read_value: if True, will return something which evaluates to the new 

544 value of the variable; if False will return the assign op. 

545 

546 Returns: 

547 The updated variable. If `read_value` is false, instead returns None in 

548 Eager mode and the assign op in graph mode. 

549 """ 

550 raise NotImplementedError 

551 

552 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 

553 """Subtracts a value from this variable. 

554 

555 This is essentially a shortcut for `assign_sub(self, delta)`. 

556 

557 Args: 

558 delta: A `Tensor`. The value to subtract from this variable. 

559 use_locking: If `True`, use locking during the operation. 

560 name: The name of the operation to be created 

561 read_value: if True, will return something which evaluates to the new 

562 value of the variable; if False will return the assign op. 

563 

564 Returns: 

565 The updated variable. If `read_value` is false, instead returns None in 

566 Eager mode and the assign op in graph mode. 

567 """ 

568 raise NotImplementedError 

569 

570 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

571 """Subtracts `tf.IndexedSlices` from this variable. 

572 

573 Args: 

574 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 

575 use_locking: If `True`, use locking during the operation. 

576 name: the name of the operation. 

577 

578 Returns: 

579 The updated variable. 

580 

581 Raises: 

582 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

583 """ 

584 raise NotImplementedError 

585 

586 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

587 """Adds `tf.IndexedSlices` to this variable. 

588 

589 Args: 

590 sparse_delta: `tf.IndexedSlices` to be added to this variable. 

591 use_locking: If `True`, use locking during the operation. 

592 name: the name of the operation. 

593 

594 Returns: 

595 The updated variable. 

596 

597 Raises: 

598 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

599 """ 

600 raise NotImplementedError 

601 

602 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

603 """Updates this variable with the max of `tf.IndexedSlices` and itself. 

604 

605 Args: 

606 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 

607 variable. 

608 use_locking: If `True`, use locking during the operation. 

609 name: the name of the operation. 

610 

611 Returns: 

612 The updated variable. 

613 

614 Raises: 

615 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

616 """ 

617 raise NotImplementedError 

618 

619 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

620 """Updates this variable with the min of `tf.IndexedSlices` and itself. 

621 

622 Args: 

623 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 

624 variable. 

625 use_locking: If `True`, use locking during the operation. 

626 name: the name of the operation. 

627 

628 Returns: 

629 The updated variable. 

630 

631 Raises: 

632 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

633 """ 

634 raise NotImplementedError 

635 

636 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

637 """Multiply this variable by `tf.IndexedSlices`. 

638 

639 Args: 

640 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 

641 use_locking: If `True`, use locking during the operation. 

642 name: the name of the operation. 

643 

644 Returns: 

645 The updated variable. 

646 

647 Raises: 

648 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

649 """ 

650 raise NotImplementedError 

651 

652 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

653 """Divide this variable by `tf.IndexedSlices`. 

654 

655 Args: 

656 sparse_delta: `tf.IndexedSlices` to divide this variable by. 

657 use_locking: If `True`, use locking during the operation. 

658 name: the name of the operation. 

659 

660 Returns: 

661 The updated variable. 

662 

663 Raises: 

664 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

665 """ 

666 raise NotImplementedError 

667 

668 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

669 """Assigns `tf.IndexedSlices` to this variable. 

670 

671 Args: 

672 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 

673 use_locking: If `True`, use locking during the operation. 

674 name: the name of the operation. 

675 

676 Returns: 

677 The updated variable. 

678 

679 Raises: 

680 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

681 """ 

682 raise NotImplementedError 

683 

684 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 

685 """Assigns `tf.IndexedSlices` to this variable batch-wise. 

686 

687 Analogous to `batch_gather`. This assumes that this variable and the 

688 sparse_delta IndexedSlices have a series of leading dimensions that are the 

689 same for all of them, and the updates are performed on the last dimension of 

690 indices. In other words, the dimensions should be the following: 

691 

692 `num_prefix_dims = sparse_delta.indices.ndims - 1` 

693 `batch_dim = num_prefix_dims + 1` 

694 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 

695 batch_dim:]` 

696 

697 where 

698 

699 `sparse_delta.updates.shape[:num_prefix_dims]` 

700 `== sparse_delta.indices.shape[:num_prefix_dims]` 

701 `== var.shape[:num_prefix_dims]` 

702 

703 And the operation performed can be expressed as: 

704 

705 `var[i_1, ..., i_n, 

706 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 

707 i_1, ..., i_n, j]` 

708 

709 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 

710 `scatter_update`. 

711 

712 To avoid this operation one can looping over the first `ndims` of the 

713 variable and using `scatter_update` on the subtensors that result of slicing 

714 the first dimension. This is a valid option for `ndims = 1`, but less 

715 efficient than this implementation. 

716 

717 Args: 

718 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 

719 use_locking: If `True`, use locking during the operation. 

720 name: the name of the operation. 

721 

722 Returns: 

723 The updated variable. 

724 

725 Raises: 

726 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

727 """ 

728 raise NotImplementedError 

729 

730 def scatter_nd_sub(self, indices, updates, name=None): 

731 """Applies sparse subtraction to individual values or slices in a Variable. 

732 

733 Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 

734 

735 `indices` must be integer tensor, containing indices into self. 

736 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

737 

738 The innermost dimension of `indices` (with length `K`) corresponds to 

739 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

740 dimension of self. 

741 

742 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

743 

744 ``` 

745 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 

746 ``` 

747 

748 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

749 8 elements. In Python, that update would look like this: 

750 

751 ```python 

752 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

753 indices = tf.constant([[4], [3], [1] ,[7]]) 

754 updates = tf.constant([9, 10, 11, 12]) 

755 v.scatter_nd_sub(indices, updates) 

756 print(v) 

757 ``` 

758 

759 After the update `v` would look like this: 

760 

761 [1, -9, 3, -6, -4, 6, 7, -4] 

762 

763 See `tf.scatter_nd` for more details about how to make updates to 

764 slices. 

765 

766 Args: 

767 indices: The indices to be used in the operation. 

768 updates: The values to be used in the operation. 

769 name: the name of the operation. 

770 

771 Returns: 

772 The updated variable. 

773 """ 

774 raise NotImplementedError 

775 

776 def scatter_nd_add(self, indices, updates, name=None): 

777 """Applies sparse addition to individual values or slices in a Variable. 

778 

779 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 

780 

781 `indices` must be integer tensor, containing indices into self. 

782 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

783 

784 The innermost dimension of `indices` (with length `K`) corresponds to 

785 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

786 dimension of self. 

787 

788 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

789 

790 ``` 

791 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 

792 ``` 

793 

794 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

795 8 elements. In Python, that update would look like this: 

796 

797 ```python 

798 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

799 indices = tf.constant([[4], [3], [1] ,[7]]) 

800 updates = tf.constant([9, 10, 11, 12]) 

801 v.scatter_nd_add(indices, updates) 

802 print(v) 

803 ``` 

804 

805 The resulting update to v would look like this: 

806 

807 [1, 13, 3, 14, 14, 6, 7, 20] 

808 

809 See `tf.scatter_nd` for more details about how to make updates to 

810 slices. 

811 

812 Args: 

813 indices: The indices to be used in the operation. 

814 updates: The values to be used in the operation. 

815 name: the name of the operation. 

816 

817 Returns: 

818 The updated variable. 

819 """ 

820 raise NotImplementedError 

821 

822 def scatter_nd_update(self, indices, updates, name=None): 

823 """Applies sparse assignment to individual values or slices in a Variable. 

824 

825 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 

826 

827 `indices` must be integer tensor, containing indices into self. 

828 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

829 

830 The innermost dimension of `indices` (with length `K`) corresponds to 

831 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

832 dimension of self. 

833 

834 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

835 

836 ``` 

837 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 

838 ``` 

839 

840 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

841 8 elements. In Python, that update would look like this: 

842 

843 ```python 

844 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

845 indices = tf.constant([[4], [3], [1] ,[7]]) 

846 updates = tf.constant([9, 10, 11, 12]) 

847 v.scatter_nd_update(indices, updates) 

848 print(v) 

849 ``` 

850 

851 The resulting update to v would look like this: 

852 

853 [1, 11, 3, 10, 9, 6, 7, 12] 

854 

855 See `tf.scatter_nd` for more details about how to make updates to 

856 slices. 

857 

858 Args: 

859 indices: The indices to be used in the operation. 

860 updates: The values to be used in the operation. 

861 name: the name of the operation. 

862 

863 Returns: 

864 The updated variable. 

865 """ 

866 raise NotImplementedError 

867 

868 def sparse_read(self, indices, name=None): 

869 r"""Gather slices from params axis axis according to indices. 

870 

871 This function supports a subset of tf.gather, see tf.gather for details on 

872 usage. 

873 

874 Args: 

875 indices: The index `Tensor`. Must be one of the following types: `int32`, 

876 `int64`. Must be in range `[0, params.shape[axis])`. 

877 name: A name for the operation (optional). 

878 

879 Returns: 

880 A `Tensor`. Has the same type as `params`. 

881 """ 

882 raise AttributeError 

883 

884 def gather_nd(self, indices, name=None): 

885 r"""Gather slices from `params` into a Tensor with shape specified by `indices`. 

886 

887 See tf.gather_nd for details. 

888 

889 Args: 

890 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

891 Index tensor. 

892 name: A name for the operation (optional). 

893 

894 Returns: 

895 A `Tensor`. Has the same type as `params`. 

896 """ 

897 raise AttributeError 

898 

899 @deprecated(None, "Prefer Dataset.range instead.") 

900 def count_up_to(self, limit): 

901 """Increments this variable until it reaches `limit`. 

902 

903 When that Op is run it tries to increment the variable by `1`. If 

904 incrementing the variable would bring it above `limit` then the Op raises 

905 the exception `OutOfRangeError`. 

906 

907 If no error is raised, the Op outputs the value of the variable before 

908 the increment. 

909 

910 This is essentially a shortcut for `count_up_to(self, limit)`. 

911 

912 Args: 

913 limit: value at which incrementing the variable raises an error. 

914 

915 Returns: 

916 A `Tensor` that will hold the variable value before the increment. If no 

917 other Op modifies this variable, the values produced will all be 

918 distinct. 

919 """ 

920 raise NotImplementedError 

921 

922 @deprecated(None, 

923 "Prefer Variable.assign which has equivalent behavior in 2.X.") 

924 def load(self, value, session=None): 

925 """Load new value into this variable. 

926 

927 Writes new value to variable's memory. Doesn't add ops to the graph. 

928 

929 This convenience method requires a session where the graph 

930 containing this variable has been launched. If no session is 

931 passed, the default session is used. See `tf.compat.v1.Session` for more 

932 information on launching a graph and on sessions. 

933 

934 ```python 

935 v = tf.Variable([1, 2]) 

936 init = tf.compat.v1.global_variables_initializer() 

937 

938 with tf.compat.v1.Session() as sess: 

939 sess.run(init) 

940 # Usage passing the session explicitly. 

941 v.load([2, 3], sess) 

942 print(v.eval(sess)) # prints [2 3] 

943 # Usage with the default session. The 'with' block 

944 # above makes 'sess' the default session. 

945 v.load([3, 4], sess) 

946 print(v.eval()) # prints [3 4] 

947 ``` 

948 

949 Args: 

950 value: New variable value 

951 session: The session to use to evaluate this variable. If none, the 

952 default session is used. 

953 

954 Raises: 

955 ValueError: Session is not passed and no default session 

956 """ 

957 if context.executing_eagerly(): 

958 self.assign(value) 

959 else: 

960 session = session or ops.get_default_session() 

961 if session is None: 

962 raise ValueError( 

963 "Either session argument should be provided or default session " 

964 "should be established") 

965 session.run(self.initializer, {self.initializer.inputs[1]: value}) 

966 

967 # Conversion to tensor. 

968 @staticmethod 

969 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 

970 """Utility function for converting a Variable to a Tensor.""" 

971 _ = name 

972 if dtype and not dtype.is_compatible_with(v.dtype): 

973 raise ValueError( 

974 f"Incompatible type conversion requested to type '{dtype.name}' for " 

975 f"variable of type '{v.dtype.name}' (Variable: {v}).") 

976 if as_ref: 

977 return v._ref() # pylint: disable=protected-access 

978 else: 

979 return v.value() 

980 

981 @classmethod 

982 def _OverloadAllOperators(cls): # pylint: disable=invalid-name 

983 """Register overloads for all operators.""" 

984 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 

985 cls._OverloadOperator(operator) 

986 # For slicing, bind getitem differently than a tensor (use SliceHelperVar 

987 # instead) 

988 # pylint: disable=protected-access 

989 setattr(cls, "__getitem__", array_ops._SliceHelperVar) 

990 

991 @classmethod 

992 def _OverloadOperator(cls, operator): # pylint: disable=invalid-name 

993 """Defer an operator overload to `ops.Tensor`. 

994 

995 We pull the operator out of ops.Tensor dynamically to avoid ordering issues. 

996 

997 Args: 

998 operator: string. The operator name. 

999 """ 

1000 # We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is 

1001 # called when adding a variable to sets. As a result we call a.value() which 

1002 # causes infinite recursion when operating within a GradientTape 

1003 # TODO(gjn): Consider removing this 

1004 if operator == "__eq__" or operator == "__ne__": 

1005 return 

1006 

1007 tensor_oper = getattr(ops.Tensor, operator) 

1008 

1009 def _run_op(a, *args, **kwargs): 

1010 # pylint: disable=protected-access 

1011 return tensor_oper(a.value(), *args, **kwargs) 

1012 

1013 functools.update_wrapper(_run_op, tensor_oper) 

1014 setattr(cls, operator, _run_op) 

1015 

1016 def __hash__(self): 

1017 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 

1018 raise TypeError( 

1019 "Variable is unhashable. " 

1020 f"Instead, use variable.ref() as the key. (Variable: {self})") 

1021 else: 

1022 return id(self) 

1023 

1024 # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing 

1025 def __eq__(self, other): 

1026 """Compares two variables element-wise for equality.""" 

1027 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 

1028 return gen_math_ops.equal(self, other, incompatible_shape_error=False) 

1029 else: 

1030 # In legacy graph mode, tensor equality is object equality 

1031 return self is other 

1032 

1033 # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing 

1034 def __ne__(self, other): 

1035 """Compares two variables element-wise for equality.""" 

1036 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 

1037 return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) 

1038 else: 

1039 # In legacy graph mode, tensor equality is object equality 

1040 return self is not other 

1041 

1042 def __iter__(self): 

1043 """When executing eagerly, iterates over the value of the variable.""" 

1044 return iter(self.read_value()) 

1045 

1046 # NOTE(mrry): This enables the Variable's overloaded "right" binary 

1047 # operators to run when the left operand is an ndarray, because it 

1048 # accords the Variable class higher priority than an ndarray, or a 

1049 # numpy matrix. 

1050 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 

1051 # mechanism, which allows more control over how Variables interact 

1052 # with ndarrays. 

1053 __array_priority__ = 100 

1054 

1055 @property 

1056 def name(self): 

1057 """The name of this variable.""" 

1058 raise NotImplementedError 

1059 

1060 @property 

1061 def _shared_name(self): 

1062 """The shared name of the variable. 

1063 

1064 Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified 

1065 name with name scope prefix. 

1066 

1067 Returns: 

1068 variable name. 

1069 """ 

1070 return self.name[:self.name.index(":")] 

1071 

1072 @property 

1073 def initializer(self): 

1074 """The initializer operation for this variable.""" 

1075 raise NotImplementedError 

1076 

1077 @property 

1078 def device(self): 

1079 """The device of this variable.""" 

1080 raise NotImplementedError 

1081 

1082 @property 

1083 def dtype(self): 

1084 """The `DType` of this variable.""" 

1085 raise NotImplementedError 

1086 

1087 @property 

1088 def op(self): 

1089 """The `Operation` of this variable.""" 

1090 raise NotImplementedError 

1091 

1092 @property 

1093 def graph(self): 

1094 """The `Graph` of this variable.""" 

1095 raise NotImplementedError 

1096 

1097 @property 

1098 def shape(self): 

1099 """The `TensorShape` of this variable. 

1100 

1101 Returns: 

1102 A `TensorShape`. 

1103 """ 

1104 raise NotImplementedError 

1105 

1106 def get_shape(self): 

1107 """Alias of `Variable.shape`.""" 

1108 return self.shape 

1109 

1110 def _gather_saveables_for_checkpoint(self): 

1111 """For implementing `Trackable`. This object is saveable on its own.""" 

1112 return {trackable.VARIABLE_VALUE_KEY: self} 

1113 

1114 def to_proto(self, export_scope=None): 

1115 """Converts a `Variable` to a `VariableDef` protocol buffer. 

1116 

1117 Args: 

1118 export_scope: Optional `string`. Name scope to remove. 

1119 

1120 Returns: 

1121 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 

1122 in the specified name scope. 

1123 """ 

1124 raise NotImplementedError 

1125 

1126 @staticmethod 

1127 def from_proto(variable_def, import_scope=None): 

1128 """Returns a `Variable` object created from `variable_def`.""" 

1129 raise NotImplementedError 

1130 

1131 def _set_save_slice_info(self, save_slice_info): 

1132 """Sets the slice info for this `Variable`. 

1133 

1134 Args: 

1135 save_slice_info: A `Variable.SaveSliceInfo` object. 

1136 """ 

1137 self._save_slice_info = save_slice_info 

1138 

1139 def _get_save_slice_info(self): 

1140 return self._save_slice_info 

1141 

1142 @deprecated(None, "Use ref() instead.") 

1143 def experimental_ref(self): 

1144 return self.ref() 

1145 

1146 def ref(self): 

1147 # tf.Tensor also has the same ref() API. If you update the 

1148 # documentation here, please update tf.Tensor.ref() as well. 

1149 """Returns a hashable reference object to this Variable. 

1150 

1151 The primary use case for this API is to put variables in a set/dictionary. 

1152 We can't put variables in a set/dictionary as `variable.__hash__()` is no 

1153 longer available starting Tensorflow 2.0. 

1154 

1155 The following will raise an exception starting 2.0 

1156 

1157 >>> x = tf.Variable(5) 

1158 >>> y = tf.Variable(10) 

1159 >>> z = tf.Variable(10) 

1160 >>> variable_set = {x, y, z} 

1161 Traceback (most recent call last): 

1162 ... 

1163 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key. 

1164 >>> variable_dict = {x: 'five', y: 'ten'} 

1165 Traceback (most recent call last): 

1166 ... 

1167 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key. 

1168 

1169 Instead, we can use `variable.ref()`. 

1170 

1171 >>> variable_set = {x.ref(), y.ref(), z.ref()} 

1172 >>> x.ref() in variable_set 

1173 True 

1174 >>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'} 

1175 >>> variable_dict[y.ref()] 

1176 'ten' 

1177 

1178 Also, the reference object provides `.deref()` function that returns the 

1179 original Variable. 

1180 

1181 >>> x = tf.Variable(5) 

1182 >>> x.ref().deref() 

1183 <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5> 

1184 """ 

1185 return object_identity.Reference(self) 

1186 

1187 @classmethod 

1188 def _variable_call( 

1189 cls, 

1190 initial_value=None, 

1191 trainable=None, 

1192 validate_shape=True, 

1193 caching_device=None, 

1194 name=None, 

1195 variable_def=None, 

1196 dtype=None, 

1197 import_scope=None, 

1198 constraint=None, 

1199 synchronization=VariableSynchronization.AUTO, 

1200 aggregation=VariableAggregation.NONE, 

1201 shape=None, 

1202 experimental_enable_variable_lifting=None, 

1203 **kwargs, 

1204 ): 

1205 """Variable class getter. Useful to force the signature.""" 

1206 if cls is not Variable: 

1207 return None 

1208 previous_getter = lambda **kws: default_variable_creator_v2(None, **kws) 

1209 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 

1210 previous_getter = _make_getter(getter, previous_getter) 

1211 

1212 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 

1213 if aggregation is None: 

1214 aggregation = VariableAggregation.NONE 

1215 return previous_getter( 

1216 initial_value=initial_value, 

1217 trainable=trainable, 

1218 validate_shape=validate_shape, 

1219 caching_device=caching_device, 

1220 name=name, 

1221 variable_def=variable_def, 

1222 dtype=dtype, 

1223 import_scope=import_scope, 

1224 constraint=constraint, 

1225 synchronization=synchronization, 

1226 aggregation=aggregation, 

1227 shape=shape, 

1228 experimental_enable_variable_lifting=experimental_enable_variable_lifting, 

1229 ) 

1230 

1231 class SaveSliceInfo: 

1232 """Information on how to save this Variable as a slice. 

1233 

1234 Provides internal support for saving variables as slices of a larger 

1235 variable. This API is not public and is subject to change. 

1236 

1237 Available properties: 

1238 

1239 * full_name 

1240 * full_shape 

1241 * var_offset 

1242 * var_shape 

1243 """ 

1244 

1245 def __init__(self, 

1246 full_name=None, 

1247 full_shape=None, 

1248 var_offset=None, 

1249 var_shape=None, 

1250 save_slice_info_def=None, 

1251 import_scope=None): 

1252 """Create a `SaveSliceInfo`. 

1253 

1254 Args: 

1255 full_name: Name of the full variable of which this `Variable` is a 

1256 slice. 

1257 full_shape: Shape of the full variable, as a list of int. 

1258 var_offset: Offset of this `Variable` into the full variable, as a list 

1259 of int. 

1260 var_shape: Shape of this `Variable`, as a list of int. 

1261 save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, 

1262 recreates the SaveSliceInfo object its contents. `save_slice_info_def` 

1263 and other arguments are mutually exclusive. 

1264 import_scope: Optional `string`. Name scope to add. Only used when 

1265 initializing from protocol buffer. 

1266 """ 

1267 if save_slice_info_def: 

1268 assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) 

1269 self.full_name = ops.prepend_name_scope( 

1270 save_slice_info_def.full_name, import_scope=import_scope) 

1271 self.full_shape = list(save_slice_info_def.full_shape) 

1272 self.var_offset = list(save_slice_info_def.var_offset) 

1273 self.var_shape = list(save_slice_info_def.var_shape) 

1274 else: 

1275 self.full_name = full_name 

1276 self.full_shape = full_shape 

1277 self.var_offset = var_offset 

1278 self.var_shape = var_shape 

1279 

1280 @property 

1281 def spec(self): 

1282 """Computes the spec string used for saving.""" 

1283 full_shape_str = " ".join("%d" % d for d in self.full_shape) + " " 

1284 sl_spec = ":".join( 

1285 "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)) 

1286 return full_shape_str + sl_spec 

1287 

1288 def to_proto(self, export_scope=None): 

1289 """Returns a SaveSliceInfoDef() proto. 

1290 

1291 Args: 

1292 export_scope: Optional `string`. Name scope to remove. 

1293 

1294 Returns: 

1295 A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not 

1296 in the specified name scope. 

1297 """ 

1298 if (export_scope is None or self.full_name.startswith(export_scope)): 

1299 save_slice_info_def = variable_pb2.SaveSliceInfoDef() 

1300 save_slice_info_def.full_name = ops.strip_name_scope( 

1301 self.full_name, export_scope) 

1302 for i in self.full_shape: 

1303 save_slice_info_def.full_shape.append(i) 

1304 for i in self.var_offset: 

1305 save_slice_info_def.var_offset.append(i) 

1306 for i in self.var_shape: 

1307 save_slice_info_def.var_shape.append(i) 

1308 return save_slice_info_def 

1309 else: 

1310 return None 

1311 

1312 

1313Variable._OverloadAllOperators() # pylint: disable=protected-access 

1314_pywrap_utils.RegisterType("Variable", Variable) 

1315 

1316 

1317def _try_guard_against_uninitialized_dependencies(name, initial_value): 

1318 """Attempt to guard against dependencies on uninitialized variables. 

1319 

1320 Replace references to variables in `initial_value` with references to the 

1321 variable's initialized values. The initialized values are essentially 

1322 conditional TensorFlow graphs that return a variable's value if it is 

1323 initialized or its `initial_value` if it hasn't been initialized. This 

1324 replacement is done on a best effort basis: 

1325 

1326 - If the `initial_value` graph contains cycles, we don't do any 

1327 replacements for that graph. 

1328 - If the variables that `initial_value` depends on are not present in the 

1329 `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them. 

1330 

1331 In these cases, it is up to the caller to ensure that the `initial_value` 

1332 graph uses initialized variables or that they guard access to variables 

1333 using their `initialized_value` method. 

1334 

1335 Args: 

1336 name: Variable name. 

1337 initial_value: `Tensor`. The initial value. 

1338 

1339 Returns: 

1340 A `Tensor` suitable to initialize a variable. 

1341 Raises: 

1342 TypeError: If `initial_value` is not a `Tensor`. 

1343 """ 

1344 if not isinstance(initial_value, ops.Tensor): 

1345 raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) 

1346 

1347 # Don't modify initial_value if it contains any cyclic dependencies. 

1348 if _has_cycle(initial_value.op, state={}): 

1349 return initial_value 

1350 return _safe_initial_value_from_tensor(name, initial_value, op_cache={}) 

1351 

1352 

1353_UNKNOWN, _STARTED, _FINISHED = range(3) 

1354 

1355 

1356def _has_cycle(op, state): 

1357 """Detect cycles in the dependencies of `initial_value`.""" 

1358 op_state = state.get(op.name, _UNKNOWN) 

1359 if op_state == _STARTED: 

1360 return True 

1361 elif op_state == _FINISHED: 

1362 return False 

1363 

1364 state[op.name] = _STARTED 

1365 for i in itertools.chain((i.op for i in op.inputs), op.control_inputs): 

1366 if _has_cycle(i, state): 

1367 return True 

1368 state[op.name] = _FINISHED 

1369 return False 

1370 

1371 

1372def _safe_initial_value_from_tensor(name, tensor, op_cache): 

1373 """Replace dependencies on variables with their initialized values. 

1374 

1375 Args: 

1376 name: Variable name. 

1377 tensor: A `Tensor`. The tensor to replace. 

1378 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 

1379 the results so as to avoid creating redundant operations. 

1380 

1381 Returns: 

1382 A `Tensor` compatible with `tensor`. Any inputs that lead to variable 

1383 values will be replaced with a corresponding graph that uses the 

1384 variable's initialized values. This is done on a best-effort basis. If no 

1385 modifications need to be made then `tensor` will be returned unchanged. 

1386 """ 

1387 op = tensor.op 

1388 new_op = op_cache.get(op.name) 

1389 if new_op is None: 

1390 new_op = _safe_initial_value_from_op(name, op, op_cache) 

1391 op_cache[op.name] = new_op 

1392 return new_op.outputs[tensor.value_index] 

1393 

1394 

1395def _safe_initial_value_from_op(name, op, op_cache): 

1396 """Replace dependencies on variables with their initialized values. 

1397 

1398 Args: 

1399 name: Variable name. 

1400 op: An `Operation`. The operation to replace. 

1401 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 

1402 the results so as to avoid creating redundant operations. 

1403 

1404 Returns: 

1405 An `Operation` compatible with `op`. Any inputs that lead to variable 

1406 values will be replaced with a corresponding graph that uses the 

1407 variable's initialized values. This is done on a best-effort basis. If no 

1408 modifications need to be made then `op` will be returned unchanged. 

1409 """ 

1410 op_type = op.node_def.op 

1411 if op_type in ("IsVariableInitialized", "VarIsInitializedOp", 

1412 "ReadVariableOp", "If"): 

1413 return op 

1414 

1415 # Attempt to find the initialized_value of any variable reference / handles. 

1416 # TODO(b/70206927): Fix handling of ResourceVariables. 

1417 if op_type in ("Variable", "VariableV2", "VarHandleOp"): 

1418 initialized_value = _find_initialized_value_for_variable(op) 

1419 return op if initialized_value is None else initialized_value.op 

1420 

1421 # Recursively build initializer expressions for inputs. 

1422 modified = False 

1423 new_op_inputs = [] 

1424 for op_input in op.inputs: 

1425 new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache) 

1426 new_op_inputs.append(new_op_input) 

1427 modified = modified or (new_op_input != op_input) 

1428 

1429 # If at least one input was modified, replace the op. 

1430 if modified: 

1431 new_op_type = op_type 

1432 if new_op_type == "RefSwitch": 

1433 new_op_type = "Switch" 

1434 new_op_name = op.node_def.name + "_" + name 

1435 new_op_name = new_op_name.replace(":", "_") 

1436 return op.graph.create_op( 

1437 new_op_type, 

1438 new_op_inputs, 

1439 op._output_types, # pylint: disable=protected-access 

1440 name=new_op_name, 

1441 attrs=op.node_def.attr) 

1442 

1443 return op 

1444 

1445 

1446def _find_initialized_value_for_variable(variable_op): 

1447 """Find the initialized value for a variable op. 

1448 

1449 To do so, lookup the variable op in the variables collection. 

1450 

1451 Args: 

1452 variable_op: A variable `Operation`. 

1453 

1454 Returns: 

1455 A `Tensor` representing the initialized value for the variable or `None` 

1456 if the initialized value could not be found. 

1457 """ 

1458 try: 

1459 var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"] 

1460 for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES, 

1461 ops.GraphKeys.LOCAL_VARIABLES): 

1462 for var in variable_op.graph.get_collection(collection_name): 

1463 if var.name in var_names: 

1464 return var.initialized_value() 

1465 except AttributeError: 

1466 # Return None when an incomplete user-defined variable type was put in 

1467 # the collection. 

1468 return None 

1469 return None 

1470 

1471 

1472class PartitionedVariable: 

1473 """A container for partitioned `Variable` objects. 

1474 

1475 @compatibility(eager) `tf.PartitionedVariable` is not compatible with 

1476 eager execution. Use `tf.Variable` instead which is compatible 

1477 with both eager execution and graph construction. See [the 

1478 TensorFlow Eager Execution 

1479 guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers) 

1480 for details on how variables work in eager execution. 

1481 @end_compatibility 

1482 """ 

1483 

1484 def __init__(self, name, shape, dtype, variable_list, partitions): 

1485 """Creates a new partitioned variable wrapper. 

1486 

1487 Variables passed via the variable_list must contain a save_slice_info 

1488 field. Concatenation and iteration is in lexicographic order according 

1489 to the var_offset property of the save_slice_info. 

1490 

1491 Args: 

1492 name: String. Overall name of the variables. 

1493 shape: List of integers. Overall shape of the variables. 

1494 dtype: Type of the variables. 

1495 variable_list: List of `Variable` that comprise this partitioned variable. 

1496 partitions: List of integers. Number of partitions for each dimension. 

1497 

1498 Raises: 

1499 TypeError: If `variable_list` is not a list of `Variable` objects, or 

1500 `partitions` is not a list. 

1501 ValueError: If `variable_list` is empty, or the `Variable` shape 

1502 information does not match `shape`, or `partitions` has invalid values. 

1503 """ 

1504 if not isinstance(variable_list, (list, tuple)): 

1505 raise TypeError("variable_list is not a list or tuple: %s" % 

1506 variable_list) 

1507 if not isinstance(partitions, (list, tuple)): 

1508 raise TypeError("partitions is not a list or tuple: %s" % partitions) 

1509 if not all(p >= 1 for p in partitions): 

1510 raise ValueError("partition values must be positive: %s" % partitions) 

1511 if not variable_list: 

1512 raise ValueError("variable_list may not be empty") 

1513 # pylint: disable=protected-access 

1514 for v in variable_list: 

1515 # Sort the variable_list lexicographically according to var offset value. 

1516 if not all(v._get_save_slice_info() is not None for v in variable_list): 

1517 raise ValueError( 

1518 "All variables must have a save_slice_info available: %s" % 

1519 [v.name for v in variable_list]) 

1520 if len(shape) != len(partitions): 

1521 raise ValueError("len(shape) != len(partitions): %s vs. %s" % 

1522 (shape, partitions)) 

1523 if v._get_save_slice_info().full_shape != shape: 

1524 raise ValueError("All variables' full shapes must match shape: %s; " 

1525 "but full shapes were: %s" % 

1526 (shape, str([v._get_save_slice_info().full_shape]))) 

1527 self._variable_list = sorted( 

1528 variable_list, key=lambda v: v._get_save_slice_info().var_offset) 

1529 # pylint: enable=protected-access 

1530 

1531 self._name = name 

1532 self._shape = shape 

1533 self._dtype = dtype 

1534 self._partitions = partitions 

1535 self._as_tensor = None 

1536 

1537 def __iter__(self): 

1538 """Return an iterable for accessing the underlying partition Variables.""" 

1539 return iter(self._variable_list) 

1540 

1541 def __len__(self): 

1542 num_partition_axes = len(self._partition_axes()) 

1543 if num_partition_axes > 1: 

1544 raise ValueError("Cannot get a length for %d > 1 partition axes" % 

1545 num_partition_axes) 

1546 return len(self._variable_list) 

1547 

1548 def _partition_axes(self): 

1549 if all(p == 1 for p in self._partitions): 

1550 return [0] 

1551 else: 

1552 return [i for i, p in enumerate(self._partitions) if p > 1] 

1553 

1554 def _concat(self): 

1555 """Returns the overall concatenated value as a `Tensor`. 

1556 

1557 This is different from using the partitioned variable directly as a tensor 

1558 (through tensor conversion and `as_tensor`) in that it creates a new set of 

1559 operations that keeps the control dependencies from its scope. 

1560 

1561 Returns: 

1562 `Tensor` containing the concatenated value. 

1563 """ 

1564 if len(self._variable_list) == 1: 

1565 with ops.name_scope(None): 

1566 return array_ops.identity(self._variable_list[0], name=self._name) 

1567 

1568 partition_axes = self._partition_axes() 

1569 

1570 if len(partition_axes) > 1: 

1571 raise NotImplementedError( 

1572 "Cannot concatenate along more than one dimension: %s. " 

1573 "Multi-axis partition concat is not supported" % str(partition_axes)) 

1574 partition_ix = partition_axes[0] 

1575 

1576 with ops.name_scope(self._name + "/ConcatPartitions/"): 

1577 concatenated = array_ops.concat(self._variable_list, partition_ix) 

1578 

1579 with ops.name_scope(None): 

1580 return array_ops.identity(concatenated, name=self._name) 

1581 

1582 def as_tensor(self): 

1583 """Returns the overall concatenated value as a `Tensor`. 

1584 

1585 The returned tensor will not inherit the control dependencies from the scope 

1586 where the value is used, which is similar to getting the value of 

1587 `Variable`. 

1588 

1589 Returns: 

1590 `Tensor` containing the concatenated value. 

1591 """ 

1592 with ops.control_dependencies(None): 

1593 return self._concat() 

1594 

1595 @staticmethod 

1596 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): 

1597 # pylint: disable=invalid-name 

1598 _ = name 

1599 if dtype is not None and not dtype.is_compatible_with(v.dtype): 

1600 raise ValueError( 

1601 "Incompatible type conversion requested to type '%s' for variable " 

1602 "of type '%s'" % (dtype.name, v.dtype.name)) 

1603 if as_ref: 

1604 raise NotImplementedError( 

1605 "PartitionedVariable doesn't support being used as a reference.") 

1606 else: 

1607 return v.as_tensor() 

1608 

1609 @property 

1610 def name(self): 

1611 return self._name 

1612 

1613 @property 

1614 def dtype(self): 

1615 return self._dtype 

1616 

1617 @property 

1618 def shape(self): 

1619 return self.get_shape() 

1620 

1621 @property 

1622 def _distribute_strategy(self): 

1623 """The `tf.distribute.Strategy` that this variable was created under.""" 

1624 # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy. 

1625 return None 

1626 

1627 def get_shape(self): 

1628 return self._shape 

1629 

1630 def _get_variable_list(self): 

1631 return self._variable_list 

1632 

1633 def _get_partitions(self): 

1634 return self._partitions 

1635 

1636 def _apply_assign_fn(self, assign_fn, value): 

1637 partition_axes = self._partition_axes() 

1638 if len(partition_axes) > 1: 

1639 raise NotImplementedError( 

1640 "Cannot do assign action along more than one dimension: %s. " 

1641 "Multi-axis partition assign action is not supported " % 

1642 str(partition_axes)) 

1643 if isinstance(value, list): 

1644 assert len(value) == len(self._variable_list) 

1645 value_list = value 

1646 elif isinstance(value, PartitionedVariable): 

1647 value_list = list(value) 

1648 else: 

1649 partition_ix = partition_axes[0] 

1650 size_splits_list = [ 

1651 tensor_shape.dimension_value(var.shape[partition_ix]) 

1652 for var in self._variable_list 

1653 ] 

1654 value_list = array_ops.split(value, size_splits_list, axis=partition_ix) 

1655 

1656 op_list = [ 

1657 assign_fn(var, value_list[idx]) 

1658 for idx, var in enumerate(self._variable_list) 

1659 ] 

1660 return op_list 

1661 

1662 def assign(self, value, use_locking=False, name=None, read_value=True): 

1663 assign_fn = lambda var, r_value: var.assign( 

1664 r_value, use_locking=use_locking, name=name, read_value=read_value) 

1665 assign_list = self._apply_assign_fn(assign_fn, value) 

1666 if read_value: 

1667 return assign_list 

1668 return [assign.op for assign in assign_list] 

1669 

1670 def assign_add(self, value, use_locking=False, name=None, read_value=True): 

1671 assign_fn = lambda var, r_value: var.assign_add( 

1672 r_value, use_locking=use_locking, name=name, read_value=read_value) 

1673 assign_list = self._apply_assign_fn(assign_fn, value) 

1674 if read_value: 

1675 return assign_list 

1676 return [assign.op for assign in assign_list] 

1677 

1678 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 

1679 assign_fn = lambda var, r_value: var.assign_sub( 

1680 r_value, use_locking=use_locking, name=name, read_value=read_value) 

1681 assign_list = self._apply_assign_fn(assign_fn, value) 

1682 if read_value: 

1683 return assign_list 

1684 return [assign.op for assign in assign_list] 

1685 

1686 

1687@tf_export(v1=["global_variables"]) 

1688def global_variables(scope=None): 

1689 """Returns global variables. 

1690 

1691 Global variables are variables that are shared across machines in a 

1692 distributed environment. The `Variable()` constructor or `get_variable()` 

1693 automatically adds new variables to the graph collection 

1694 `GraphKeys.GLOBAL_VARIABLES`. 

1695 This convenience function returns the contents of that collection. 

1696 

1697 An alternative to global variables are local variables. See 

1698 `tf.compat.v1.local_variables` 

1699 

1700 @compatibility(TF2) 

1701 Not compatible with eager execution and `tf.function`. In particular, Graph 

1702 collections are deprecated in TF2. Instead please create a 

1703 [tf.Module](https://www.tensorflow.org/guide/intro_to_modules) 

1704 container for all your model state, including variables. 

1705 You can then list all the variables in your `tf.Module` through the 

1706 `variables` attribute. 

1707 @end_compatibility 

1708 

1709 Args: 

1710 scope: (Optional.) A string. If supplied, the resulting list is filtered to 

1711 include only items whose `name` attribute matches `scope` using 

1712 `re.match`. Items without a `name` attribute are never returned if a scope 

1713 is supplied. The choice of `re.match` means that a `scope` without special 

1714 tokens filters by prefix. 

1715 

1716 Returns: 

1717 A list of `Variable` objects. 

1718 """ 

1719 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) 

1720 

1721 

1722@tf_export(v1=["all_variables"]) 

1723@deprecated("2017-03-02", "Please use tf.global_variables instead.") 

1724def all_variables(): 

1725 """Use `tf.compat.v1.global_variables` instead.""" 

1726 return global_variables() 

1727 

1728 

1729def _all_saveable_objects(scope=None): 

1730 """Returns all variables and `SaveableObject`s that must be checkpointed. 

1731 

1732 Args: 

1733 scope: (Optional.) A string. If supplied, the resulting list is filtered to 

1734 include only items whose `name` attribute matches `scope` using 

1735 `re.match`. Items without a `name` attribute are never returned if a scope 

1736 is supplied. The choice of `re.match` means that a `scope` without special 

1737 tokens filters by prefix. 

1738 

1739 Returns: 

1740 A list of `Variable` and `SaveableObject` to be checkpointed 

1741 """ 

1742 # TODO(andreasst): make this function public once things are settled. 

1743 return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) + 

1744 ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope)) 

1745 

1746 

1747@tf_export(v1=["local_variables"]) 

1748def local_variables(scope=None): 

1749 """Returns local variables. 

1750 

1751 Local variables - per process variables, usually not saved/restored to 

1752 checkpoint and used for temporary or intermediate values. 

1753 For example, they can be used as counters for metrics computation or 

1754 number of epochs this machine has read data. 

1755 The `tf.contrib.framework.local_variable()` function automatically adds the 

1756 new variable to `GraphKeys.LOCAL_VARIABLES`. 

1757 This convenience function returns the contents of that collection. 

1758 

1759 An alternative to local variables are global variables. See 

1760 `tf.compat.v1.global_variables` 

1761 

1762 Args: 

1763 scope: (Optional.) A string. If supplied, the resulting list is filtered to 

1764 include only items whose `name` attribute matches `scope` using 

1765 `re.match`. Items without a `name` attribute are never returned if a scope 

1766 is supplied. The choice of `re.match` means that a `scope` without special 

1767 tokens filters by prefix. 

1768 

1769 Returns: 

1770 A list of local `Variable` objects. 

1771 """ 

1772 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope) 

1773 

1774 

1775@tf_export(v1=["model_variables"]) 

1776def model_variables(scope=None): 

1777 """Returns all variables in the MODEL_VARIABLES collection. 

1778 

1779 Args: 

1780 scope: (Optional.) A string. If supplied, the resulting list is filtered to 

1781 include only items whose `name` attribute matches `scope` using 

1782 `re.match`. Items without a `name` attribute are never returned if a scope 

1783 is supplied. The choice of `re.match` means that a `scope` without special 

1784 tokens filters by prefix. 

1785 

1786 Returns: 

1787 A list of local Variable objects. 

1788 """ 

1789 return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope) 

1790 

1791 

1792@tf_export(v1=["trainable_variables"]) 

1793def trainable_variables(scope=None): 

1794 """Returns all variables created with `trainable=True`. 

1795 

1796 When passed `trainable=True`, the `Variable()` constructor automatically 

1797 adds new variables to the graph collection 

1798 `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the 

1799 contents of that collection. 

1800 

1801 @compatibility(TF2) 

1802 Not compatible with eager execution and `tf.function`. In particular, Graph 

1803 collections are deprecated in TF2. Instead please create a `tf.Module` 

1804 container for all your model state, including variables. 

1805 You can then list all the trainable variables in your `tf.Module` through the 

1806 `trainable_variables` attribute. 

1807 @end_compatibility 

1808 

1809 Args: 

1810 scope: (Optional.) A string. If supplied, the resulting list is filtered to 

1811 include only items whose `name` attribute matches `scope` using 

1812 `re.match`. Items without a `name` attribute are never returned if a scope 

1813 is supplied. The choice of `re.match` means that a `scope` without special 

1814 tokens filters by prefix. 

1815 

1816 Returns: 

1817 A list of Variable objects. 

1818 """ 

1819 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope) 

1820 

1821 

1822@tf_export(v1=["moving_average_variables"]) 

1823def moving_average_variables(scope=None): 

1824 """Returns all variables that maintain their moving averages. 

1825 

1826 If an `ExponentialMovingAverage` object is created and the `apply()` 

1827 method is called on a list of variables, these variables will 

1828 be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. 

1829 This convenience function returns the contents of that collection. 

1830 

1831 Args: 

1832 scope: (Optional.) A string. If supplied, the resulting list is filtered to 

1833 include only items whose `name` attribute matches `scope` using 

1834 `re.match`. Items without a `name` attribute are never returned if a scope 

1835 is supplied. The choice of `re.match` means that a `scope` without special 

1836 tokens filters by prefix. 

1837 

1838 Returns: 

1839 A list of Variable objects. 

1840 """ 

1841 return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope) 

1842 

1843 

1844@tf_export(v1=["initializers.variables", "variables_initializer"]) 

1845def variables_initializer(var_list, name="init"): 

1846 """Returns an Op that initializes a list of variables. 

1847 

1848 After you launch the graph in a session, you can run the returned Op to 

1849 initialize all the variables in `var_list`. This Op runs all the 

1850 initializers of the variables in `var_list` in parallel. 

1851 

1852 Calling `initialize_variables()` is equivalent to passing the list of 

1853 initializers to `Group()`. 

1854 

1855 If `var_list` is empty, however, the function still returns an Op that can 

1856 be run. That Op just has no effect. 

1857 

1858 @compatibility(TF2) 

1859 In TF2, variables are initialized immediately when they are created. There is 

1860 no longer a need to run variable initializers before using them. 

1861 @end_compatibility 

1862 

1863 Args: 

1864 var_list: List of `Variable` objects to initialize. 

1865 name: Optional name for the returned operation. 

1866 

1867 Returns: 

1868 An Op that run the initializers of all the specified variables. 

1869 """ 

1870 if var_list and not context.executing_eagerly(): 

1871 return control_flow_ops.group(*[v.initializer for v in var_list], name=name) 

1872 return control_flow_ops.no_op(name=name) 

1873 

1874 

1875@tf_export(v1=["initialize_variables"]) 

1876@tf_should_use.should_use_result 

1877@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.") 

1878def initialize_variables(var_list, name="init"): 

1879 """See `tf.compat.v1.variables_initializer`.""" 

1880 return variables_initializer(var_list, name=name) 

1881 

1882 

1883@tf_export(v1=["initializers.global_variables", "global_variables_initializer"]) 

1884def global_variables_initializer(): 

1885 """Returns an Op that initializes global variables. 

1886 

1887 This is just a shortcut for `variables_initializer(global_variables())` 

1888 

1889 @compatibility(TF2) 

1890 In TF2, variables are initialized immediately when they are created. There is 

1891 no longer a need to run variable initializers before using them. 

1892 @end_compatibility 

1893 

1894 Returns: 

1895 An Op that initializes global variables in the graph. 

1896 """ 

1897 if context.executing_eagerly(): 

1898 return control_flow_ops.no_op(name="global_variables_initializer") 

1899 return variables_initializer(global_variables()) 

1900 

1901 

1902@tf_export(v1=["initialize_all_variables"]) 

1903@tf_should_use.should_use_result 

1904@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.") 

1905def initialize_all_variables(): 

1906 """See `tf.compat.v1.global_variables_initializer`.""" 

1907 return global_variables_initializer() 

1908 

1909 

1910@tf_export(v1=["initializers.local_variables", "local_variables_initializer"]) 

1911def local_variables_initializer(): 

1912 """Returns an Op that initializes all local variables. 

1913 

1914 This is just a shortcut for `variables_initializer(local_variables())` 

1915 

1916 @compatibility(TF2) 

1917 In TF2, variables are initialized immediately when they are created. There is 

1918 no longer a need to run variable initializers before using them. 

1919 @end_compatibility 

1920 

1921 Returns: 

1922 An Op that initializes all local variables in the graph. 

1923 """ 

1924 if context.executing_eagerly(): 

1925 return control_flow_ops.no_op(name="local_variables_initializer") 

1926 return variables_initializer(local_variables()) 

1927 

1928 

1929@tf_export(v1=["initialize_local_variables"]) 

1930@tf_should_use.should_use_result 

1931@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.") 

1932def initialize_local_variables(): 

1933 """See `tf.compat.v1.local_variables_initializer`.""" 

1934 return local_variables_initializer() 

1935 

1936 

1937@tf_export(v1=["assert_variables_initialized"]) 

1938@tf_should_use.should_use_result 

1939def assert_variables_initialized(var_list=None): 

1940 """Returns an Op to check if variables are initialized. 

1941 

1942 NOTE: This function is obsolete and will be removed in 6 months. Please 

1943 change your implementation to use `report_uninitialized_variables()`. 

1944 

1945 When run, the returned Op will raise the exception `FailedPreconditionError` 

1946 if any of the variables has not yet been initialized. 

1947 

1948 Note: This function is implemented by trying to fetch the values of the 

1949 variables. If one of the variables is not initialized a message may be 

1950 logged by the C++ runtime. This is expected. 

1951 

1952 Args: 

1953 var_list: List of `Variable` objects to check. Defaults to the value of 

1954 `global_variables().` 

1955 

1956 Returns: 

1957 An Op, or None if there are no variables. 

1958 """ 

1959 if var_list is None: 

1960 var_list = global_variables() + local_variables() 

1961 # Backwards compatibility for old-style variables. TODO(touts): remove. 

1962 if not var_list: 

1963 var_list = [] 

1964 for op in ops.get_default_graph().get_operations(): 

1965 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 

1966 var_list.append(op.outputs[0]) 

1967 if not var_list: 

1968 return None 

1969 else: 

1970 ranks = [] 

1971 for var in var_list: 

1972 with ops.colocate_with(var.op): 

1973 ranks.append(array_ops.rank_internal(var, optimize=False)) 

1974 if len(ranks) == 1: 

1975 return ranks[0] 

1976 else: 

1977 return array_ops_stack.stack(ranks) 

1978 

1979 

1980@tf_export(v1=["report_uninitialized_variables"]) 

1981@tf_should_use.should_use_result 

1982def report_uninitialized_variables(var_list=None, 

1983 name="report_uninitialized_variables"): 

1984 """Adds ops to list the names of uninitialized variables. 

1985 

1986 When run, it returns a 1-D tensor containing the names of uninitialized 

1987 variables if there are any, or an empty array if there are none. 

1988 

1989 Args: 

1990 var_list: List of `Variable` objects to check. Defaults to the value of 

1991 `global_variables() + local_variables()` 

1992 name: Optional name of the `Operation`. 

1993 

1994 Returns: 

1995 A 1-D tensor containing names of the uninitialized variables, or an empty 

1996 1-D tensor if there are no variables or no uninitialized variables. 

1997 """ 

1998 if var_list is None: 

1999 var_list = global_variables() + local_variables() 

2000 # Backwards compatibility for old-style variables. TODO(touts): remove. 

2001 if not var_list: 

2002 var_list = [] 

2003 for op in ops.get_default_graph().get_operations(): 

2004 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 

2005 var_list.append(op.outputs[0]) 

2006 with ops.name_scope(name): 

2007 # Run all operations on CPU 

2008 if var_list: 

2009 init_vars = [state_ops.is_variable_initialized(v) for v in var_list] 

2010 local_device = os.environ.get( 

2011 "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0") 

2012 with ops.device(local_device): 

2013 if not var_list: 

2014 # Return an empty tensor so we only need to check for returned tensor 

2015 # size being 0 as an indication of model ready. 

2016 return array_ops.constant([], dtype=dtypes.string) 

2017 else: 

2018 # Get a 1-D boolean tensor listing whether each variable is initialized. 

2019 variables_mask = math_ops.logical_not(array_ops_stack.stack(init_vars)) 

2020 # Get a 1-D string tensor containing all the variable names. 

2021 variable_names_tensor = array_ops.constant( 

2022 [s.op.name for s in var_list]) 

2023 # Return a 1-D tensor containing all the names of 

2024 # uninitialized variables. 

2025 return array_ops.boolean_mask(variable_names_tensor, variables_mask) 

2026 

2027 

2028tensor_conversion_registry.register_tensor_conversion_function( 

2029 PartitionedVariable, PartitionedVariable._TensorConversionFunction) # pylint: disable=protected-access