Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/variable_scope.py: 24%

792 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A class to store named variables and a scope operator to manage sharing.""" 

16 

17import copy 

18import enum 

19import functools 

20import sys 

21import threading 

22import traceback 

23 

24from tensorflow.python import tf2 

25from tensorflow.python.client import session 

26from tensorflow.python.eager import context 

27from tensorflow.python.eager import monitoring 

28from tensorflow.python.framework import dtypes 

29from tensorflow.python.framework import ops 

30from tensorflow.python.framework import tensor_conversion_registry 

31from tensorflow.python.framework import tensor_shape 

32from tensorflow.python.ops import array_ops 

33from tensorflow.python.ops import init_ops 

34from tensorflow.python.ops import ref_variable 

35from tensorflow.python.ops import resource_variable_ops 

36from tensorflow.python.ops import variable_v1 

37from tensorflow.python.ops import variables 

38from tensorflow.python.platform import tf_logging as logging 

39from tensorflow.python.types import core 

40from tensorflow.python.util import deprecation 

41from tensorflow.python.util import function_utils 

42from tensorflow.python.util import tf_contextlib 

43from tensorflow.python.util import tf_inspect 

44from tensorflow.python.util.compat import collections_abc 

45from tensorflow.python.util.tf_export import tf_export 

46 

47 

48__all__ = [ 

49 "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", 

50 "get_local_variable", "variable_scope", "variable_op_scope", 

51 "no_regularizer", "VariableSynchronization", "VariableAggregation" 

52] 

53 

54_api_usage_gauge = monitoring.BoolGauge( 

55 "/tensorflow/api/resource_variables", 

56 "Whether variable_scope.enable_resource_variables() is called.") 

57 

58 

59class _PartitionInfo: 

60 """Holds partition info used by initializer functions.""" 

61 

62 __slots__ = ["_full_shape", "_var_offset"] 

63 

64 def __init__(self, full_shape, var_offset): 

65 """Constructor. 

66 

67 Args: 

68 full_shape: Tuple or list of `int` indicating the full combined shape of 

69 the partitioned variables. 

70 var_offset: Tuple or list of `int` specifying offset of this partition 

71 with respect to the full variable for each dimension. 

72 

73 Raises: 

74 TypeError: If `full_shape` or `var_offset` is not a sequence. 

75 ValueError: If `full_shape` or `var_offset` differ in length. If 

76 `var_offset` exceeds `full_shape` in any dimension. 

77 """ 

78 if not isinstance(full_shape, (list, tuple)): 

79 raise TypeError( 

80 "`full_shape` must be a sequence (like tuple or list) instead of " + 

81 type(full_shape).__name__) 

82 

83 if not isinstance(var_offset, (list, tuple)): 

84 raise TypeError( 

85 "`var_offset` must be a sequence (like tuple or list) instead of " + 

86 type(var_offset).__name__) 

87 

88 if len(var_offset) != len(full_shape): 

89 raise ValueError( 

90 "Expected equal length, but `var_offset` is of length {} while " 

91 "full_shape is of length {}.".format( 

92 len(var_offset), len(full_shape))) 

93 

94 for offset, shape in zip(var_offset, full_shape): 

95 if offset < 0 or offset >= shape: 

96 raise ValueError( 

97 "Expected 0 <= offset < shape but found offset={}, shape={} for " 

98 "var_offset={}, full_shape={}".format(offset, shape, var_offset, 

99 full_shape)) 

100 

101 self._full_shape = full_shape 

102 self._var_offset = var_offset 

103 

104 @property 

105 def full_shape(self): 

106 return self._full_shape 

107 

108 @property 

109 def var_offset(self): 

110 return self._var_offset 

111 

112 def single_offset(self, shape): 

113 """Returns the offset when the variable is partitioned in at most one dim. 

114 

115 Args: 

116 shape: Tuple or list of `int` indicating the shape of one specific 

117 variable partition. 

118 

119 Returns: 

120 `int` representing the offset in the dimension along which the variable is 

121 partitioned. Returns 0 if the variable is not being partitioned. 

122 

123 Raises: 

124 ValueError: Depending on self.single_slice_dim(). 

125 """ 

126 

127 single_slice_dim = self.single_slice_dim(shape) 

128 # If this variable is not being partitioned at all, single_slice_dim() could 

129 # return None. 

130 if single_slice_dim is None: 

131 return 0 

132 return self.var_offset[single_slice_dim] 

133 

134 def single_slice_dim(self, shape): 

135 """Returns the slice dim when the variable is partitioned only in one dim. 

136 

137 Args: 

138 shape: Tuple or list of `int` indicating the shape of one specific 

139 variable partition. 

140 

141 Returns: 

142 `int` representing the dimension that the variable is partitioned in, or 

143 `None` if the variable doesn't seem to be partitioned at all. 

144 

145 Raises: 

146 TypeError: If `shape` is not a sequence. 

147 ValueError: If `shape` is not the same length as `self.full_shape`. If 

148 the variable is partitioned in more than one dimension. 

149 """ 

150 if not isinstance(shape, (tuple, list)): 

151 raise TypeError( 

152 "`shape` must be a sequence (like tuple or list) instead of " + 

153 type(shape).__name__) 

154 

155 if len(shape) != len(self.full_shape): 

156 raise ValueError( 

157 "Expected equal length, but received shape={} of length {} while " 

158 "self.full_shape={} is of length {}.".format(shape, len(shape), 

159 self.full_shape, 

160 len(self.full_shape))) 

161 

162 for i in range(len(shape)): 

163 if self.var_offset[i] + shape[i] > self.full_shape[i]: 

164 raise ValueError( 

165 "With self.var_offset={}, a partition of shape={} would exceed " 

166 "self.full_shape={} in dimension {}.".format( 

167 self.var_offset, shape, self.full_shape, i)) 

168 

169 slice_dim = None 

170 for i in range(len(shape)): 

171 if shape[i] == self.full_shape[i]: 

172 continue 

173 if slice_dim is not None: 

174 raise ValueError( 

175 "Cannot use single_slice_dim() with shape={} and " 

176 "self.full_shape={} since slice dim could be either dimension {} " 

177 "or {}.".format(shape, self.full_shape, i, slice_dim)) 

178 slice_dim = i 

179 

180 return slice_dim 

181 

182 

183class _ReuseMode(enum.Enum): 

184 """Mode for variable access within a variable scope.""" 

185 

186 # Indicates that variables are to be fetched if they already exist or 

187 # otherwise created. 

188 AUTO_REUSE = 1 

189 

190 # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of 

191 # enum values. 

192 # REUSE_FALSE = 2 

193 # REUSE_TRUE = 3 

194 

195 

196# TODO(apassos) remove these forwarding symbols. 

197VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name 

198VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name 

199 

200AUTO_REUSE = _ReuseMode.AUTO_REUSE 

201tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE") 

202AUTO_REUSE.__doc__ = """ 

203@compatibility(TF2) 

204`tf.compat.v1.AUTO_REUSE` is a legacy API that is a no-op when TF2 behaviors 

205are enabled. 

206 

207If you rely on `get_variable` and auto-reuse, see the 

208[model mapping guide](https://www.tensorflow.org/guide/migrate/model_mapping) 

209for more info on how to migrate your code. 

210 

211Note: when you use the `tf.compat.v1.keras.utils.track_tf1_style_variables` 

212API as described in the above guide, `get_variable` will always behave as if 

213`v1.AUTO_REUSE` is set. Without the decorator, reuse will be ignored and new 

214variables will always be created, regardless of if they have already been 

215created. 

216@end_compatibility 

217 

218When passed in as the value for the `reuse` flag, `AUTO_REUSE` indicates that 

219get_variable() should create the requested variable if it doesn't exist or, if 

220it does exist, simply return it. 

221""" 

222 

223_DEFAULT_USE_RESOURCE = tf2.enabled() 

224 

225 

226@tf_export(v1=["enable_resource_variables"]) 

227def enable_resource_variables(): 

228 """Creates resource variables by default. 

229 

230 Resource variables are improved versions of TensorFlow variables with a 

231 well-defined memory model. Accessing a resource variable reads its value, and 

232 all ops which access a specific read value of the variable are guaranteed to 

233 see the same value for that tensor. Writes which happen after a read (by 

234 having a control or data dependency on the read) are guaranteed not to affect 

235 the value of the read tensor, and similarly writes which happen before a read 

236 are guaranteed to affect the value. No guarantees are made about unordered 

237 read/write pairs. 

238 

239 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 

240 feature. 

241 """ 

242 global _DEFAULT_USE_RESOURCE 

243 _DEFAULT_USE_RESOURCE = True 

244 logging.vlog(1, "Enabling resource variables") 

245 _api_usage_gauge.get_cell().set(True) 

246 

247 

248@tf_export(v1=["resource_variables_enabled"]) 

249def resource_variables_enabled(): 

250 """Returns `True` if resource variables are enabled. 

251 

252 Resource variables are improved versions of TensorFlow variables with a 

253 well-defined memory model. Accessing a resource variable reads its value, and 

254 all ops which access a specific read value of the variable are guaranteed to 

255 see the same value for that tensor. Writes which happen after a read (by 

256 having a control or data dependency on the read) are guaranteed not to affect 

257 the value of the read tensor, and similarly writes which happen before a read 

258 are guaranteed to affect the value. No guarantees are made about unordered 

259 read/write pairs. 

260 

261 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 

262 feature. 

263 """ 

264 global _DEFAULT_USE_RESOURCE 

265 return _DEFAULT_USE_RESOURCE 

266 

267 

268@deprecation.deprecated( 

269 None, "non-resource variables are not supported in the long term") 

270@tf_export(v1=["disable_resource_variables"]) 

271def disable_resource_variables(): 

272 """Opts out of resource variables. 

273 

274 If your code needs tf.disable_resource_variables() to be called to work 

275 properly please file a bug. 

276 """ 

277 global _DEFAULT_USE_RESOURCE 

278 _DEFAULT_USE_RESOURCE = False 

279 logging.vlog(1, "Disabling resource variables") 

280 _api_usage_gauge.get_cell().set(False) 

281 

282 

283def _needs_no_arguments(python_callable): 

284 """Returns true if the callable needs no arguments to call.""" 

285 # TODO(bfontain): Switch to inspect.signature when we are python 3 only. 

286 # signature = inspect.signature(python_callable) 

287 # return not [1 for param in signature.parameters.values() 

288 # if param.default == param.empty] 

289 num_arguments = len(tf_inspect.getargspec(python_callable).args) 

290 if not tf_inspect.isfunction(python_callable) and not isinstance( 

291 python_callable, functools.partial): 

292 # getargspec includes self for function objects (which aren't 

293 # functools.partial). This has no default so we need to remove it. 

294 # It is not even an argument so its odd that getargspec returns this. 

295 # Note that this is fixed with inspect.signature in Python 3. 

296 num_arguments -= 1 

297 return num_arguments == len( 

298 tf_inspect.getargspec(python_callable).defaults or []) 

299 

300 

301class _VariableStore: 

302 """Variable store that carries a number of named Variables. 

303 

304 New variable names and new variables can be created; all stored 

305 variables are initialized with the initializer passed to __init__. 

306 

307 Attributes: 

308 vars: a dictionary with string names (same as passed in GetVar) as keys and 

309 the corresponding TensorFlow Variables as values. 

310 """ 

311 

312 __slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"] 

313 

314 def __init__(self): 

315 """Create a variable store.""" 

316 self._vars = {} # A dictionary of the stored TensorFlow variables. 

317 self._partitioned_vars = {} # A dict of the stored PartitionedVariables. 

318 self._store_eager_variables = False 

319 

320 def get_variable(self, 

321 name, 

322 shape=None, 

323 dtype=dtypes.float32, 

324 initializer=None, 

325 regularizer=None, 

326 reuse=None, 

327 trainable=None, 

328 collections=None, 

329 caching_device=None, 

330 partitioner=None, 

331 validate_shape=True, 

332 use_resource=None, 

333 custom_getter=None, 

334 constraint=None, 

335 synchronization=VariableSynchronization.AUTO, 

336 aggregation=VariableAggregation.NONE): 

337 """Gets an existing variable with these parameters or create a new one. 

338 

339 If a variable with the given name is already stored, we return the stored 

340 variable. Otherwise, we create a new one. 

341 

342 Set `reuse` to `True` when you only want to reuse existing Variables. 

343 Set `reuse` to `False` when you only want to create new Variables. 

344 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 

345 variables to be created if they don't exist or returned if they do. 

346 

347 If initializer is `None` (the default), the default initializer passed in 

348 the constructor is used. If that one is `None` too, we use a new 

349 `glorot_uniform_initializer`. If initializer is a Tensor, we use 

350 it as a value and derive the shape from the initializer. 

351 

352 If a partitioner is provided, a `PartitionedVariable` is returned. 

353 Accessing this object as a `Tensor` returns the shards concatenated along 

354 the partition axis. 

355 

356 Some useful partitioners are available. See, e.g., 

357 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 

358 

359 Args: 

360 name: The name of the new or existing variable. 

361 shape: Shape of the new or existing variable. 

362 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 

363 initializer: Initializer for the variable. 

364 regularizer: A (Tensor -> Tensor or None) function; the result of applying 

365 it on a newly created variable will be added to the collection 

366 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 

367 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 

368 variables. When eager execution is enabled this argument is always 

369 forced to be False. 

370 trainable: If `True` also add the variable to the graph collection 

371 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable` 

372 defaults to `True`, unless `synchronization` is set to `ON_READ`, in 

373 which case it defaults to `False`. 

374 collections: List of graph collections keys to add the `Variable` to. 

375 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 

376 caching_device: Optional device string or function describing where the 

377 Variable should be cached for reading. Defaults to the Variable's 

378 device. If not `None`, caches on another device. Typical use is to 

379 cache on the device where the Ops using the `Variable` reside, to 

380 deduplicate copying through `Switch` and other conditional statements. 

381 partitioner: Optional callable that accepts a fully defined `TensorShape` 

382 and dtype of the `Variable` to be created, and returns a list of 

383 partitions for each axis (currently only one axis can be partitioned). 

384 validate_shape: If False, allows the variable to be initialized with a 

385 value of unknown shape. If True, the default, the shape of initial_value 

386 must be known. 

387 use_resource: If False, creates a regular Variable. If True, creates 

388 instead an experimental ResourceVariable which has well-defined 

389 semantics. Defaults to False (will later change to True). When eager 

390 execution is enabled this argument is always forced to be true. 

391 custom_getter: Callable that takes as a first argument the true getter, 

392 and allows overwriting the internal get_variable method. The signature 

393 of `custom_getter` should match that of this method, 

394 but the most future-proof version will allow for changes: `def 

395 custom_getter(getter, *args, **kwargs)`. Direct access to 

396 all `get_variable` parameters is also allowed: `def 

397 custom_getter(getter, name, *args, **kwargs)`. A simple identity 

398 custom getter that simply creates variables with modified names is: 

399 ```python 

400 def custom_getter(getter, name, *args, **kwargs): return getter(name + 

401 '_suffix', *args, **kwargs) ``` 

402 constraint: An optional projection function to be applied to the variable 

403 after being updated by an `Optimizer` (e.g. used to implement norm 

404 constraints or value constraints for layer weights). The function must 

405 take as input the unprojected Tensor representing the value of the 

406 variable and return the Tensor for the projected value (which must have 

407 the same shape). Constraints are not safe to use when doing asynchronous 

408 distributed training. 

409 synchronization: Indicates when a distributed a variable will be 

410 aggregated. Accepted values are constants defined in the class 

411 `tf.VariableSynchronization`. By default the synchronization is set to 

412 `AUTO` and the current `DistributionStrategy` chooses when to 

413 synchronize. 

414 aggregation: Indicates how a distributed variable will be aggregated. 

415 Accepted values are constants defined in the class 

416 `tf.VariableAggregation`. 

417 

418 Returns: 

419 The created or existing `Variable` (or `PartitionedVariable`, if a 

420 partitioner was used). 

421 

422 Raises: 

423 ValueError: when creating a new variable and shape is not declared, 

424 when reusing a variable and specifying a conflicting shape, 

425 or when violating reuse during variable creation. 

426 RuntimeError: when eager execution is enabled and not called from an 

427 EagerVariableStore. 

428 """ 

429 if custom_getter is not None and not callable(custom_getter): 

430 raise ValueError("Passed a custom_getter which is not callable: %s" % 

431 custom_getter) 

432 

433 with ops.init_scope(): 

434 if context.executing_eagerly(): 

435 # Variable creation and initialization takes place in `init_scope`s; 

436 # as such, if an `init_scope` lifts us into the eager context, then we 

437 # need to use `ResourceVariable`s. 

438 use_resource = True 

439 

440 # Note that it's fine to reuse eager variables whose initialization was 

441 # lifted from a function-building graph into the eager context (that's why 

442 # the following clause is not wrapped in an `init_scope`); lifted variables 

443 # are tracked by the graph's `VariableStore`. 

444 if context.executing_eagerly(): 

445 if not self._store_eager_variables and reuse: 

446 raise RuntimeError( 

447 "When eager execution is enabled variable reuse is only supported" 

448 " when an EagerVariableStore is active. See the documentation on" 

449 " EagerVariableStore for example usage.") 

450 if self._store_eager_variables: 

451 reuse = AUTO_REUSE 

452 

453 # If a *_ref type is passed in an error would be triggered further down the 

454 # stack. We prevent this using base_dtype to get a non-ref version of the 

455 # type, before doing anything else. When _ref types are removed in favor of 

456 # resources, this line can be removed. 

457 try: 

458 dtype = dtype.base_dtype 

459 except AttributeError: 

460 # .base_dtype not existing means that we will try and use the raw dtype 

461 # which was passed in - this might be a NumPy type which is valid. 

462 pass 

463 

464 # This is the main logic of get_variable. However, custom_getter 

465 # may override this logic. So we save it as a callable and pass 

466 # it to custom_getter. 

467 # Note: the parameters of _true_getter, and their documentation, match 

468 # *exactly* item-for-item with the docstring of this method. 

469 def _true_getter( # pylint: disable=missing-docstring 

470 name, 

471 shape=None, 

472 dtype=dtypes.float32, 

473 initializer=None, 

474 regularizer=None, 

475 reuse=None, 

476 trainable=None, 

477 collections=None, 

478 caching_device=None, 

479 partitioner=None, 

480 validate_shape=True, 

481 use_resource=None, 

482 constraint=None, 

483 synchronization=VariableSynchronization.AUTO, 

484 aggregation=VariableAggregation.NONE): 

485 is_scalar = ( 

486 shape is not None and isinstance(shape, collections_abc.Sequence) and 

487 not shape) 

488 # Partitioned variable case 

489 if partitioner is not None and not is_scalar: 

490 if not callable(partitioner): 

491 raise ValueError("Partitioner must be callable, but received: %s" % 

492 partitioner) 

493 with ops.name_scope(None): 

494 return self._get_partitioned_variable( 

495 name=name, 

496 shape=shape, 

497 dtype=dtype, 

498 initializer=initializer, 

499 regularizer=regularizer, 

500 reuse=reuse, 

501 trainable=trainable, 

502 collections=collections, 

503 caching_device=caching_device, 

504 partitioner=partitioner, 

505 validate_shape=validate_shape, 

506 use_resource=use_resource, 

507 constraint=constraint, 

508 synchronization=synchronization, 

509 aggregation=aggregation) 

510 

511 # Special case for partitioned variable to allow reuse without having to 

512 # specify partitioner. 

513 if (reuse is True and partitioner is None 

514 and name in self._partitioned_vars): 

515 return self._get_partitioned_variable( 

516 name=name, 

517 shape=shape, 

518 dtype=dtype, 

519 initializer=initializer, 

520 regularizer=regularizer, 

521 reuse=reuse, 

522 trainable=trainable, 

523 collections=collections, 

524 caching_device=caching_device, 

525 partitioner=None, 

526 validate_shape=validate_shape, 

527 use_resource=use_resource, 

528 constraint=constraint, 

529 synchronization=synchronization, 

530 aggregation=aggregation) 

531 

532 # Single variable case 

533 if "%s/part_0" % name in self._vars: 

534 raise ValueError( 

535 "No partitioner was provided, but a partitioned version of the " 

536 "variable was found: %s/part_0. Perhaps a variable of the same " 

537 "name was already created with partitioning?" % name) 

538 

539 return self._get_single_variable( 

540 name=name, 

541 shape=shape, 

542 dtype=dtype, 

543 initializer=initializer, 

544 regularizer=regularizer, 

545 reuse=reuse, 

546 trainable=trainable, 

547 collections=collections, 

548 caching_device=caching_device, 

549 validate_shape=validate_shape, 

550 use_resource=use_resource, 

551 constraint=constraint, 

552 synchronization=synchronization, 

553 aggregation=aggregation) 

554 

555 synchronization, aggregation, trainable = ( 

556 variables.validate_synchronization_aggregation_trainable( 

557 synchronization, aggregation, trainable, name)) 

558 

559 if custom_getter is not None: 

560 # Handle backwards compatibility with getter arguments that were added 

561 # to the API after users started writing custom getters. 

562 custom_getter_kwargs = { 

563 "getter": _true_getter, 

564 "name": name, 

565 "shape": shape, 

566 "dtype": dtype, 

567 "initializer": initializer, 

568 "regularizer": regularizer, 

569 "reuse": reuse, 

570 "trainable": trainable, 

571 "collections": collections, 

572 "caching_device": caching_device, 

573 "partitioner": partitioner, 

574 "validate_shape": validate_shape, 

575 "use_resource": use_resource, 

576 "synchronization": synchronization, 

577 "aggregation": aggregation, 

578 } 

579 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, 

580 # `lambda`. 

581 if ("constraint" in function_utils.fn_args(custom_getter) or 

582 function_utils.has_kwargs(custom_getter)): 

583 custom_getter_kwargs["constraint"] = constraint 

584 return custom_getter(**custom_getter_kwargs) 

585 else: 

586 return _true_getter( 

587 name, 

588 shape=shape, 

589 dtype=dtype, 

590 initializer=initializer, 

591 regularizer=regularizer, 

592 reuse=reuse, 

593 trainable=trainable, 

594 collections=collections, 

595 caching_device=caching_device, 

596 partitioner=partitioner, 

597 validate_shape=validate_shape, 

598 use_resource=use_resource, 

599 constraint=constraint, 

600 synchronization=synchronization, 

601 aggregation=aggregation) 

602 

603 def _get_partitioned_variable(self, 

604 name, 

605 partitioner, 

606 shape=None, 

607 dtype=dtypes.float32, 

608 initializer=None, 

609 regularizer=None, 

610 reuse=None, 

611 trainable=None, 

612 collections=None, 

613 caching_device=None, 

614 validate_shape=True, 

615 use_resource=None, 

616 constraint=None, 

617 synchronization=VariableSynchronization.AUTO, 

618 aggregation=VariableAggregation.NONE): 

619 """Gets or creates a sharded variable list with these parameters. 

620 

621 The `partitioner` must be a callable that accepts a fully defined 

622 `TensorShape` and returns a sequence of integers (the `partitions`). 

623 These integers describe how to partition the given sharded `Variable` 

624 along the given dimension. That is, `partitions[1] = 3` means split 

625 the `Variable` into 3 shards along dimension 1. Currently, sharding along 

626 only one axis is supported. 

627 

628 If the list of variables with the given name (prefix) is already stored, 

629 we return the stored variables. Otherwise, we create a new one. 

630 

631 Set `reuse` to `True` when you only want to reuse existing Variables. 

632 Set `reuse` to `False` when you only want to create new Variables. 

633 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 

634 variables to be created if they don't exist or returned if they do. 

635 

636 If initializer is `None` (the default), the default initializer passed in 

637 the constructor is used. If that one is `None` too, we use a new 

638 `glorot_uniform_initializer`. If initializer is a Tensor, we use 

639 it as a value and derive the shape from the initializer. 

640 

641 If the initializer is a callable, then it will be called for each 

642 shard. Otherwise the initializer should match the shape of the entire 

643 sharded Variable, and it will be sliced accordingly for each shard. 

644 

645 Some useful partitioners are available. See, e.g., 

646 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 

647 

648 Args: 

649 name: the name of the new or existing sharded variable. 

650 partitioner: Optional callable that accepts a fully defined `TensorShape` 

651 and `dtype` of the Variable to be created, and returns a list of 

652 partitions for each axis (currently only one axis can be partitioned). 

653 shape: shape of the new or existing sharded variable. 

654 dtype: type of the new or existing sharded variable (defaults to 

655 `DT_FLOAT`). 

656 initializer: initializer for the sharded variable. 

657 regularizer: a (Tensor -> Tensor or None) function; the result of applying 

658 it on a newly created variable will be added to the collection 

659 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 

660 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 

661 variables. 

662 trainable: If `True` also add the variable to the graph collection 

663 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 

664 collections: List of graph collections keys to add the Variable to. 

665 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 

666 caching_device: Optional device string or function describing where the 

667 Variable should be cached for reading. Defaults to the Variable's 

668 device. If not `None`, caches on another device. Typical use is to 

669 cache on the device where the Ops using the Variable reside, to 

670 deduplicate copying through `Switch` and other conditional statements. 

671 validate_shape: If False, allows the variable to be initialized with a 

672 value of unknown shape. If True, the default, the shape of initial_value 

673 must be known. 

674 use_resource: If False, creates a regular Variable. If True, creates an 

675 experimental ResourceVariable which has well-defined semantics. Defaults 

676 to False (will later change to True). 

677 constraint: An optional projection function to be applied to the variable 

678 after being updated by an `Optimizer` (e.g. used to implement norm 

679 constraints or value constraints for layer weights). The function must 

680 take as input the unprojected Tensor representing the value of the 

681 variable and return the Tensor for the projected value (which must have 

682 the same shape). Constraints are not safe to use when doing asynchronous 

683 distributed training. 

684 synchronization: Indicates when a distributed a variable will be 

685 aggregated. Accepted values are constants defined in the class 

686 `tf.VariableSynchronization`. By default the synchronization is set to 

687 `AUTO` and the current `DistributionStrategy` chooses when to 

688 synchronize. 

689 aggregation: Indicates how a distributed variable will be aggregated. 

690 Accepted values are constants defined in the class 

691 `tf.VariableAggregation`. 

692 

693 Returns: 

694 A `PartitionedVariable` object. 

695 

696 Raises: 

697 ValueError: when creating a new variable and shape is not declared, 

698 when reusing a variable and specifying a conflicting shape, 

699 when violating reuse during variable creation, or if an existing 

700 sharded variable exists for the given name but with different sharding. 

701 """ 

702 initializing_from_value = initializer is not None and isinstance( 

703 initializer, ops.Tensor) 

704 if name in self._vars: 

705 raise ValueError( 

706 "A partitioner was provided, but an unpartitioned version of the " 

707 "variable was found: %s. Perhaps a variable of the same name was " 

708 "already created without partitioning?" % name) 

709 

710 shape = tensor_shape.as_shape(shape) 

711 if initializing_from_value: 

712 shape = shape.merge_with(initializer.get_shape()) 

713 

714 partitions = None 

715 if not reuse or partitioner: 

716 partitions = _call_partitioner(partitioner, shape, dtype) 

717 

718 if name in self._partitioned_vars: 

719 if reuse is False: 

720 raise ValueError( 

721 "Partitioned variable with name %s already exists. Did you mean to " 

722 "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name) 

723 

724 existing_var = self._partitioned_vars[name] 

725 if not shape.is_compatible_with(existing_var.get_shape()): 

726 raise ValueError( 

727 "Trying to reuse partitioned variable %s, but specified shape %s " 

728 "and found shape %s." % (name, shape, existing_var.get_shape())) 

729 if not dtype.is_compatible_with(existing_var.dtype): 

730 raise ValueError( 

731 "Trying to reuse partitioned variable %s, but specified dtype %s " 

732 "and found dtype %s." % (name, dtype.name, existing_var.dtype.name)) 

733 

734 # pylint: disable=protected-access 

735 if (partitions is not None and 

736 existing_var._get_partitions() != partitions): 

737 raise ValueError( 

738 "Trying to reuse partitioned variable %s, but specified partitions " 

739 "%s and found partitions %s." % 

740 (name, partitions, existing_var._get_partitions())) 

741 # pylint: enable=protected-access 

742 

743 return existing_var 

744 

745 if reuse is True: 

746 raise ValueError("PartitionedVariable %s does not exist, or was not " 

747 "created with tf.get_variable(). Did you mean to set " 

748 "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name) 

749 

750 slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions) 

751 

752 if "%s/part_0" % name in self._vars: 

753 if "%s/part_%d" % (name, num_slices - 1) not in self._vars: 

754 raise ValueError( 

755 "Partitioner returned a different partitioning than what was " 

756 "already found. Partitioner returned %d shards, and shard " 

757 "%s/part_0 was found, but %s/part_%d was not." % 

758 (num_slices, name, name, num_slices - 1)) 

759 if "%s/part_%d" % (name, num_slices) in self._vars: 

760 raise ValueError( 

761 "Partitioner returned a different partitioning than what was " 

762 "already found. Partitioner returned %d shards, and shard " 

763 "%s/part_0 was found, but so was the extra shard %s/part_%d." % 

764 (num_slices, name, name, num_slices)) 

765 

766 vs = [] 

767 for i, (var_offset, var_shape) in enumerate( 

768 _iter_slices(shape.as_list(), num_slices, slice_dim)): 

769 partition_info = _PartitionInfo( 

770 full_shape=shape.as_list(), var_offset=var_offset) 

771 var_full_name = "%s/part_%d" % (name, i) 

772 with ops.name_scope( 

773 var_full_name + "/PartitionedInitializer", skip_on_eager=False): 

774 # Create the tensor to initialize the variable with default value. 

775 if initializer is None: 

776 init, initializing_from_value = self._get_default_initializer( 

777 name=name, shape=shape, dtype=dtype) 

778 if initializing_from_value: 

779 init_shape = None 

780 else: 

781 init_shape = var_shape 

782 elif callable(initializer): 

783 init = initializer 

784 init_shape = var_shape 

785 elif isinstance(initializer, ops.Tensor): 

786 init = array_ops.slice(initializer, var_offset, var_shape) 

787 # Use the dtype of the given tensor. 

788 dtype = init.dtype.base_dtype 

789 init_shape = None 

790 else: 

791 init = ops.convert_to_tensor(initializer, dtype=dtype) 

792 init = array_ops.slice(init, var_offset, var_shape) 

793 init_shape = None 

794 

795 with ops.name_scope(None): 

796 var = self._get_single_variable( 

797 name=var_full_name, 

798 shape=init_shape, 

799 dtype=dtype, 

800 initializer=init, 

801 partition_info=partition_info, 

802 regularizer=regularizer, 

803 reuse=reuse, 

804 trainable=trainable, 

805 collections=collections, 

806 caching_device=caching_device, 

807 validate_shape=validate_shape, 

808 use_resource=use_resource, 

809 constraint=constraint, 

810 synchronization=synchronization, 

811 aggregation=aggregation) 

812 

813 # pylint: disable=protected-access 

814 var._set_save_slice_info( 

815 variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset, 

816 var_shape)) 

817 vs.append(var) 

818 # pylint: enable=protected-access 

819 

820 partitioned_var = variables.PartitionedVariable( 

821 name=name, 

822 shape=shape, 

823 dtype=dtype, 

824 variable_list=vs, 

825 partitions=partitions) 

826 if not context.executing_eagerly() or self._store_eager_variables: 

827 self._partitioned_vars[name] = partitioned_var 

828 return partitioned_var 

829 

830 def _get_single_variable(self, 

831 name, 

832 shape=None, 

833 dtype=dtypes.float32, 

834 initializer=None, 

835 regularizer=None, 

836 partition_info=None, 

837 reuse=None, 

838 trainable=None, 

839 collections=None, 

840 caching_device=None, 

841 validate_shape=True, 

842 use_resource=None, 

843 constraint=None, 

844 synchronization=VariableSynchronization.AUTO, 

845 aggregation=VariableAggregation.NONE): 

846 """Get or create a single Variable (e.g. 

847 

848 a shard or entire variable). 

849 

850 See the documentation of get_variable above (ignore partitioning components) 

851 for details. 

852 

853 Args: 

854 name: see get_variable. 

855 shape: see get_variable. 

856 dtype: see get_variable. 

857 initializer: see get_variable. 

858 regularizer: see get_variable. 

859 partition_info: _PartitionInfo object. 

860 reuse: see get_variable. 

861 trainable: see get_variable. 

862 collections: see get_variable. 

863 caching_device: see get_variable. 

864 validate_shape: see get_variable. 

865 use_resource: see get_variable. 

866 constraint: see get_variable. 

867 synchronization: see get_variable. 

868 aggregation: see get_variable. 

869 

870 Returns: 

871 A Variable. See documentation of get_variable above. 

872 

873 Raises: 

874 ValueError: See documentation of get_variable above. 

875 """ 

876 # Set to true if initializer is a constant. 

877 initializing_from_value = False 

878 if initializer is not None and not callable(initializer): 

879 initializing_from_value = True 

880 if shape is not None and initializing_from_value: 

881 raise ValueError("If initializer is a constant, do not specify shape.") 

882 

883 dtype = dtypes.as_dtype(dtype) 

884 shape = tensor_shape.as_shape(shape) 

885 

886 if name in self._vars: 

887 # Here we handle the case when returning an existing variable. 

888 if reuse is False: 

889 var = self._vars[name] 

890 err_msg = ("Variable %s already exists, disallowed." 

891 " Did you mean to set reuse=True or " 

892 "reuse=tf.AUTO_REUSE in VarScope?" % name) 

893 # ResourceVariables don't have an op associated with so no traceback 

894 if isinstance(var, resource_variable_ops.ResourceVariable): 

895 raise ValueError(err_msg) 

896 tb = var.op.traceback[::-1] 

897 # Throw away internal tf entries and only take a few lines. In some 

898 # cases the traceback can be longer (e.g. if someone uses factory 

899 # functions to create variables) so we take more than needed in the 

900 # default case. 

901 tb = [x for x in tb if "tensorflow/python" not in x[0]][:5] 

902 raise ValueError("%s Originally defined at:\n\n%s" % 

903 (err_msg, "".join(traceback.format_list(tb)))) 

904 found_var = self._vars[name] 

905 if not shape.is_compatible_with(found_var.get_shape()): 

906 raise ValueError("Trying to share variable %s, but specified shape %s" 

907 " and found shape %s." % 

908 (name, shape, found_var.get_shape())) 

909 if not dtype.is_compatible_with(found_var.dtype): 

910 dtype_str = dtype.name 

911 found_type_str = found_var.dtype.name 

912 raise ValueError("Trying to share variable %s, but specified dtype %s" 

913 " and found dtype %s." % 

914 (name, dtype_str, found_type_str)) 

915 return found_var 

916 

917 # The code below handles only the case of creating a new variable. 

918 if reuse is True: 

919 raise ValueError("Variable %s does not exist, or was not created with " 

920 "tf.get_variable(). Did you mean to set " 

921 "reuse=tf.AUTO_REUSE in VarScope?" % name) 

922 

923 # Create the tensor to initialize the variable with default value. 

924 if initializer is None: 

925 initializer, initializing_from_value = self._get_default_initializer( 

926 name=name, shape=shape, dtype=dtype) 

927 # Enter an init scope when creating the initializer. 

928 with ops.init_scope(): 

929 if initializing_from_value: 

930 init_val = initializer 

931 variable_dtype = None 

932 else: 

933 # Instantiate initializer if provided initializer is a type object. 

934 if tf_inspect.isclass(initializer): 

935 initializer = initializer() 

936 if shape.is_fully_defined(): 

937 if "partition_info" in tf_inspect.getargspec(initializer).args: 

938 init_val = functools.partial(initializer, 

939 shape.as_list(), 

940 dtype=dtype, 

941 partition_info=partition_info) 

942 else: 

943 init_val = functools.partial(initializer, 

944 shape.as_list(), dtype=dtype) 

945 variable_dtype = dtype.base_dtype 

946 elif _needs_no_arguments(initializer): 

947 init_val = initializer 

948 variable_dtype = None 

949 else: 

950 raise ValueError("The initializer passed is not valid. It should " 

951 "be a callable with no arguments and the " 

952 "shape should not be provided or an instance of " 

953 "`tf.keras.initializers.*' and `shape` should be " 

954 "fully defined.") 

955 

956 # Create the variable. 

957 if use_resource is None: 

958 # Set the default value if unspecified. 

959 use_resource = _DEFAULT_USE_RESOURCE 

960 v = variable_v1.VariableV1( 

961 initial_value=init_val, 

962 name=name, 

963 trainable=trainable, 

964 collections=collections, 

965 caching_device=caching_device, 

966 dtype=variable_dtype, 

967 validate_shape=validate_shape, 

968 constraint=constraint, 

969 use_resource=use_resource, 

970 synchronization=synchronization, 

971 aggregation=aggregation) 

972 if context.executing_eagerly() and self._store_eager_variables: 

973 if collections: 

974 ops.add_to_collections(collections, v) 

975 else: 

976 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v) 

977 if trainable: 

978 ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v) 

979 

980 if not context.executing_eagerly() or self._store_eager_variables: 

981 # In eager mode we do not want to keep default references to Variable 

982 # objects as this will prevent their memory from being released. 

983 self._vars[name] = v 

984 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 

985 format(shape), initializer) 

986 

987 # Run the regularizer if requested and save the resulting loss. 

988 if regularizer: 

989 def make_regularizer_op(): 

990 with ops.colocate_with(v): 

991 with ops.name_scope(name + "/Regularizer/"): 

992 return regularizer(v) 

993 

994 if regularizer(v) is not None: 

995 lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op) 

996 ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, 

997 lazy_eval_tensor) 

998 

999 return v 

1000 

1001 # Initialize variable when no initializer provided 

1002 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 

1003 """Provide a default initializer and a corresponding value. 

1004 

1005 Args: 

1006 name: see get_variable. 

1007 shape: see get_variable. 

1008 dtype: see get_variable. 

1009 

1010 Returns: 

1011 initializer and initializing_from_value. See get_variable above. 

1012 

1013 Raises: 

1014 ValueError: When giving unsupported dtype. 

1015 """ 

1016 del shape 

1017 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 

1018 if dtype.is_floating: 

1019 initializer = init_ops.glorot_uniform_initializer() 

1020 initializing_from_value = False 

1021 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 

1022 # If dtype is DT_BOOL, provide a default value `FALSE` 

1023 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or 

1024 dtype == dtypes.string): 

1025 initializer = init_ops.zeros_initializer() 

1026 initializing_from_value = False 

1027 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 

1028 else: 

1029 raise ValueError("An initializer for variable %s of %s is required" % 

1030 (name, dtype.base_dtype)) 

1031 

1032 return initializer, initializing_from_value 

1033 

1034 

1035class _LazyEvalTensor(core.Tensor): 

1036 """A Tensor-like object that only evaluates its thunk when used.""" 

1037 

1038 def __init__(self, thunk): 

1039 """Initializes a _LazyEvalTensor object. 

1040 

1041 Args: 

1042 thunk: A callable. A thunk which computes the value of the tensor. 

1043 """ 

1044 self._thunk = thunk 

1045 self._master_tensor = thunk() 

1046 

1047 def _as_tensor(self, dtype=None, name=None, as_ref=False): 

1048 del name 

1049 assert not as_ref 

1050 assert dtype in [None, self.dtype] 

1051 

1052 return self._thunk() 

1053 

1054 

1055def _make_master_property(name): 

1056 @property 

1057 def prop(self): 

1058 return getattr(self._master_tensor, name) # pylint: disable=protected-access 

1059 return prop 

1060 

1061_master_property_list = ("device", "dtype", "graph", "name", "op", "shape", 

1062 "value_index") 

1063for _name in _master_property_list: 

1064 setattr(_LazyEvalTensor, _name, _make_master_property(_name)) 

1065 

1066 

1067def _make_master_method(name): 

1068 def method(self, *args, **kwargs): 

1069 return getattr(self._master_tensor, name)(*args, **kwargs) # pylint: disable=protected-access 

1070 return method 

1071 

1072_master_method_list = ("get_shape", "__str__", "shape_as_list") 

1073for _name in _master_method_list: 

1074 setattr(_LazyEvalTensor, _name, _make_master_method(_name)) 

1075 

1076 

1077def _make_op_method(name): 

1078 def method(self, *args, **kwargs): 

1079 return getattr(self._as_tensor(), name)(*args, **kwargs) # pylint: disable=protected-access 

1080 return method 

1081 

1082_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__", 

1083 "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__", 

1084 "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__", 

1085 "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__", 

1086 "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__", 

1087 "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__", 

1088 "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__", 

1089 "eval", "numpy") 

1090for _name in _op_list: 

1091 setattr(_LazyEvalTensor, _name, _make_op_method(_name)) 

1092 

1093 

1094tensor_conversion_registry.register_tensor_conversion_function( 

1095 _LazyEvalTensor, 

1096 lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref) # pylint: disable=protected-access 

1097 ) 

1098 

1099session.register_session_run_conversion_functions( 

1100 _LazyEvalTensor, 

1101 lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access 

1102 ) 

1103 

1104 

1105# To stop regularization, use this regularizer 

1106@tf_export(v1=["no_regularizer"]) 

1107def no_regularizer(_): 

1108 """Use this function to prevent regularization of variables.""" 

1109 return None 

1110 

1111 

1112# TODO(alive): support caching devices and partitioned variables in Eager mode. 

1113@tf_export(v1=["VariableScope"]) 

1114class VariableScope: 

1115 """Variable scope object to carry defaults to provide to `get_variable`. 

1116 

1117 Many of the arguments we need for `get_variable` in a variable store are most 

1118 easily handled with a context. This object is used for the defaults. 

1119 

1120 Attributes: 

1121 name: name of the current scope, used as prefix in get_variable. 

1122 initializer: default initializer passed to get_variable. 

1123 regularizer: default regularizer passed to get_variable. 

1124 reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in 

1125 get_variable. When eager execution is enabled this argument is always 

1126 forced to be False. 

1127 caching_device: string, callable, or None: the caching device passed to 

1128 get_variable. 

1129 partitioner: callable or `None`: the partitioner passed to `get_variable`. 

1130 custom_getter: default custom getter passed to get_variable. 

1131 name_scope: The name passed to `tf.name_scope`. 

1132 dtype: default type passed to get_variable (defaults to DT_FLOAT). 

1133 use_resource: if False, create a normal Variable; if True create an 

1134 experimental ResourceVariable with well-defined semantics. Defaults to 

1135 False (will later change to True). When eager execution is enabled this 

1136 argument is always forced to be True. 

1137 constraint: An optional projection function to be applied to the variable 

1138 after being updated by an `Optimizer` (e.g. used to implement norm 

1139 constraints or value constraints for layer weights). The function must 

1140 take as input the unprojected Tensor representing the value of the 

1141 variable and return the Tensor for the projected value (which must have 

1142 the same shape). Constraints are not safe to use when doing asynchronous 

1143 distributed training. 

1144 """ 

1145 

1146 def __init__(self, 

1147 reuse, 

1148 name="", 

1149 initializer=None, 

1150 regularizer=None, 

1151 caching_device=None, 

1152 partitioner=None, 

1153 custom_getter=None, 

1154 name_scope="", 

1155 dtype=dtypes.float32, 

1156 use_resource=None, 

1157 constraint=None): 

1158 """Creates a new VariableScope with the given properties.""" 

1159 self._name = name 

1160 self._initializer = initializer 

1161 self._regularizer = regularizer 

1162 self._reuse = reuse 

1163 self._caching_device = caching_device 

1164 self._partitioner = partitioner 

1165 self._custom_getter = custom_getter 

1166 self._name_scope = name_scope 

1167 self._dtype = dtype 

1168 self._use_resource = use_resource 

1169 self._constraint = constraint 

1170 if context.executing_eagerly(): 

1171 if self._caching_device is not None: 

1172 raise NotImplementedError("Caching devices is not yet supported " 

1173 "when eager execution is enabled.") 

1174 self._reuse = AUTO_REUSE 

1175 self._use_resource = True 

1176 

1177 @property 

1178 def name(self): 

1179 return self._name 

1180 

1181 @property 

1182 def original_name_scope(self): 

1183 return self._name_scope 

1184 

1185 @property 

1186 def reuse(self): 

1187 return self._reuse 

1188 

1189 @property 

1190 def initializer(self): 

1191 return self._initializer 

1192 

1193 @property 

1194 def dtype(self): 

1195 return self._dtype 

1196 

1197 @property 

1198 def use_resource(self): 

1199 return self._use_resource 

1200 

1201 @property 

1202 def regularizer(self): 

1203 return self._regularizer 

1204 

1205 @property 

1206 def caching_device(self): 

1207 return self._caching_device 

1208 

1209 @property 

1210 def partitioner(self): 

1211 return self._partitioner 

1212 

1213 @property 

1214 def custom_getter(self): 

1215 return self._custom_getter 

1216 

1217 @property 

1218 def constraint(self): 

1219 return self._constraint 

1220 

1221 def reuse_variables(self): 

1222 """Reuse variables in this scope.""" 

1223 self._reuse = True 

1224 

1225 def set_initializer(self, initializer): 

1226 """Set initializer for this scope.""" 

1227 self._initializer = initializer 

1228 

1229 def set_dtype(self, dtype): 

1230 """Set data type for this scope.""" 

1231 self._dtype = dtype 

1232 

1233 def set_use_resource(self, use_resource): 

1234 """Sets whether to use ResourceVariables for this scope.""" 

1235 if context.executing_eagerly() and not use_resource: 

1236 raise ValueError("When eager execution is enabled, " 

1237 "use_resource cannot be set to false.") 

1238 self._use_resource = use_resource 

1239 

1240 def set_regularizer(self, regularizer): 

1241 """Set regularizer for this scope.""" 

1242 self._regularizer = regularizer 

1243 

1244 def set_caching_device(self, caching_device): 

1245 """Set caching_device for this scope.""" 

1246 if context.executing_eagerly(): 

1247 raise NotImplementedError("Caching devices are not yet supported " 

1248 "when eager execution is enabled.") 

1249 self._caching_device = caching_device 

1250 

1251 def set_partitioner(self, partitioner): 

1252 """Set partitioner for this scope.""" 

1253 self._partitioner = partitioner 

1254 

1255 def set_custom_getter(self, custom_getter): 

1256 """Set custom getter for this scope.""" 

1257 self._custom_getter = custom_getter 

1258 

1259 def get_collection(self, name): 

1260 """Get this scope's variables.""" 

1261 scope = self._name + "/" if self._name else "" 

1262 return ops.get_collection(name, scope) 

1263 

1264 def trainable_variables(self): 

1265 """Get this scope's trainable variables.""" 

1266 return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 

1267 

1268 def global_variables(self): 

1269 """Get this scope's global variables.""" 

1270 return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 

1271 

1272 def local_variables(self): 

1273 """Get this scope's local variables.""" 

1274 return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 

1275 

1276 def get_variable(self, 

1277 var_store, 

1278 name, 

1279 shape=None, 

1280 dtype=None, 

1281 initializer=None, 

1282 regularizer=None, 

1283 reuse=None, 

1284 trainable=None, 

1285 collections=None, 

1286 caching_device=None, 

1287 partitioner=None, 

1288 validate_shape=True, 

1289 use_resource=None, 

1290 custom_getter=None, 

1291 constraint=None, 

1292 synchronization=VariableSynchronization.AUTO, 

1293 aggregation=VariableAggregation.NONE): 

1294 """Gets an existing variable with this name or create a new one.""" 

1295 if regularizer is None: 

1296 regularizer = self._regularizer 

1297 if caching_device is None: 

1298 caching_device = self._caching_device 

1299 if partitioner is None: 

1300 partitioner = self._partitioner 

1301 if custom_getter is None: 

1302 custom_getter = self._custom_getter 

1303 if context.executing_eagerly(): 

1304 reuse = False 

1305 use_resource = True 

1306 else: 

1307 if reuse is None: 

1308 reuse = self._reuse 

1309 if use_resource is None: 

1310 use_resource = self._use_resource 

1311 

1312 full_name = self.name + "/" + name if self.name else name 

1313 # Variable names only depend on variable_scope (full_name here), 

1314 # not name_scope, so we reset it below for the time of variable creation. 

1315 with ops.name_scope(None, skip_on_eager=False): 

1316 # Check that `initializer` dtype and `dtype` are consistent before 

1317 # replacing them with defaults. 

1318 if (dtype is not None and initializer is not None and 

1319 not callable(initializer)): 

1320 init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype 

1321 if init_dtype != dtype: 

1322 raise ValueError("Initializer type '%s' and explicit dtype '%s' " 

1323 "don't match." % (init_dtype, dtype)) 

1324 if initializer is None: 

1325 initializer = self._initializer 

1326 if constraint is None: 

1327 constraint = self._constraint 

1328 if dtype is None: 

1329 dtype = self._dtype 

1330 return var_store.get_variable( 

1331 full_name, 

1332 shape=shape, 

1333 dtype=dtype, 

1334 initializer=initializer, 

1335 regularizer=regularizer, 

1336 reuse=reuse, 

1337 trainable=trainable, 

1338 collections=collections, 

1339 caching_device=caching_device, 

1340 partitioner=partitioner, 

1341 validate_shape=validate_shape, 

1342 use_resource=use_resource, 

1343 custom_getter=custom_getter, 

1344 constraint=constraint, 

1345 synchronization=synchronization, 

1346 aggregation=aggregation) 

1347 

1348 def _get_partitioned_variable(self, 

1349 var_store, 

1350 name, 

1351 shape=None, 

1352 dtype=None, 

1353 initializer=None, 

1354 regularizer=None, 

1355 trainable=None, 

1356 collections=None, 

1357 caching_device=None, 

1358 partitioner=None, 

1359 validate_shape=True, 

1360 use_resource=None, 

1361 constraint=None, 

1362 synchronization=VariableSynchronization.AUTO, 

1363 aggregation=VariableAggregation.NONE): 

1364 """Gets an existing variable with this name or create a new one.""" 

1365 if initializer is None: 

1366 initializer = self._initializer 

1367 if regularizer is None: 

1368 regularizer = self._regularizer 

1369 if constraint is None: 

1370 constraint = self._constraint 

1371 if caching_device is None: 

1372 caching_device = self._caching_device 

1373 if partitioner is None: 

1374 partitioner = self._partitioner 

1375 if dtype is None: 

1376 dtype = self._dtype 

1377 if use_resource is None: 

1378 use_resource = self._use_resource 

1379 

1380 if self._custom_getter is not None: 

1381 raise ValueError( 

1382 "Private access to _get_partitioned_variable is not allowed when " 

1383 "a custom getter is set. Current custom getter: %s. " 

1384 "It is likely that you're using create_partitioned_variables. " 

1385 "If so, consider instead using get_variable with a non-empty " 

1386 "partitioner parameter instead." % self._custom_getter) 

1387 

1388 if partitioner is None: 

1389 raise ValueError("No partitioner was specified") 

1390 

1391 # This allows the variable scope name to be used as the variable name if 

1392 # this function is invoked with an empty name arg, for backward 

1393 # compatibility with create_partitioned_variables(). 

1394 full_name_list = [] 

1395 if self.name: 

1396 full_name_list.append(self.name) 

1397 if name: 

1398 full_name_list.append(name) 

1399 full_name = "/".join(full_name_list) 

1400 

1401 # Variable names only depend on variable_scope (full_name here), 

1402 # not name_scope, so we reset it below for the time of variable creation. 

1403 with ops.name_scope(None, skip_on_eager=False): 

1404 # pylint: disable=protected-access 

1405 return var_store._get_partitioned_variable( 

1406 full_name, 

1407 shape=shape, 

1408 dtype=dtype, 

1409 initializer=initializer, 

1410 regularizer=regularizer, 

1411 reuse=self.reuse, 

1412 trainable=trainable, 

1413 collections=collections, 

1414 caching_device=caching_device, 

1415 partitioner=partitioner, 

1416 validate_shape=validate_shape, 

1417 use_resource=use_resource, 

1418 constraint=constraint, 

1419 synchronization=synchronization, 

1420 aggregation=aggregation) 

1421 # pylint: enable=protected-access 

1422 

1423 

1424_VARSTORE_KEY = ("__variable_store",) 

1425_VARSCOPESTORE_KEY = ("__varscope",) 

1426 

1427 

1428class _VariableScopeStore(threading.local): 

1429 """A thread local store for the current variable scope and scope counts.""" 

1430 

1431 def __init__(self): 

1432 super(_VariableScopeStore, self).__init__() 

1433 self.current_scope = VariableScope(False) 

1434 self.variable_scopes_count = {} 

1435 

1436 def open_variable_scope(self, scope_name): 

1437 if scope_name in self.variable_scopes_count: 

1438 self.variable_scopes_count[scope_name] += 1 

1439 else: 

1440 self.variable_scopes_count[scope_name] = 1 

1441 

1442 def close_variable_subscopes(self, scope_name): 

1443 if scope_name is None: 

1444 for k in self.variable_scopes_count: 

1445 self.variable_scopes_count[k] = 0 

1446 else: 

1447 startswith_check = scope_name + "/" 

1448 startswith_len = len(startswith_check) 

1449 for k in self.variable_scopes_count: 

1450 if k[:startswith_len] == startswith_check: 

1451 self.variable_scopes_count[k] = 0 

1452 

1453 def variable_scope_count(self, scope_name): 

1454 return self.variable_scopes_count.get(scope_name, 0) 

1455 

1456 

1457def get_variable_scope_store(): 

1458 """Returns the variable scope store for current thread.""" 

1459 scope_store = ops.get_collection(_VARSCOPESTORE_KEY) 

1460 

1461 if not scope_store: 

1462 scope_store = _VariableScopeStore() 

1463 ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store) 

1464 else: 

1465 scope_store = scope_store[0] 

1466 

1467 return scope_store 

1468 

1469 

1470@tf_export(v1=["get_variable_scope"]) 

1471def get_variable_scope(): 

1472 """Returns the current variable scope. 

1473 

1474 @compatibility(TF2) 

1475 Although it is a legacy `compat.v1` api, 

1476 `tf.compat.v1.get_variable` is compatible with eager 

1477 execution and `tf.function` 

1478 

1479 However, to maintain variable-scope based variable reuse 

1480 you will need to combine it with 

1481 `tf.compat.v1.keras.utils.track_tf1_style_variables`. (Though 

1482 it will behave as if reuse is always set to `tf.compat.v1.AUTO_REUSE`.) 

1483 

1484 See the 

1485 [migration guide](https://www.tensorflow.org/guide/migrate/model_mapping) 

1486 for more info. 

1487 

1488 The TF2 equivalent, if you are just trying to track 

1489 variable name prefixes and not control `get_variable`-based variable reuse, 

1490 would be to use `tf.name_scope` and capture the output of opening the 

1491 scope (which represents the current name prefix). 

1492 

1493 For example: 

1494 ```python 

1495 x = tf.name_scope('foo') as current_scope: 

1496 ... 

1497 ``` 

1498 @end_compatibility 

1499 """ 

1500 return get_variable_scope_store().current_scope 

1501 

1502 

1503def _get_default_variable_store(): 

1504 store = ops.get_collection(_VARSTORE_KEY) 

1505 if store: 

1506 return store[0] 

1507 store = _VariableStore() 

1508 ops.add_to_collection(_VARSTORE_KEY, store) 

1509 return store 

1510 

1511 

1512@tf_contextlib.contextmanager 

1513def with_variable_store(store): 

1514 store_collection = ops.get_collection_ref(_VARSTORE_KEY) 

1515 old = list(store_collection) 

1516 store_collection[:] = [store] 

1517 try: 

1518 yield 

1519 finally: 

1520 store_collection[:] = old 

1521 

1522 

1523class EagerVariableStore: 

1524 """Wrapper allowing functional layers to be used with eager execution. 

1525 

1526 When eager execution is enabled Variables get deleted when they go out of 

1527 scope, and are not stored in global collections by default. A lot of code 

1528 (mostly the functional layers in tf.layers) assumes that variables are kept in 

1529 a global list. 

1530 

1531 EagerVariableStore can be used in conjunction with this code to make it 

1532 eager-friendly. For example, to create a dense layer, use: 

1533 

1534 ``` 

1535 container = tfe.EagerVariableStore() 

1536 for input in dataset_iterator: 

1537 with container.as_default(): 

1538 x = tf.compat.v1.layers.dense(input, name="l1") 

1539 print(container.variables) # Should print the variables used in the layer. 

1540 ``` 

1541 """ 

1542 

1543 def __init__(self, store=None): 

1544 if store is not None: 

1545 if not store._store_eager_variables: # pylint: disable=protected-access 

1546 raise ValueError("Cannot construct EagerVariableStore from a " 

1547 "VariableStore object that does not hold eager " 

1548 "variables.") 

1549 self._store = store 

1550 else: 

1551 self._store = _VariableStore() 

1552 self._store._store_eager_variables = True # pylint: disable=protected-access 

1553 

1554 def as_default(self): 

1555 return with_variable_store(self._store) 

1556 

1557 def variables(self): 

1558 return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access 

1559 

1560 def trainable_variables(self): 

1561 # pylint: disable=protected-access 

1562 return sorted([x for x in self._store._vars.values() if x.trainable], 

1563 key=lambda x: x.name) 

1564 # pylint: enable=protected-access 

1565 

1566 def non_trainable_variables(self): 

1567 # pylint: disable=protected-access 

1568 return sorted([x for x in self._store._vars.values() if not x.trainable], 

1569 key=lambda x: x.name) 

1570 # pylint: enable=protected-access 

1571 

1572 def copy(self): 

1573 """Copy this variable store and all of its contents. 

1574 

1575 Variables contained in this store will be copied over to the new variable 

1576 store, meaning that they can be modified without affecting the variables in 

1577 this store. 

1578 

1579 Returns: 

1580 A new EagerVariableStore instance containing copied variables. 

1581 """ 

1582 # pylint: disable=protected-access 

1583 new_store = EagerVariableStore() 

1584 for key, var in self._store._vars.items(): 

1585 # Strip device out of variable name. 

1586 try: 

1587 index = var.name.index(":") 

1588 except ValueError: 

1589 stripped_var_name = var.name 

1590 else: 

1591 stripped_var_name = var.name[:index] 

1592 

1593 # Create new variable with same value, name, and "trainable" flag. 

1594 new_var = resource_variable_ops.ResourceVariable( 

1595 var.read_value(), name=stripped_var_name, trainable=var.trainable) 

1596 new_store._store._vars[key] = new_var 

1597 return new_store 

1598 # pylint: enable=protected-access 

1599 

1600 

1601# The argument list for get_variable must match arguments to get_local_variable. 

1602# So, if you are updating the arguments, also update arguments to 

1603# get_local_variable below. 

1604@tf_export(v1=["get_variable"]) 

1605def get_variable(name, 

1606 shape=None, 

1607 dtype=None, 

1608 initializer=None, 

1609 regularizer=None, 

1610 trainable=None, 

1611 collections=None, 

1612 caching_device=None, 

1613 partitioner=None, 

1614 validate_shape=True, 

1615 use_resource=None, 

1616 custom_getter=None, 

1617 constraint=None, 

1618 synchronization=VariableSynchronization.AUTO, 

1619 aggregation=VariableAggregation.NONE): 

1620 return get_variable_scope().get_variable( 

1621 _get_default_variable_store(), 

1622 name, 

1623 shape=shape, 

1624 dtype=dtype, 

1625 initializer=initializer, 

1626 regularizer=regularizer, 

1627 trainable=trainable, 

1628 collections=collections, 

1629 caching_device=caching_device, 

1630 partitioner=partitioner, 

1631 validate_shape=validate_shape, 

1632 use_resource=use_resource, 

1633 custom_getter=custom_getter, 

1634 constraint=constraint, 

1635 synchronization=synchronization, 

1636 aggregation=aggregation) 

1637 

1638 

1639get_variable_or_local_docstring = ("""%s 

1640 

1641@compatibility(TF2) 

1642Although it is a legacy `compat.v1` api, 

1643`tf.compat.v1.get_variable` is mostly compatible with eager 

1644execution and `tf.function` but only if you combine it with the 

1645`tf.compat.v1.keras.utils.track_tf1_style_variables` decorator. (Though 

1646it will behave as if reuse is always set to `AUTO_REUSE`.) 

1647 

1648See the 

1649[model migration guide](https://www.tensorflow.org/guide/migrate/model_mapping) 

1650for more info. 

1651 

1652If you do not combine it with 

1653`tf.compat.v1.keras.utils.track_tf1_style_variables`, `get_variable` will create 

1654a brand new variable every single time it is called and will never reuse 

1655variables, regardless of variable names or `reuse` arguments. 

1656 

1657The TF2 equivalent of this symbol would be `tf.Variable`, but note 

1658that when using `tf.Variable` you must make sure you track your variables 

1659(and regularizer arguments) either manually or via `tf.Module` or 

1660`tf.keras.layers.Layer` mechanisms. 

1661 

1662A section of the 

1663[migration guide](https://www.tensorflow.org/guide/migrate/model_mapping#incremental_migration_to_native_tf2) 

1664provides more details on incrementally migrating these usages to `tf.Variable` 

1665as well. 

1666 

1667Note: The `partitioner` arg is not compatible with TF2 behaviors even when 

1668using `tf.compat.v1.keras.utils.track_tf1_style_variables`. It can be replaced 

1669by using `ParameterServerStrategy` and its partitioners. See the 

1670[multi-gpu migration guide](https://www.tensorflow.org/guide/migrate/multi_worker_cpu_gpu_training) 

1671and the ParameterServerStrategy guides it references for more info. 

1672@end_compatibility 

1673 

1674%sThis function prefixes the name with the current variable scope 

1675and performs reuse checks. See the 

1676[Variable Scope How To](https://tensorflow.org/guide/variables) 

1677for an extensive description of how reusing works. Here is a basic example: 

1678 

1679```python 

1680def foo(): 

1681 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 

1682 v = tf.get_variable("v", [1]) 

1683 return v 

1684 

1685v1 = foo() # Creates v. 

1686v2 = foo() # Gets the same, existing v. 

1687assert v1 == v2 

1688``` 

1689 

1690If initializer is `None` (the default), the default initializer passed in 

1691the variable scope will be used. If that one is `None` too, a 

1692`glorot_uniform_initializer` will be used. The initializer can also be 

1693a Tensor, in which case the variable is initialized to this value and shape. 

1694 

1695Similarly, if the regularizer is `None` (the default), the default regularizer 

1696passed in the variable scope will be used (if that is `None` too, 

1697then by default no regularization is performed). 

1698 

1699If a partitioner is provided, a `PartitionedVariable` is returned. 

1700Accessing this object as a `Tensor` returns the shards concatenated along 

1701the partition axis. 

1702 

1703Some useful partitioners are available. See, e.g., 

1704`variable_axis_size_partitioner` and `min_max_variable_partitioner`. 

1705 

1706Args: 

1707 name: The name of the new or existing variable. 

1708 shape: Shape of the new or existing variable. 

1709 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 

1710 initializer: Initializer for the variable if one is created. Can either be 

1711 an initializer object or a Tensor. If it's a Tensor, its shape must be known 

1712 unless validate_shape is False. 

1713 regularizer: A (Tensor -> Tensor or None) function; the result of 

1714 applying it on a newly created variable will be added to the collection 

1715 `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization. 

1716 %scollections: List of graph collections keys to add the Variable to. 

1717 Defaults to `[%s]` (see `tf.Variable`). 

1718 caching_device: Optional device string or function describing where the 

1719 Variable should be cached for reading. Defaults to the Variable's 

1720 device. If not `None`, caches on another device. Typical use is to 

1721 cache on the device where the Ops using the Variable reside, to 

1722 deduplicate copying through `Switch` and other conditional statements. 

1723 partitioner: Optional callable that accepts a fully defined `TensorShape` 

1724 and `dtype` of the Variable to be created, and returns a list of 

1725 partitions for each axis (currently only one axis can be partitioned). 

1726 validate_shape: If False, allows the variable to be initialized with a 

1727 value of unknown shape. If True, the default, the shape of initial_value 

1728 must be known. For this to be used the initializer must be a Tensor and 

1729 not an initializer object. 

1730 use_resource: If False, creates a regular Variable. If true, creates an 

1731 experimental ResourceVariable instead with well-defined semantics. 

1732 Defaults to False (will later change to True). When eager execution is 

1733 enabled this argument is always forced to be True. 

1734 custom_getter: Callable that takes as a first argument the true getter, and 

1735 allows overwriting the internal get_variable method. 

1736 The signature of `custom_getter` should match that of this method, 

1737 but the most future-proof version will allow for changes: 

1738 `def custom_getter(getter, *args, **kwargs)`. Direct access to 

1739 all `get_variable` parameters is also allowed: 

1740 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 

1741 custom getter that simply creates variables with modified names is: 

1742 ```python 

1743 def custom_getter(getter, name, *args, **kwargs): 

1744 return getter(name + '_suffix', *args, **kwargs) 

1745 ``` 

1746 constraint: An optional projection function to be applied to the variable 

1747 after being updated by an `Optimizer` (e.g. used to implement norm 

1748 constraints or value constraints for layer weights). The function must 

1749 take as input the unprojected Tensor representing the value of the 

1750 variable and return the Tensor for the projected value 

1751 (which must have the same shape). Constraints are not safe to 

1752 use when doing asynchronous distributed training. 

1753 synchronization: Indicates when a distributed a variable will be 

1754 aggregated. Accepted values are constants defined in the class 

1755 `tf.VariableSynchronization`. By default the synchronization is set to 

1756 `AUTO` and the current `DistributionStrategy` chooses 

1757 when to synchronize. 

1758 aggregation: Indicates how a distributed variable will be aggregated. 

1759 Accepted values are constants defined in the class 

1760 `tf.VariableAggregation`. 

1761 

1762Returns: 

1763 The created or existing `Variable` (or `PartitionedVariable`, if a 

1764 partitioner was used). 

1765 

1766Raises: 

1767 ValueError: when creating a new variable and shape is not declared, 

1768 when violating reuse during variable creation, or when `initializer` dtype 

1769 and `dtype` don't match. Reuse is set inside `variable_scope`. 

1770""") 

1771get_variable.__doc__ = get_variable_or_local_docstring % ( 

1772 "Gets an existing variable with these parameters or create a new one.", "", 

1773 "trainable: If `True` also add the variable to the graph collection\n" 

1774 " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n ", 

1775 "GraphKeys.GLOBAL_VARIABLES") 

1776 

1777 

1778# The argument list for get_local_variable must match arguments to get_variable. 

1779# So, if you are updating the arguments, also update arguments to get_variable. 

1780@tf_export(v1=["get_local_variable"]) 

1781def get_local_variable( # pylint: disable=missing-docstring 

1782 name, 

1783 shape=None, 

1784 dtype=None, 

1785 initializer=None, 

1786 regularizer=None, 

1787 trainable=False, # pylint: disable=unused-argument 

1788 collections=None, 

1789 caching_device=None, 

1790 partitioner=None, 

1791 validate_shape=True, 

1792 use_resource=None, 

1793 custom_getter=None, 

1794 constraint=None, 

1795 synchronization=VariableSynchronization.AUTO, 

1796 aggregation=VariableAggregation.NONE): 

1797 if collections: 

1798 collections += [ops.GraphKeys.LOCAL_VARIABLES] 

1799 else: 

1800 collections = [ops.GraphKeys.LOCAL_VARIABLES] 

1801 return get_variable( 

1802 name, 

1803 shape=shape, 

1804 dtype=dtype, 

1805 initializer=initializer, 

1806 regularizer=regularizer, 

1807 trainable=False, 

1808 collections=collections, 

1809 caching_device=caching_device, 

1810 partitioner=partitioner, 

1811 validate_shape=validate_shape, 

1812 use_resource=use_resource, 

1813 synchronization=synchronization, 

1814 aggregation=aggregation, 

1815 custom_getter=custom_getter, 

1816 constraint=constraint) 

1817 

1818 

1819get_local_variable.__doc__ = get_variable_or_local_docstring % ( 

1820 "Gets an existing *local* variable or creates a new one.", 

1821 "Behavior is the same as in `get_variable`, except that variables are\n" 

1822 "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n" 

1823 "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES") 

1824 

1825 

1826def _get_partitioned_variable(name, 

1827 shape=None, 

1828 dtype=None, 

1829 initializer=None, 

1830 regularizer=None, 

1831 trainable=True, 

1832 collections=None, 

1833 caching_device=None, 

1834 partitioner=None, 

1835 validate_shape=True, 

1836 use_resource=None, 

1837 constraint=None, 

1838 synchronization=VariableSynchronization.AUTO, 

1839 aggregation=VariableAggregation.NONE): 

1840 """Gets or creates a sharded variable list with these parameters. 

1841 

1842 The `partitioner` must be a callable that accepts a fully defined 

1843 `TensorShape` and returns a sequence of integers (the `partitions`). 

1844 These integers describe how to partition the given sharded `Variable` 

1845 along the given dimension. That is, `partitions[1] = 3` means split 

1846 the `Variable` into 3 shards along dimension 1. Currently, sharding along 

1847 only one axis is supported. 

1848 

1849 If the list of variables with the given name (prefix) is already stored, 

1850 we return the stored variables. Otherwise, we create a new one. 

1851 

1852 If initializer is `None` (the default), the default initializer passed in 

1853 the constructor is used. If that one is `None` too, we use a new 

1854 `glorot_uniform_initializer`. If initializer is a Tensor, we use 

1855 it as a value and derive the shape from the initializer. 

1856 

1857 If the initializer is a callable, then it will be called for each 

1858 shard. Otherwise the initializer should match the shape of the entire 

1859 sharded Variable, and it will be sliced accordingly for each shard. 

1860 

1861 Some useful partitioners are available. See, e.g., 

1862 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 

1863 

1864 Args: 

1865 name: The name of the new or existing variable. 

1866 shape: Shape of the new or existing variable. 

1867 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 

1868 initializer: Initializer for the variable if one is created. 

1869 regularizer: A (Tensor -> Tensor or None) function; the result of applying 

1870 it on a newly created variable will be added to the collection 

1871 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 

1872 trainable: If `True` also add the variable to the graph collection 

1873 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 

1874 collections: List of graph collections keys to add the Variable to. Defaults 

1875 to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 

1876 caching_device: Optional device string or function describing where the 

1877 Variable should be cached for reading. Defaults to the Variable's device. 

1878 If not `None`, caches on another device. Typical use is to cache on the 

1879 device where the Ops using the Variable reside, to deduplicate copying 

1880 through `Switch` and other conditional statements. 

1881 partitioner: Optional callable that accepts a fully defined `TensorShape` 

1882 and `dtype` of the Variable to be created, and returns a list of 

1883 partitions for each axis (currently only one axis can be partitioned). 

1884 validate_shape: If False, allows the variable to be initialized with a value 

1885 of unknown shape. If True, the default, the shape of initial_value must be 

1886 known. 

1887 use_resource: If False, creates a regular Variable. If True, creates an 

1888 experimental ResourceVariable instead which has well-defined semantics. 

1889 Defaults to False (will later change to True). 

1890 constraint: An optional projection function to be applied to the variable 

1891 after being updated by an `Optimizer` (e.g. used to implement norm 

1892 constraints or value constraints for layer weights). The function must 

1893 take as input the unprojected Tensor representing the value of the 

1894 variable and return the Tensor for the projected value (which must have 

1895 the same shape). Constraints are not safe to use when doing asynchronous 

1896 distributed training. 

1897 synchronization: Indicates when a distributed a variable will be aggregated. 

1898 Accepted values are constants defined in the class 

1899 `tf.VariableSynchronization`. By default the synchronization is set to 

1900 `AUTO` and the current `DistributionStrategy` chooses when to synchronize. 

1901 aggregation: Indicates how a distributed variable will be aggregated. 

1902 Accepted values are constants defined in the class 

1903 `tf.VariableAggregation`. 

1904 

1905 Returns: 

1906 A tuple `(shards, partitions)` where `shards` is the list of `Variable` 

1907 shards and `partitions` is the output of the partitioner on the input 

1908 shape. 

1909 

1910 Raises: 

1911 ValueError: when creating a new variable and shape is not declared, 

1912 or when violating reuse during variable creation. Reuse is set inside 

1913 `variable_scope`. 

1914 """ 

1915 # pylint: disable=protected-access 

1916 scope = get_variable_scope() 

1917 if scope.custom_getter is not None: 

1918 raise ValueError( 

1919 "Private access to _get_partitioned_variable is not allowed when " 

1920 "a custom getter is set. Current custom getter: %s. " 

1921 "It is likely that you're using create_partitioned_variables. " 

1922 "If so, consider instead using get_variable with a non-empty " 

1923 "partitioner parameter instead." % scope.custom_getter) 

1924 return scope._get_partitioned_variable( 

1925 _get_default_variable_store(), 

1926 name, 

1927 shape=shape, 

1928 dtype=dtype, 

1929 initializer=initializer, 

1930 regularizer=regularizer, 

1931 trainable=trainable, 

1932 collections=collections, 

1933 caching_device=caching_device, 

1934 partitioner=partitioner, 

1935 validate_shape=validate_shape, 

1936 use_resource=use_resource, 

1937 constraint=constraint, 

1938 synchronization=synchronization, 

1939 aggregation=aggregation) 

1940 # pylint: enable=protected-access 

1941 

1942 

1943# Named like a function for compatibility with the previous 

1944# @tf_contextlib.contextmanager definition. 

1945class _pure_variable_scope: # pylint: disable=invalid-name 

1946 """A context for the variable_scope, see `variable_scope` for docs.""" 

1947 

1948 def __init__(self, 

1949 name_or_scope, 

1950 reuse=None, 

1951 initializer=None, 

1952 regularizer=None, 

1953 caching_device=None, 

1954 partitioner=None, 

1955 custom_getter=None, 

1956 old_name_scope=None, 

1957 dtype=dtypes.float32, 

1958 use_resource=None, 

1959 constraint=None): 

1960 """Creates a context for the variable_scope, see `variable_scope` for docs. 

1961 

1962 Note: this does not create a name scope. 

1963 

1964 Args: 

1965 name_or_scope: `string` or `VariableScope`: the scope to open. 

1966 reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit 

1967 the parent scope's reuse flag. 

1968 initializer: default initializer for variables within this scope. 

1969 regularizer: default regularizer for variables within this scope. 

1970 caching_device: default caching device for variables within this scope. 

1971 partitioner: default partitioner for variables within this scope. 

1972 custom_getter: default custom getter for variables within this scope. 

1973 old_name_scope: the original name scope when re-entering a variable scope. 

1974 dtype: type of the variables within this scope (defaults to `DT_FLOAT`). 

1975 use_resource: If False, variables in this scope will be regular Variables. 

1976 If True, experimental ResourceVariables will be creates instead, with 

1977 well-defined semantics. Defaults to False (will later change to True). 

1978 constraint: An optional projection function to be applied to the variable 

1979 after being updated by an `Optimizer` (e.g. used to implement norm 

1980 constraints or value constraints for layer weights). The function must 

1981 take as input the unprojected Tensor representing the value of the 

1982 variable and return the Tensor for the projected value (which must have 

1983 the same shape). Constraints are not safe to use when doing asynchronous 

1984 distributed training. 

1985 """ 

1986 self._name_or_scope = name_or_scope 

1987 self._reuse = reuse 

1988 self._initializer = initializer 

1989 self._regularizer = regularizer 

1990 self._caching_device = caching_device 

1991 self._partitioner = partitioner 

1992 self._custom_getter = custom_getter 

1993 self._old_name_scope = old_name_scope 

1994 self._dtype = dtype 

1995 self._use_resource = use_resource 

1996 self._constraint = constraint 

1997 self._var_store = _get_default_variable_store() 

1998 self._var_scope_store = get_variable_scope_store() 

1999 self._last_variable_scope_object = None 

2000 if isinstance(self._name_or_scope, VariableScope): 

2001 self._new_name = self._name_or_scope.name 

2002 name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access 

2003 # Handler for the case when we jump to a shared scope. We create a new 

2004 # VariableScope (self._var_scope_object) that contains a copy of the 

2005 # provided shared scope, possibly with changed reuse and initializer, if 

2006 # the user requested this. 

2007 variable_scope_object = VariableScope( 

2008 self._name_or_scope.reuse if not self._reuse else self._reuse, 

2009 name=self._new_name, 

2010 initializer=self._name_or_scope.initializer, 

2011 regularizer=self._name_or_scope.regularizer, 

2012 caching_device=self._name_or_scope.caching_device, 

2013 partitioner=self._name_or_scope.partitioner, 

2014 dtype=self._name_or_scope.dtype, 

2015 custom_getter=self._name_or_scope.custom_getter, 

2016 name_scope=name_scope, 

2017 use_resource=self._name_or_scope.use_resource, 

2018 constraint=self._constraint) 

2019 if self._initializer is not None: 

2020 variable_scope_object.set_initializer(self._initializer) 

2021 if self._regularizer is not None: 

2022 variable_scope_object.set_regularizer(self._regularizer) 

2023 if self._caching_device is not None: 

2024 variable_scope_object.set_caching_device(self._caching_device) 

2025 if self._partitioner is not None: 

2026 variable_scope_object.set_partitioner(self._partitioner) 

2027 if self._custom_getter is not None: 

2028 variable_scope_object.set_custom_getter( 

2029 _maybe_wrap_custom_getter(self._custom_getter, 

2030 self._name_or_scope.custom_getter)) 

2031 if self._dtype is not None: 

2032 variable_scope_object.set_dtype(self._dtype) 

2033 if self._use_resource is not None: 

2034 variable_scope_object.set_use_resource(self._use_resource) 

2035 self._cached_variable_scope_object = variable_scope_object 

2036 

2037 def __enter__(self): 

2038 """Begins the scope block. 

2039 

2040 Returns: 

2041 A VariableScope. 

2042 Raises: 

2043 ValueError: when trying to reuse within a create scope, or create within 

2044 a reuse scope, or if reuse is not `None` or `True`. 

2045 TypeError: when the types of some arguments are not appropriate. 

2046 """ 

2047 self._old = self._var_scope_store.current_scope 

2048 if isinstance(self._name_or_scope, VariableScope): 

2049 self._var_scope_store.open_variable_scope(self._new_name) 

2050 self._old_subscopes = copy.copy( 

2051 self._var_scope_store.variable_scopes_count) 

2052 variable_scope_object = self._cached_variable_scope_object 

2053 else: 

2054 # Handler for the case when we just prolong current variable scope. 

2055 # VariableScope with name extended by the provided one, and inherited 

2056 # reuse and initializer (except if the user provided values to set). 

2057 self._new_name = ( 

2058 self._old.name + "/" + 

2059 self._name_or_scope if self._old.name else self._name_or_scope) 

2060 self._reuse = (self._reuse or 

2061 self._old.reuse) # Re-using is inherited by sub-scopes. 

2062 if self._old_name_scope is None: 

2063 name_scope = self._name_or_scope 

2064 else: 

2065 name_scope = self._old_name_scope 

2066 variable_scope_object = VariableScope( 

2067 self._reuse, 

2068 name=self._new_name, 

2069 initializer=self._old.initializer, 

2070 regularizer=self._old.regularizer, 

2071 caching_device=self._old.caching_device, 

2072 partitioner=self._old.partitioner, 

2073 dtype=self._old.dtype, 

2074 use_resource=self._old.use_resource, 

2075 custom_getter=self._old.custom_getter, 

2076 name_scope=name_scope, 

2077 constraint=self._constraint) 

2078 if self._initializer is not None: 

2079 variable_scope_object.set_initializer(self._initializer) 

2080 if self._regularizer is not None: 

2081 variable_scope_object.set_regularizer(self._regularizer) 

2082 if self._caching_device is not None: 

2083 variable_scope_object.set_caching_device(self._caching_device) 

2084 if self._partitioner is not None: 

2085 variable_scope_object.set_partitioner(self._partitioner) 

2086 if self._custom_getter is not None: 

2087 variable_scope_object.set_custom_getter( 

2088 _maybe_wrap_custom_getter(self._custom_getter, 

2089 self._old.custom_getter)) 

2090 if self._dtype is not None: 

2091 variable_scope_object.set_dtype(self._dtype) 

2092 if self._use_resource is not None: 

2093 variable_scope_object.set_use_resource(self._use_resource) 

2094 self._var_scope_store.open_variable_scope(self._new_name) 

2095 self._var_scope_store.current_scope = variable_scope_object 

2096 self._last_variable_scope_object = variable_scope_object 

2097 return variable_scope_object 

2098 

2099 def __exit__(self, type_arg, value_arg, traceback_arg): 

2100 if (self._var_scope_store.current_scope is 

2101 not self._last_variable_scope_object): 

2102 raise RuntimeError("Improper nesting of variable_scope.") 

2103 # If jumping out from a non-prolonged scope, restore counts. 

2104 if isinstance(self._name_or_scope, VariableScope): 

2105 self._var_scope_store.variable_scopes_count = self._old_subscopes 

2106 else: 

2107 self._var_scope_store.close_variable_subscopes(self._new_name) 

2108 self._var_scope_store.current_scope = self._old 

2109 

2110 

2111def _maybe_wrap_custom_getter(custom_getter, old_getter): 

2112 """Wrap a call to a custom_getter to use the old_getter internally.""" 

2113 if old_getter is None: 

2114 return custom_getter 

2115 

2116 # The new custom_getter should call the old one 

2117 def wrapped_custom_getter(getter, *args, **kwargs): 

2118 # Call: 

2119 # custom_getter( 

2120 # lambda: old_getter(true_getter, ...), *args, **kwargs) 

2121 # which means custom_getter will call old_getter, which 

2122 # will call the true_getter, perform any intermediate 

2123 # processing, and return the results to the current 

2124 # getter, which will also perform additional processing. 

2125 return custom_getter(functools.partial(old_getter, getter), *args, **kwargs) 

2126 

2127 return wrapped_custom_getter 

2128 

2129 

2130def _get_unique_variable_scope(prefix): 

2131 """Get a name with the given prefix unique in the current variable scope.""" 

2132 var_scope_store = get_variable_scope_store() 

2133 current_scope = get_variable_scope() 

2134 name = current_scope.name + "/" + prefix if current_scope.name else prefix 

2135 if var_scope_store.variable_scope_count(name) == 0: 

2136 return prefix 

2137 idx = 1 

2138 while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0: 

2139 idx += 1 

2140 return prefix + ("_%d" % idx) 

2141 

2142 

2143# Named like a function for backwards compatibility with the 

2144# @tf_contextlib.contextmanager version, which was switched to a class to avoid 

2145# some object creation overhead. 

2146@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name 

2147class variable_scope: 

2148 """A context manager for defining ops that creates variables (layers). 

2149 

2150 @compatibility(TF2) 

2151 Although it is a legacy `compat.v1` api, 

2152 `tf.compat.v1.variable_scope` is mostly compatible with eager 

2153 execution and `tf.function` as long as you combine it with the 

2154 `tf.compat.v1.keras.utils.track_tf1_style_variables` decorator (though 

2155 it will behave as if reuse is always set to `AUTO_REUSE`.) 

2156 

2157 See the 

2158 [model migration guide]( 

2159 https://www.tensorflow.org/guide/migrate/model_mapping) 

2160 for more info on 

2161 migrating code that relies on `variable_scope`-based variable reuse. 

2162 

2163 When you use it with eager execution enabled but without 

2164 `tf.compat.v1.keras.utils.track_tf1_style_variables`, 

2165 `tf.compat.v1.variable_scope` will still be able to prefix the names 

2166 of variables created within the scope but it will not enable variable reuse 

2167 or error-raising checks around variable reuse (`get_variable` calls within 

2168 it would always create new variables). 

2169 

2170 Once you have switched away from `get_variable`-based variable reuse 

2171 mechanisms, to switch to TF2 APIs you can just use 

2172 `tf.name_scope` to prefix variable names. 

2173 @end_compatibility 

2174 

2175 This context manager validates that the (optional) `values` are from the same 

2176 graph, ensures that graph is the default graph, and pushes a name scope and a 

2177 variable scope. 

2178 

2179 If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None, 

2180 then `default_name` is used. In that case, if the same name has been 

2181 previously used in the same scope, it will be made unique by appending `_N` 

2182 to it. 

2183 

2184 Variable scope allows you to create new variables and to share already created 

2185 ones while providing checks to not create or share by accident. For details, 

2186 see the [Variable Scope How To](https://tensorflow.org/guide/variables), here 

2187 we present only a few basic examples. 

2188 

2189 The Variable Scope works as expected when the Eager Execution is Disabled. 

2190 

2191 ```python 

2192 tf.compat.v1.disable_eager_execution() 

2193 ``` 

2194 

2195 Simple example of how to create a new variable: 

2196 

2197 ```python 

2198 with tf.compat.v1.variable_scope("foo"): 

2199 with tf.compat.v1.variable_scope("bar"): 

2200 v = tf.compat.v1.get_variable("v", [1]) 

2201 assert v.name == "foo/bar/v:0" 

2202 ``` 

2203 

2204 Simple example of how to reenter a premade variable scope safely: 

2205 

2206 ```python 

2207 with tf.compat.v1.variable_scope("foo") as vs: 

2208 pass 

2209 

2210 # Re-enter the variable scope. 

2211 with tf.compat.v1.variable_scope(vs, 

2212 auxiliary_name_scope=False) as vs1: 

2213 # Restore the original name_scope. 

2214 with tf.name_scope(vs1.original_name_scope): 

2215 v = tf.compat.v1.get_variable("v", [1]) 

2216 assert v.name == "foo/v:0" 

2217 c = tf.constant([1], name="c") 

2218 assert c.name == "foo/c:0" 

2219 ``` 

2220 

2221 Keep in mind that the counters for `default_name` are discarded once the 

2222 parent scope is exited. Therefore when the code re-enters the scope (for 

2223 instance by saving it), all nested default_name counters will be restarted. 

2224 

2225 For instance: 

2226 

2227 ```python 

2228 with tf.compat.v1.variable_scope("foo") as vs: 

2229 with tf.compat.v1.variable_scope(None, default_name="bar"): 

2230 v = tf.compat.v1.get_variable("a", [1]) 

2231 assert v.name == "foo/bar/a:0", v.name 

2232 with tf.compat.v1.variable_scope(None, default_name="bar"): 

2233 v = tf.compat.v1.get_variable("b", [1]) 

2234 assert v.name == "foo/bar_1/b:0" 

2235 

2236 with tf.compat.v1.variable_scope(vs): 

2237 with tf.compat.v1.variable_scope(None, default_name="bar"): 

2238 v = tf.compat.v1.get_variable("c", [1]) 

2239 assert v.name == "foo/bar/c:0" # Uses bar instead of bar_2! 

2240 ``` 

2241 

2242 Basic example of sharing a variable AUTO_REUSE: 

2243 

2244 ```python 

2245 def foo(): 

2246 with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE): 

2247 v = tf.compat.v1.get_variable("v", [1]) 

2248 return v 

2249 

2250 v1 = foo() # Creates v. 

2251 v2 = foo() # Gets the same, existing v. 

2252 assert v1 == v2 

2253 ``` 

2254 

2255 Basic example of sharing a variable with reuse=True: 

2256 

2257 ```python 

2258 with tf.compat.v1.variable_scope("foo"): 

2259 v = tf.compat.v1.get_variable("v", [1]) 

2260 with tf.compat.v1.variable_scope("foo", reuse=True): 

2261 v1 = tf.compat.v1.get_variable("v", [1]) 

2262 assert v1 == v 

2263 ``` 

2264 

2265 Sharing a variable by capturing a scope and setting reuse: 

2266 

2267 ```python 

2268 with tf.compat.v1.variable_scope("foo") as scope: 

2269 v = tf.compat.v1.get_variable("v", [1]) 

2270 scope.reuse_variables() 

2271 v1 = tf.compat.v1.get_variable("v", [1]) 

2272 assert v1 == v 

2273 ``` 

2274 

2275 To prevent accidental sharing of variables, we raise an exception when getting 

2276 an existing variable in a non-reusing scope. 

2277 

2278 ```python 

2279 with tf.compat.v1.variable_scope("foo"): 

2280 v = tf.compat.v1.get_variable("v", [1]) 

2281 v1 = tf.compat.v1.get_variable("v", [1]) 

2282 # Raises ValueError("... v already exists ..."). 

2283 ``` 

2284 

2285 Similarly, we raise an exception when trying to get a variable that does not 

2286 exist in reuse mode. 

2287 

2288 ```python 

2289 with tf.compat.v1.variable_scope("foo", reuse=True): 

2290 v = tf.compat.v1.get_variable("v", [1]) 

2291 # Raises ValueError("... v does not exists ..."). 

2292 ``` 

2293 

2294 Note that the `reuse` flag is inherited: if we open a reusing scope, then all 

2295 its sub-scopes become reusing as well. 

2296 

2297 A note about name scoping: Setting `reuse` does not impact the naming of other 

2298 ops such as mult. See related discussion on 

2299 [github#6189](https://github.com/tensorflow/tensorflow/issues/6189) 

2300 

2301 Note that up to and including version 1.0, it was allowed (though explicitly 

2302 discouraged) to pass False to the reuse argument, yielding undocumented 

2303 behaviour slightly different from None. Starting at 1.1.0 passing None and 

2304 False as reuse has exactly the same effect. 

2305 

2306 A note about using variable scopes in multi-threaded environment: Variable 

2307 scopes are thread local, so one thread will not see another thread's current 

2308 scope. Also, when using `default_name`, unique scopes names are also generated 

2309 only on a per thread basis. If the same name was used within a different 

2310 thread, that doesn't prevent a new thread from creating the same scope. 

2311 However, the underlying variable store is shared across threads (within the 

2312 same graph). As such, if another thread tries to create a new variable with 

2313 the same name as a variable created by a previous thread, it will fail unless 

2314 reuse is True. 

2315 

2316 Further, each thread starts with an empty variable scope. So if you wish to 

2317 preserve name prefixes from a scope from the main thread, you should capture 

2318 the main thread's scope and re-enter it in each thread. For e.g. 

2319 

2320 ``` 

2321 main_thread_scope = variable_scope.get_variable_scope() 

2322 

2323 # Thread's target function: 

2324 def thread_target_fn(captured_scope): 

2325 with variable_scope.variable_scope(captured_scope): 

2326 # .... regular code for this thread 

2327 

2328 

2329 thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,)) 

2330 ``` 

2331 """ 

2332 

2333 def __init__(self, 

2334 name_or_scope, 

2335 default_name=None, 

2336 values=None, 

2337 initializer=None, 

2338 regularizer=None, 

2339 caching_device=None, 

2340 partitioner=None, 

2341 custom_getter=None, 

2342 reuse=None, 

2343 dtype=None, 

2344 use_resource=None, 

2345 constraint=None, 

2346 auxiliary_name_scope=True): 

2347 """Initialize the context manager. 

2348 

2349 Args: 

2350 name_or_scope: `string` or `VariableScope`: the scope to open. 

2351 default_name: The default name to use if the `name_or_scope` argument is 

2352 `None`, this name will be uniquified. If name_or_scope is provided it 

2353 won't be used and therefore it is not required and can be None. 

2354 values: The list of `Tensor` arguments that are passed to the op function. 

2355 initializer: default initializer for variables within this scope. 

2356 regularizer: default regularizer for variables within this scope. 

2357 caching_device: default caching device for variables within this scope. 

2358 partitioner: default partitioner for variables within this scope. 

2359 custom_getter: default custom getter for variables within this scope. 

2360 reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into 

2361 reuse mode for this scope as well as all sub-scopes; if 

2362 tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and 

2363 return them otherwise; if None, we inherit the parent scope's reuse 

2364 flag. When eager execution is enabled, new variables are always created 

2365 unless an EagerVariableStore or template is currently active. 

2366 dtype: type of variables created in this scope (defaults to the type in 

2367 the passed scope, or inherited from parent scope). 

2368 use_resource: If False, all variables will be regular Variables. If True, 

2369 experimental ResourceVariables with well-defined semantics will be used 

2370 instead. Defaults to False (will later change to True). When eager 

2371 execution is enabled this argument is always forced to be True. 

2372 constraint: An optional projection function to be applied to the variable 

2373 after being updated by an `Optimizer` (e.g. used to implement norm 

2374 constraints or value constraints for layer weights). The function must 

2375 take as input the unprojected Tensor representing the value of the 

2376 variable and return the Tensor for the projected value (which must have 

2377 the same shape). Constraints are not safe to use when doing asynchronous 

2378 distributed training. 

2379 auxiliary_name_scope: If `True`, we create an auxiliary name scope with 

2380 the scope. If `False`, we don't create it. Note that the argument is not 

2381 inherited, and it only takes effect for once when creating. You should 

2382 only use it for re-entering a premade variable scope. 

2383 

2384 Returns: 

2385 A scope that can be captured and reused. 

2386 

2387 Raises: 

2388 ValueError: when trying to reuse within a create scope, or create within 

2389 a reuse scope. 

2390 TypeError: when the types of some arguments are not appropriate. 

2391 """ 

2392 self._name_or_scope = name_or_scope 

2393 self._default_name = default_name 

2394 self._values = values 

2395 self._initializer = initializer 

2396 self._regularizer = regularizer 

2397 self._caching_device = caching_device 

2398 self._partitioner = partitioner 

2399 self._custom_getter = custom_getter 

2400 self._reuse = reuse 

2401 self._dtype = dtype 

2402 self._use_resource = use_resource 

2403 self._constraint = constraint 

2404 if self._default_name is None and self._name_or_scope is None: 

2405 raise TypeError("If default_name is None then name_or_scope is required") 

2406 if self._reuse is False: 

2407 # We don't allow non-inheriting scopes, False = None here. 

2408 self._reuse = None 

2409 if not (self._reuse is True 

2410 or self._reuse is None 

2411 or self._reuse is AUTO_REUSE): 

2412 raise ValueError("The reuse parameter must be True or False or None.") 

2413 if self._values is None: 

2414 self._values = [] 

2415 self._in_graph_mode = not context.executing_eagerly() 

2416 if self._in_graph_mode: 

2417 self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access 

2418 self._cached_pure_variable_scope = None 

2419 self._current_name_scope = None 

2420 if not isinstance(auxiliary_name_scope, bool): 

2421 raise TypeError("The auxiliary_name_scope must be `True` or `False`, " 

2422 "while get {}".format(auxiliary_name_scope)) 

2423 self._auxiliary_name_scope = auxiliary_name_scope 

2424 

2425 def __enter__(self): 

2426 # If the default graph is building a function, then we should not replace it 

2427 # with the cached graph. 

2428 if ops.get_default_graph().building_function: 

2429 self._building_function = True 

2430 else: 

2431 self._building_function = False 

2432 if self._in_graph_mode and not self._building_function: 

2433 self._graph_context_manager = self._graph.as_default() 

2434 self._graph_context_manager.__enter__() 

2435 if self._cached_pure_variable_scope is not None: 

2436 # Fast path for re-entering variable_scopes. We've held on to the pure 

2437 # variable scope from a previous successful __enter__, so we avoid some 

2438 # overhead by re-using that object. 

2439 if self._current_name_scope is not None: 

2440 self._current_name_scope.__enter__() 

2441 return self._cached_pure_variable_scope.__enter__() 

2442 

2443 try: 

2444 return self._enter_scope_uncached() 

2445 except: 

2446 if (self._in_graph_mode and not self._building_function and 

2447 self._graph_context_manager is not None): 

2448 self._graph_context_manager.__exit__(*sys.exc_info()) 

2449 raise 

2450 

2451 def _enter_scope_uncached(self): 

2452 """Enters the context manager when there is no cached scope yet. 

2453 

2454 Returns: 

2455 The entered variable scope. 

2456 

2457 Raises: 

2458 TypeError: A wrong type is passed as `scope` at __init__(). 

2459 ValueError: `reuse` is incorrectly set at __init__(). 

2460 """ 

2461 if self._auxiliary_name_scope: 

2462 # Create a new name scope later 

2463 current_name_scope = None 

2464 else: 

2465 # Reenter the current name scope 

2466 name_scope = ops.get_name_scope() 

2467 if name_scope: 

2468 # Hack to reenter 

2469 name_scope += "/" 

2470 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 

2471 else: 

2472 # Root scope 

2473 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 

2474 

2475 # IMPORTANT: Only assign to self._cached_pure_variable_scope and 

2476 # self._current_name_scope after successful __enter__() calls. 

2477 if self._name_or_scope is not None: 

2478 if not isinstance(self._name_or_scope, (VariableScope, str)): 

2479 raise TypeError("VariableScope: name_or_scope must be a string or " 

2480 "VariableScope.") 

2481 if isinstance(self._name_or_scope, str): 

2482 name_scope = self._name_or_scope 

2483 else: 

2484 name_scope = self._name_or_scope.name.split("/")[-1] 

2485 if name_scope or current_name_scope: 

2486 current_name_scope = current_name_scope or ops.name_scope( 

2487 name_scope, skip_on_eager=False) 

2488 try: 

2489 current_name_scope_name = current_name_scope.__enter__() 

2490 except: 

2491 current_name_scope.__exit__(*sys.exc_info()) 

2492 raise 

2493 self._current_name_scope = current_name_scope 

2494 if isinstance(self._name_or_scope, str): 

2495 old_name_scope = current_name_scope_name 

2496 else: 

2497 old_name_scope = self._name_or_scope.original_name_scope 

2498 pure_variable_scope = _pure_variable_scope( 

2499 self._name_or_scope, 

2500 reuse=self._reuse, 

2501 initializer=self._initializer, 

2502 regularizer=self._regularizer, 

2503 caching_device=self._caching_device, 

2504 partitioner=self._partitioner, 

2505 custom_getter=self._custom_getter, 

2506 old_name_scope=old_name_scope, 

2507 dtype=self._dtype, 

2508 use_resource=self._use_resource, 

2509 constraint=self._constraint) 

2510 try: 

2511 entered_pure_variable_scope = pure_variable_scope.__enter__() 

2512 except: 

2513 pure_variable_scope.__exit__(*sys.exc_info()) 

2514 raise 

2515 self._cached_pure_variable_scope = pure_variable_scope 

2516 return entered_pure_variable_scope 

2517 else: 

2518 self._current_name_scope = None 

2519 # This can only happen if someone is entering the root variable scope. 

2520 pure_variable_scope = _pure_variable_scope( 

2521 self._name_or_scope, 

2522 reuse=self._reuse, 

2523 initializer=self._initializer, 

2524 regularizer=self._regularizer, 

2525 caching_device=self._caching_device, 

2526 partitioner=self._partitioner, 

2527 custom_getter=self._custom_getter, 

2528 dtype=self._dtype, 

2529 use_resource=self._use_resource, 

2530 constraint=self._constraint) 

2531 try: 

2532 entered_pure_variable_scope = pure_variable_scope.__enter__() 

2533 except: 

2534 pure_variable_scope.__exit__(*sys.exc_info()) 

2535 raise 

2536 self._cached_pure_variable_scope = pure_variable_scope 

2537 return entered_pure_variable_scope 

2538 

2539 else: # Here name_or_scope is None. Using default name, but made unique. 

2540 if self._reuse: 

2541 raise ValueError("reuse=True cannot be used without a name_or_scope") 

2542 current_name_scope = current_name_scope or ops.name_scope( 

2543 self._default_name, skip_on_eager=False) 

2544 try: 

2545 current_name_scope_name = current_name_scope.__enter__() 

2546 except: 

2547 current_name_scope.__exit__(*sys.exc_info()) 

2548 raise 

2549 self._current_name_scope = current_name_scope 

2550 unique_default_name = _get_unique_variable_scope(self._default_name) 

2551 pure_variable_scope = _pure_variable_scope( 

2552 unique_default_name, 

2553 initializer=self._initializer, 

2554 regularizer=self._regularizer, 

2555 caching_device=self._caching_device, 

2556 partitioner=self._partitioner, 

2557 custom_getter=self._custom_getter, 

2558 old_name_scope=current_name_scope_name, 

2559 dtype=self._dtype, 

2560 use_resource=self._use_resource, 

2561 constraint=self._constraint) 

2562 try: 

2563 entered_pure_variable_scope = pure_variable_scope.__enter__() 

2564 except: 

2565 pure_variable_scope.__exit__(*sys.exc_info()) 

2566 raise 

2567 self._cached_pure_variable_scope = pure_variable_scope 

2568 return entered_pure_variable_scope 

2569 

2570 def __exit__(self, type_arg, value_arg, traceback_arg): 

2571 try: 

2572 self._cached_pure_variable_scope.__exit__(type_arg, value_arg, 

2573 traceback_arg) 

2574 finally: 

2575 try: 

2576 if self._current_name_scope: 

2577 self._current_name_scope.__exit__(type_arg, value_arg, 

2578 traceback_arg) 

2579 finally: 

2580 if self._in_graph_mode and not self._building_function: 

2581 self._graph_context_manager.__exit__(type_arg, value_arg, 

2582 traceback_arg) 

2583 

2584 

2585# pylint: disable=g-doc-return-or-yield 

2586@tf_export(v1=["variable_op_scope"]) 

2587@tf_contextlib.contextmanager 

2588def variable_op_scope(values, 

2589 name_or_scope, 

2590 default_name=None, 

2591 initializer=None, 

2592 regularizer=None, 

2593 caching_device=None, 

2594 partitioner=None, 

2595 custom_getter=None, 

2596 reuse=None, 

2597 dtype=None, 

2598 use_resource=None, 

2599 constraint=None): 

2600 """Deprecated: context manager for defining an op that creates variables.""" 

2601 logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated," 

2602 " use tf.variable_scope(name, default_name, values)") 

2603 with variable_scope( 

2604 name_or_scope, 

2605 default_name=default_name, 

2606 values=values, 

2607 initializer=initializer, 

2608 regularizer=regularizer, 

2609 caching_device=caching_device, 

2610 partitioner=partitioner, 

2611 custom_getter=custom_getter, 

2612 reuse=reuse, 

2613 dtype=dtype, 

2614 use_resource=use_resource, 

2615 constraint=constraint) as scope: 

2616 yield scope 

2617 

2618 

2619def _call_partitioner(partitioner, shape, dtype): 

2620 """Call partitioner validating its inputs/output. 

2621 

2622 Args: 

2623 partitioner: a function mapping `Tensor` shape and dtype to a list of 

2624 partitions. 

2625 shape: shape of the `Tensor` to partition, must have at least two 

2626 dimensions. 

2627 dtype: dtype of the elements in the `Tensor`. 

2628 

2629 Returns: 

2630 A list with elements >=1 and exactly one >1. The index of that 

2631 element corresponds to the partitioning axis. 

2632 """ 

2633 if not shape.is_fully_defined(): 

2634 raise ValueError("Shape of a new partitioned variable must be " 

2635 "fully defined, but instead was %s." % (shape,)) 

2636 if shape.ndims < 1: 

2637 raise ValueError("A partitioned Variable must have rank at least 1, " 

2638 "shape: %s" % shape) 

2639 

2640 slicing = partitioner(shape=shape, dtype=dtype) 

2641 if not isinstance(slicing, collections_abc.Sequence): 

2642 raise ValueError("Partitioner must return a sequence, but saw: %s" % 

2643 slicing) 

2644 if len(slicing) != shape.ndims: 

2645 raise ValueError( 

2646 "Partitioner returned a partition list that does not match the " 

2647 "Variable's rank: %s vs. %s" % (slicing, shape)) 

2648 if any(p < 1 for p in slicing): 

2649 raise ValueError("Partitioner returned zero partitions for some axes: %s" % 

2650 slicing) 

2651 if sum(p > 1 for p in slicing) > 1: 

2652 raise ValueError("Can only slice a variable along one dimension: " 

2653 "shape: %s, partitioning: %s" % (shape, slicing)) 

2654 return slicing 

2655 

2656 

2657# TODO(slebedev): could be inlined, but 

2658# `_VariableStore._get_partitioned_variable` is too complex even 

2659# without this logic. 

2660def _get_slice_dim_and_num_slices(slicing): 

2661 """Get slicing dimension and number of slices from the partitioner output.""" 

2662 for slice_dim, num_slices in enumerate(slicing): 

2663 if num_slices > 1: 

2664 break 

2665 else: 

2666 # Degenerate case: no partitioning applied. 

2667 slice_dim = 0 

2668 num_slices = 1 

2669 return slice_dim, num_slices 

2670 

2671 

2672def _iter_slices(full_shape, num_slices, slice_dim): 

2673 """Slices a given a shape along the specified dimension.""" 

2674 num_slices_with_excess = full_shape[slice_dim] % num_slices 

2675 offset = [0] * len(full_shape) 

2676 min_slice_len = full_shape[slice_dim] // num_slices 

2677 for i in range(num_slices): 

2678 shape = full_shape[:] 

2679 shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess) 

2680 yield offset[:], shape 

2681 offset[slice_dim] += shape[slice_dim] 

2682 

2683 

2684def _make_getter(captured_getter, captured_previous): 

2685 """Gets around capturing loop variables in python being broken.""" 

2686 return lambda **kwargs: captured_getter(captured_previous, **kwargs) 

2687 

2688 

2689# TODO(apassos) remove forwarding symbol 

2690variable = variable_v1.VariableV1 

2691 

2692# temporary references needed while refactors are in progress 

2693default_variable_creator = ref_variable.default_variable_creator 

2694_to_proto_fn = ref_variable._to_proto_fn # pylint: disable=protected-access 

2695_from_proto_fn = ref_variable._from_proto_fn # pylint: disable=protected-access 

2696 

2697 

2698@tf_export(v1=["variable_creator_scope"]) 

2699@tf_contextlib.contextmanager 

2700def variable_creator_scope_v1(variable_creator): 

2701 """Scope which defines a variable creation function to be used by variable(). 

2702 

2703 variable_creator is expected to be a function with the following signature: 

2704 

2705 ``` 

2706 def variable_creator(next_creator, **kwargs) 

2707 ``` 

2708 

2709 The creator is supposed to eventually call the next_creator to create a 

2710 variable if it does want to create a variable and not call Variable or 

2711 ResourceVariable directly. This helps make creators composable. A creator may 

2712 choose to create multiple variables, return already existing variables, or 

2713 simply register that a variable was created and defer to the next creators in 

2714 line. Creators can also modify the keyword arguments seen by the next 

2715 creators. 

2716 

2717 Custom getters in the variable scope will eventually resolve down to these 

2718 custom creators when they do create variables. 

2719 

2720 The valid keyword arguments in kwds are: 

2721 

2722 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

2723 which is the initial value for the Variable. The initial value must have 

2724 a shape specified unless `validate_shape` is set to False. Can also be a 

2725 callable with no argument that returns the initial value when called. In 

2726 that case, `dtype` must be specified. (Note that initializer functions 

2727 from init_ops.py must first be bound to a shape before being used here.) 

2728 * trainable: If `True`, the default, also adds the variable to the graph 

2729 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 

2730 the default list of variables to use by the `Optimizer` classes. 

2731 `trainable` defaults to `True`, unless `synchronization` is 

2732 set to `ON_READ`, in which case it defaults to `False`. 

2733 * collections: List of graph collections keys. The new variable is added to 

2734 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

2735 * validate_shape: If `False`, allows the variable to be initialized with a 

2736 value of unknown shape. If `True`, the default, the shape of 

2737 `initial_value` must be known. 

2738 * caching_device: Optional device string describing where the Variable 

2739 should be cached for reading. Defaults to the Variable's device. 

2740 If not `None`, caches on another device. Typical use is to cache 

2741 on the device where the Ops using the Variable reside, to deduplicate 

2742 copying through `Switch` and other conditional statements. 

2743 * name: Optional name for the variable. Defaults to `'Variable'` and gets 

2744 uniquified automatically. 

2745 * dtype: If set, initial_value will be converted to the given type. 

2746 If `None`, either the datatype will be kept (if `initial_value` is 

2747 a Tensor), or `convert_to_tensor` will decide. 

2748 * constraint: A constraint function to be applied to the variable after 

2749 updates by some algorithms. 

2750 * use_resource: if True, a ResourceVariable is always created. 

2751 * synchronization: Indicates when a distributed a variable will be 

2752 aggregated. Accepted values are constants defined in the class 

2753 `tf.VariableSynchronization`. By default the synchronization is set to 

2754 `AUTO` and the current `DistributionStrategy` chooses 

2755 when to synchronize. 

2756 * aggregation: Indicates how a distributed variable will be aggregated. 

2757 Accepted values are constants defined in the class 

2758 `tf.VariableAggregation`. 

2759 

2760 This set may grow over time, so it's important the signature of creators is as 

2761 mentioned above. 

2762 

2763 Args: 

2764 variable_creator: the passed creator 

2765 

2766 Yields: 

2767 A scope in which the creator is active 

2768 """ 

2769 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 

2770 yield 

2771 

2772 

2773# Note: only the docstrings differ between this and v1. 

2774@tf_export("variable_creator_scope", v1=[]) 

2775@tf_contextlib.contextmanager 

2776def variable_creator_scope(variable_creator): 

2777 """Scope which defines a variable creation function to be used by variable(). 

2778 

2779 variable_creator is expected to be a function with the following signature: 

2780 

2781 ``` 

2782 def variable_creator(next_creator, **kwargs) 

2783 ``` 

2784 

2785 The creator is supposed to eventually call the next_creator to create a 

2786 variable if it does want to create a variable and not call Variable or 

2787 ResourceVariable directly. This helps make creators composable. A creator may 

2788 choose to create multiple variables, return already existing variables, or 

2789 simply register that a variable was created and defer to the next creators in 

2790 line. Creators can also modify the keyword arguments seen by the next 

2791 creators. 

2792 

2793 Custom getters in the variable scope will eventually resolve down to these 

2794 custom creators when they do create variables. 

2795 

2796 The valid keyword arguments in kwds are: 

2797 

2798 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

2799 which is the initial value for the Variable. The initial value must have 

2800 a shape specified unless `validate_shape` is set to False. Can also be a 

2801 callable with no argument that returns the initial value when called. In 

2802 that case, `dtype` must be specified. (Note that initializer functions 

2803 from init_ops.py must first be bound to a shape before being used here.) 

2804 * trainable: If `True`, the default, GradientTapes automatically watch 

2805 uses of this Variable. 

2806 * validate_shape: If `False`, allows the variable to be initialized with a 

2807 value of unknown shape. If `True`, the default, the shape of 

2808 `initial_value` must be known. 

2809 * caching_device: Optional device string describing where the Variable 

2810 should be cached for reading. Defaults to the Variable's device. 

2811 If not `None`, caches on another device. Typical use is to cache 

2812 on the device where the Ops using the Variable reside, to deduplicate 

2813 copying through `Switch` and other conditional statements. 

2814 * name: Optional name for the variable. Defaults to `'Variable'` and gets 

2815 uniquified automatically. 

2816 dtype: If set, initial_value will be converted to the given type. 

2817 If `None`, either the datatype will be kept (if `initial_value` is 

2818 a Tensor), or `convert_to_tensor` will decide. 

2819 * constraint: A constraint function to be applied to the variable after 

2820 updates by some algorithms. 

2821 * synchronization: Indicates when a distributed a variable will be 

2822 aggregated. Accepted values are constants defined in the class 

2823 `tf.VariableSynchronization`. By default the synchronization is set to 

2824 `AUTO` and the current `DistributionStrategy` chooses 

2825 when to synchronize. 

2826 * aggregation: Indicates how a distributed variable will be aggregated. 

2827 Accepted values are constants defined in the class 

2828 `tf.VariableAggregation`. 

2829 

2830 This set may grow over time, so it's important the signature of creators is as 

2831 mentioned above. 

2832 

2833 Args: 

2834 variable_creator: the passed creator 

2835 

2836 Yields: 

2837 A scope in which the creator is active 

2838 """ 

2839 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 

2840 yield