Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py: 27%

897 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Ops to use variables as resources.""" 

16 

17# pylint: disable=g-bad-name 

18import contextlib 

19import functools 

20import weakref 

21 

22import numpy as np 

23 

24from tensorflow.core.framework import attr_value_pb2 

25from tensorflow.core.framework import variable_pb2 

26from tensorflow.core.function import trace_type 

27from tensorflow.core.protobuf import struct_pb2 

28from tensorflow.python.checkpoint import tensor_callable 

29from tensorflow.python.client import pywrap_tf_session 

30from tensorflow.python.compat import compat as forward_compat 

31from tensorflow.python.eager import context 

32from tensorflow.python.eager import record 

33from tensorflow.python.eager import tape 

34from tensorflow.python.framework import auto_control_deps_utils as acd 

35from tensorflow.python.framework import composite_tensor 

36from tensorflow.python.framework import composite_tensor_gradient 

37from tensorflow.python.framework import constant_op 

38from tensorflow.python.framework import cpp_shape_inference_pb2 

39from tensorflow.python.framework import dtypes 

40from tensorflow.python.framework import errors 

41from tensorflow.python.framework import indexed_slices 

42from tensorflow.python.framework import ops 

43from tensorflow.python.framework import tensor as tensor_module 

44from tensorflow.python.framework import tensor_conversion_registry 

45from tensorflow.python.framework import tensor_shape 

46from tensorflow.python.framework import tensor_spec 

47from tensorflow.python.ops import array_ops 

48from tensorflow.python.ops import gen_array_ops 

49from tensorflow.python.ops import gen_resource_variable_ops 

50from tensorflow.python.ops import gen_state_ops 

51from tensorflow.python.ops import handle_data_util 

52from tensorflow.python.ops import math_ops 

53from tensorflow.python.ops import state_ops 

54from tensorflow.python.ops import variables 

55# go/tf-wildcard-import 

56# pylint: disable=wildcard-import 

57from tensorflow.python.ops.gen_resource_variable_ops import * 

58# pylint: enable=wildcard-import 

59from tensorflow.python.saved_model import nested_structure_coder 

60from tensorflow.python.trackable import base as trackable 

61from tensorflow.python.types import core 

62from tensorflow.python.util import _pywrap_utils 

63from tensorflow.python.util import compat 

64from tensorflow.python.util.deprecation import deprecated 

65from tensorflow.python.util.tf_export import tf_export 

66 

67acd.register_read_only_resource_op("ReadVariableOp") 

68acd.register_read_only_resource_op("VariableShape") 

69acd.register_read_only_resource_op("ResourceGather") 

70acd.register_read_only_resource_op("ResourceGatherNd") 

71acd.register_read_only_resource_op("_ReadVariablesOp") 

72 

73# TODO(allenl): Remove this alias and migrate callers. 

74get_resource_handle_data = handle_data_util.get_resource_handle_data 

75 

76 

77def get_eager_safe_handle_data(handle): 

78 """Get the data handle from the Tensor `handle`.""" 

79 assert isinstance(handle, ops.Tensor) 

80 

81 if isinstance(handle, ops.EagerTensor): 

82 return handle._handle_data # pylint: disable=protected-access 

83 else: 

84 return get_resource_handle_data(handle) 

85 

86 

87def _set_handle_shapes_and_types(tensor, handle_data, graph_mode): 

88 """Sets the shape inference result HandleData on tensor. 

89 

90 Args: 

91 tensor: A `Tensor` or `EagerTensor`. 

92 handle_data: A `CppShapeInferenceResult.HandleData`. 

93 graph_mode: A python bool. 

94 """ 

95 tensor._handle_data = handle_data # pylint: disable=protected-access 

96 if not graph_mode: 

97 return 

98 

99 # Not an EagerTensor, so a graph tensor. 

100 shapes, types = zip( 

101 *[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) 

102 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] 

103 shapes = [ 

104 [d.size for d in s.dim] # pylint: disable=g-complex-comprehension 

105 if not s.unknown_rank else None for s in shapes 

106 ] 

107 with tensor._op.graph._c_graph.get() as c_graph: # pylint: disable=protected-access 

108 pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( 

109 c_graph, 

110 tensor._as_tf_output(), # pylint: disable=protected-access 

111 shapes, 

112 ranks, 

113 types) 

114 

115 

116def _combine_handle_data(handle, initial_value): 

117 """Concats HandleData from tensors `handle` and `initial_value`. 

118 

119 Args: 

120 handle: A `Tensor` of dtype `resource`. 

121 initial_value: A `Tensor`. 

122 

123 Returns: 

124 A `CppShapeInferenceResult.HandleData`. If `initial_value` has dtype 

125 `variant`, the `HandleData` contains the concatenation of the shape_and_type 

126 from both `handle` and `initial_value`. 

127 

128 Raises: 

129 RuntimeError: If handle, which was returned by VarHandleOp, either has 

130 no handle data, or its len(handle_data.shape_and_type) != 1. 

131 """ 

132 assert handle.dtype == dtypes.resource 

133 

134 variable_handle_data = get_eager_safe_handle_data(handle) 

135 

136 if initial_value.dtype != dtypes.variant: 

137 return variable_handle_data 

138 

139 extra_handle_data = get_eager_safe_handle_data(initial_value) 

140 if extra_handle_data is not None and extra_handle_data.is_set: 

141 if (variable_handle_data is None or not variable_handle_data.is_set or 

142 len(variable_handle_data.shape_and_type) != 1): 

143 raise RuntimeError( 

144 "Expected VarHandleOp to return a length==1 shape_and_type, " 

145 f"but saw: '{variable_handle_data}'") 

146 variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) 

147 return variable_handle_data 

148 

149 

150def _variable_handle_from_shape_and_dtype(shape, 

151 dtype, 

152 shared_name, 

153 name, 

154 graph_mode, 

155 initial_value=None): 

156 """Create a variable handle, copying in handle data from `initial_value`.""" 

157 container = ops.get_default_graph()._container # pylint: disable=protected-access 

158 if container is None: 

159 container = "" 

160 shape = tensor_shape.as_shape(shape) 

161 dtype = dtypes.as_dtype(dtype) 

162 if not graph_mode: 

163 if shared_name is not None: 

164 raise errors.InternalError( 

165 node_def=None, 

166 op=None, 

167 message="Using an explicit shared_name is " 

168 "not allowed when executing eagerly.") 

169 shared_name = context.anonymous_name() 

170 

171 handle = gen_resource_variable_ops.var_handle_op( 

172 shape=shape, 

173 dtype=dtype, 

174 shared_name=shared_name, 

175 name=name, 

176 container=container) 

177 if initial_value is None: 

178 initial_value = handle 

179 if graph_mode: 

180 full_handle_data = _combine_handle_data(handle, initial_value) 

181 _set_handle_shapes_and_types(handle, full_handle_data, graph_mode) 

182 return handle 

183 else: 

184 handle_data = handle_data_util.create_handle_data(shape, dtype) 

185 if initial_value is not None and initial_value.dtype == dtypes.variant: 

186 extra_handle_data = get_eager_safe_handle_data(initial_value) 

187 if extra_handle_data is not None and extra_handle_data.is_set: 

188 if (not handle_data.is_set or len(handle_data.shape_and_type) != 1): 

189 raise RuntimeError( 

190 "Expected VarHandleOp to return a length==1 shape_and_type, " 

191 f"but saw: '{handle_data}'") 

192 handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) 

193 

194 _set_handle_shapes_and_types(handle, handle_data, graph_mode) 

195 return handle 

196 

197 

198def eager_safe_variable_handle(initial_value, shape, shared_name, name, 

199 graph_mode): 

200 """Creates a variable handle with information to do shape inference. 

201 

202 The dtype is read from `initial_value` and stored in the returned 

203 resource tensor's handle data. 

204 

205 If `initial_value.dtype == tf.variant`, we additionally extract the handle 

206 data (if any) from `initial_value` and append it to the `handle_data`. 

207 In this case, the returned tensor's handle data is in the form 

208 

209 ``` 

210 is_set: true 

211 shape_and_type { 

212 shape { 

213 // initial_value.shape 

214 } 

215 dtype: DT_VARIANT 

216 } 

217 shape_and_type { 

218 // handle_data(initial_value).shape_and_type[0] 

219 } 

220 shape_and_type { 

221 // handle_data(initial_value).shape_and_type[1] 

222 } 

223 ... 

224 ``` 

225 

226 Ops that read from this tensor, such as `ReadVariableOp` and 

227 `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]` 

228 correspond to the handle data of the variant(s) stored in the Variable. 

229 

230 Args: 

231 initial_value: A `Tensor`. 

232 shape: The shape of the handle data. Can be `TensorShape(None)` (i.e. 

233 unknown shape). 

234 shared_name: A string. 

235 name: A string. 

236 graph_mode: A python bool. 

237 

238 Returns: 

239 The handle, a `Tensor` of type `resource`. 

240 """ 

241 dtype = initial_value.dtype.base_dtype 

242 return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name, 

243 graph_mode, initial_value) 

244 

245 

246@contextlib.contextmanager 

247def _handle_graph(handle): 

248 # Note: might have an eager tensor but not be executing eagerly when building 

249 # functions. 

250 if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or 

251 ops.has_default_graph()): 

252 yield 

253 else: 

254 with handle.graph.as_default(): 

255 yield 

256 

257 

258class EagerResourceDeleter: 

259 """An object which cleans up a resource handle. 

260 

261 An alternative to defining a __del__ method on an object. The intended use is 

262 that ResourceVariables or other objects with resource handles will maintain a 

263 single reference to this object. When the parent object is collected, this 

264 object will be too. Even if the parent object is part of a reference cycle, 

265 the cycle will be collectable. 

266 """ 

267 

268 __slots__ = ["_handle", "_handle_device", "_context"] 

269 

270 def __init__(self, handle, handle_device): 

271 if not isinstance(handle, ops.Tensor): 

272 raise ValueError( 

273 (f"Passed handle={handle} to EagerResourceDeleter. Was expecting " 

274 f"the handle to be a `tf.Tensor`.")) 

275 self._handle = handle 

276 self._handle_device = handle_device 

277 # This is held since the __del__ function runs an op, and if the context() 

278 # is collected before this object, there will be a segfault when running the 

279 # op. 

280 self._context = context.context() 

281 

282 def __del__(self): 

283 # Resources follow object-identity when executing eagerly, so it is safe to 

284 # delete the resource we have a handle to. 

285 try: 

286 # A packed EagerTensor doesn't own any resource. 

287 if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed: 

288 return 

289 # This resource was created in eager mode. However, this destructor may be 

290 # running in graph mode (especially during unit tests). To clean up 

291 # successfully, we switch back into eager mode temporarily. 

292 with context.eager_mode(): 

293 with ops.device(self._handle_device): 

294 gen_resource_variable_ops.destroy_resource_op( 

295 self._handle, ignore_lookup_error=True) 

296 except TypeError: 

297 # Suppress some exceptions, mainly for the case when we're running on 

298 # module deletion. Things that can go wrong include the context module 

299 # already being unloaded, self._handle._handle_data no longer being 

300 # valid, and so on. Printing warnings in these cases is silly 

301 # (exceptions raised from __del__ are printed as warnings to stderr). 

302 pass # 'NoneType' object is not callable when the handle has been 

303 # partially unloaded. 

304 except AttributeError: 

305 pass # 'NoneType' object has no attribute 'eager_mode' when context has 

306 # been unloaded. Will catch other module unloads as well. 

307 

308 

309def shape_safe_assign_variable_handle(handle, shape, value, name=None): 

310 """Helper that checks shape compatibility and assigns variable.""" 

311 with _handle_graph(handle): 

312 value_tensor = ops.convert_to_tensor(value) 

313 shape.assert_is_compatible_with(value_tensor.shape) 

314 return gen_resource_variable_ops.assign_variable_op( 

315 handle, value_tensor, name=name) 

316 

317 

318def _maybe_set_handle_data(dtype, handle, tensor): 

319 if dtype == dtypes.variant: 

320 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 

321 # variant's handle data. Extract it. 

322 handle_data = get_eager_safe_handle_data(handle) 

323 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 

324 tensor._handle_data = ( # pylint: disable=protected-access 

325 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 

326 is_set=True, shape_and_type=handle_data.shape_and_type[1:])) 

327 

328 

329def variable_accessed(variable): 

330 """Records that `variable` was accessed for the tape and FuncGraph.""" 

331 if hasattr(ops.get_default_graph(), "watch_variable"): 

332 ops.get_default_graph().watch_variable(variable) 

333 if variable.trainable: 

334 tape.variable_accessed(variable) 

335 

336 

337def default_variable_creator_v2(next_creator=None, **kwargs): 

338 """Default variable creator.""" 

339 assert next_creator is None 

340 initial_value = kwargs.get("initial_value", None) 

341 trainable = kwargs.get("trainable", None) 

342 validate_shape = kwargs.get("validate_shape", True) 

343 caching_device = kwargs.get("caching_device", None) 

344 name = kwargs.get("name", None) 

345 variable_def = kwargs.get("variable_def", None) 

346 dtype = kwargs.get("dtype", None) 

347 import_scope = kwargs.get("import_scope", None) 

348 constraint = kwargs.get("constraint", None) 

349 distribute_strategy = kwargs.get("distribute_strategy", None) 

350 synchronization = kwargs.get("synchronization", None) 

351 aggregation = kwargs.get("aggregation", None) 

352 shape = kwargs.get("shape", None) 

353 experimental_enable_variable_lifting = kwargs.get( 

354 "experimental_enable_variable_lifting", None) 

355 

356 return ResourceVariable( 

357 initial_value=initial_value, 

358 trainable=trainable, 

359 validate_shape=validate_shape, 

360 caching_device=caching_device, 

361 name=name, 

362 dtype=dtype, 

363 constraint=constraint, 

364 variable_def=variable_def, 

365 import_scope=import_scope, 

366 distribute_strategy=distribute_strategy, 

367 synchronization=synchronization, 

368 aggregation=aggregation, 

369 shape=shape, 

370 experimental_enable_variable_lifting=experimental_enable_variable_lifting, 

371 ) 

372 

373 

374variables.default_variable_creator_v2 = default_variable_creator_v2 

375 

376 

377class BaseResourceVariable(variables.Variable, core.Tensor): 

378 """A python variable from an existing handle.""" 

379 

380 # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in. 

381 def __init__( # pylint: disable=super-init-not-called 

382 self, 

383 trainable=None, 

384 shape=None, 

385 dtype=None, 

386 handle=None, 

387 constraint=None, 

388 synchronization=None, 

389 aggregation=None, 

390 distribute_strategy=None, 

391 name=None, 

392 unique_id=None, 

393 handle_name=None, 

394 graph_element=None, 

395 initial_value=None, 

396 initializer_op=None, 

397 is_initialized_op=None, 

398 cached_value=None, 

399 save_slice_info=None, 

400 caching_device=None, 

401 in_graph_mode=None, 

402 validate_shape=True, 

403 **unused_kwargs): 

404 """Creates a variable from a handle. 

405 

406 Args: 

407 trainable: If `True`, GradientTapes automatically watch uses of this 

408 Variable. 

409 shape: The variable's shape. This shape can be set to tf.TensorShape(None) 

410 in order to assign values of different shapes to this variable. 

411 Otherwise (i.e. if the shape is fully determined), it will trigger run 

412 time checks to ensure that each assignment is of the same shape. 

413 dtype: The variable's dtype. 

414 handle: The variable's handle 

415 constraint: An optional projection function to be applied to the variable 

416 after being updated by an `Optimizer` (e.g. used to implement norm 

417 constraints or value constraints for layer weights). The function must 

418 take as input the unprojected Tensor representing the value of the 

419 variable and return the Tensor for the projected value (which must have 

420 the same shape). Constraints are not safe to use when doing asynchronous 

421 distributed training. 

422 synchronization: Indicates when a distributed a variable will be 

423 aggregated. Accepted values are constants defined in the class 

424 `tf.VariableSynchronization`. By default the synchronization is set to 

425 `AUTO` and the current `DistributionStrategy` chooses when to 

426 synchronize. 

427 aggregation: Indicates how a distributed variable will be aggregated. 

428 Accepted values are constants defined in the class 

429 `tf.VariableAggregation`. 

430 distribute_strategy: The distribution strategy this variable was created 

431 under. 

432 name: The name for this variable. 

433 unique_id: Internal. Unique ID for this variable's handle. 

434 handle_name: The name for the variable's handle. 

435 graph_element: Optional, required only in session.run-mode. Pre-created 

436 tensor which reads this variable's value. 

437 initial_value: Optional. Variable's initial value. 

438 initializer_op: Operation which assigns the variable's initial value. 

439 is_initialized_op: Pre-created operation to check whether this variable is 

440 initialized. 

441 cached_value: Pre-created operation to read this variable in a specific 

442 device. 

443 save_slice_info: Metadata for variable partitioning. 

444 caching_device: Optional device string or function describing where the 

445 Variable should be cached for reading. Defaults to the Variable's 

446 device. If not `None`, caches on another device. Typical use is to 

447 cache on the device where the Ops using the Variable reside, to 

448 deduplicate copying through `Switch` and other conditional statements. 

449 in_graph_mode: whether we are executing in TF1 graph mode. If None, will 

450 detect within the function. This is to avoid repeated init_scope() 

451 conetxt entrances which can add up. 

452 validate_shape: If `False`, allows the variable to be initialized with a 

453 value of unknown shape. If `True`, the default, the shape of 

454 `initial_value` must be known. 

455 """ 

456 if in_graph_mode is None: 

457 with ops.init_scope(): 

458 self._in_graph_mode = not context.executing_eagerly() 

459 else: 

460 self._in_graph_mode = in_graph_mode 

461 synchronization, aggregation, trainable = ( 

462 variables.validate_synchronization_aggregation_trainable( 

463 synchronization, aggregation, trainable, name)) 

464 self._trainable = trainable 

465 self._synchronization = synchronization 

466 self._aggregation = aggregation 

467 self._save_slice_info = save_slice_info 

468 self._initial_value = initial_value 

469 self._initializer_op = initializer_op 

470 self._is_initialized_op = is_initialized_op 

471 self._graph_element = graph_element 

472 self._caching_device = caching_device 

473 self._cached_value = cached_value 

474 self._distribute_strategy = distribute_strategy 

475 # Store the graph key so optimizers know how to only retrieve variables from 

476 # this graph. Guaranteed to be the same as the eager graph_key. 

477 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 

478 self._shape = tensor_shape.as_shape(shape) 

479 self._dtype = dtypes.as_dtype(dtype) 

480 self._handle = handle 

481 self._unique_id = unique_id 

482 if handle_name is None: 

483 self._handle_name = "Variable:0" 

484 else: 

485 self._handle_name = handle_name + ":0" 

486 self._constraint = constraint 

487 self._cached_shape_as_list = None 

488 self._validate_shape = validate_shape 

489 

490 def __repr__(self): 

491 if context.executing_eagerly() and not self._in_graph_mode: 

492 # If we cannot read the value for any reason (e.g. variable uninitialized 

493 # during tf.function tracing), still produce a __repr__. Note that for 

494 # async eager, errors due to uninitialized variables will raise in 

495 # ops.value_text when the handle is resolved, so we need to keep that 

496 # under the try...except if we want to suppress them. 

497 try: 

498 with ops.device(self.device): 

499 value_text = ops.value_text(self.read_value(), is_repr=True) 

500 except: # pylint: disable=bare-except 

501 value_text = "numpy=<unavailable>" 

502 

503 return "<tf.Variable '%s' shape=%s dtype=%s, %s>" % ( 

504 self.name, self.get_shape(), self.dtype.name, value_text) 

505 else: 

506 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 

507 self.name, self.get_shape(), self.dtype.name) 

508 

509 def __tf_tracing_type__(self, signature_context): 

510 alias_id = signature_context.alias_global_id(self._handle._id) # pylint:disable=protected-access 

511 # TODO(xjun): Create variable placeholders directly from VariableSpec 

512 # without using original values. 

513 signature_context.add_placeholder(alias_id, self) 

514 return VariableSpec(shape=self.shape, 

515 dtype=self.dtype, 

516 trainable=self.trainable, 

517 alias_id=alias_id) 

518 

519 @contextlib.contextmanager 

520 def _assign_dependencies(self): 

521 """Makes assignments depend on the cached value, if any. 

522 

523 This prevents undefined behavior with reads not ordered wrt writes. 

524 

525 Yields: 

526 None. 

527 """ 

528 if self._cached_value is not None: 

529 with ops.control_dependencies([self._cached_value]): 

530 yield 

531 else: 

532 yield 

533 

534 def __array__(self, dtype=None): 

535 """Allows direct conversion to a numpy array. 

536 

537 >>> np.array(tf.Variable([1.0])) 

538 array([1.], dtype=float32) 

539 

540 Returns: 

541 The variable value as a numpy array. 

542 """ 

543 # You can't return `self.numpy()` here because for scalars 

544 # that raises: 

545 # ValueError: object __array__ method not producing an array 

546 # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give 

547 # the same error. The `EagerTensor` class must be doing something behind the 

548 # scenes to make `np.array(tf.constant(1))` work. 

549 return np.asarray(self.numpy(), dtype=dtype) 

550 

551 def __nonzero__(self): 

552 return self.__bool__() 

553 

554 def __bool__(self): 

555 return bool(self.read_value()) 

556 

557 def __copy__(self): 

558 return self 

559 

560 def __deepcopy__(self, memo): 

561 if not context.executing_eagerly(): 

562 raise NotImplementedError( 

563 "__deepcopy__() is only available when eager execution is enabled.") 

564 copied_variable = ResourceVariable( 

565 initial_value=self.read_value(), 

566 trainable=self._trainable, 

567 constraint=self._constraint, 

568 dtype=self._dtype, 

569 name=self._shared_name, 

570 distribute_strategy=self._distribute_strategy, 

571 synchronization=self.synchronization, 

572 aggregation=self.aggregation) 

573 memo[self._unique_id] = copied_variable 

574 return copied_variable 

575 

576 @property 

577 def dtype(self): 

578 """The dtype of this variable.""" 

579 return self._dtype 

580 

581 @property 

582 def device(self): 

583 """The device this variable is on.""" 

584 return self.handle.device 

585 

586 @property 

587 def graph(self): 

588 """The `Graph` of this variable.""" 

589 return self.handle.graph 

590 

591 @property 

592 def name(self): 

593 """The name of the handle for this variable.""" 

594 return self._handle_name 

595 

596 @property 

597 def shape(self): 

598 """The shape of this variable.""" 

599 return self._shape 

600 

601 def set_shape(self, shape): 

602 self._shape = self._shape.merge_with(shape) 

603 

604 def _shape_as_list(self): 

605 if self.shape.ndims is None: 

606 return None 

607 return [dim.value for dim in self.shape.dims] 

608 

609 def _shape_tuple(self): 

610 shape = self._shape_as_list() 

611 if shape is None: 

612 return None 

613 return tuple(shape) 

614 

615 @property 

616 def create(self): 

617 """The op responsible for initializing this variable.""" 

618 if not self._in_graph_mode: 

619 raise RuntimeError("This operation is not supported " 

620 "when eager execution is enabled.") 

621 return self._initializer_op 

622 

623 @property 

624 def handle(self): 

625 """The handle by which this variable can be accessed.""" 

626 return self._handle 

627 

628 def value(self): 

629 """A cached operation which reads the value of this variable.""" 

630 if self._cached_value is not None: 

631 return self._cached_value 

632 with ops.colocate_with(None, ignore_existing=True): 

633 return self._read_variable_op() 

634 

635 def _as_graph_element(self): 

636 """Conversion function for Graph.as_graph_element().""" 

637 return self._graph_element 

638 

639 @property 

640 def initializer(self): 

641 """The op responsible for initializing this variable.""" 

642 return self._initializer_op 

643 

644 @property 

645 def initial_value(self): 

646 """Returns the Tensor used as the initial value for the variable.""" 

647 if context.executing_eagerly(): 

648 raise RuntimeError("This property is not supported " 

649 "when eager execution is enabled.") 

650 return self._initial_value 

651 

652 @property 

653 def constraint(self): 

654 """Returns the constraint function associated with this variable. 

655 

656 Returns: 

657 The constraint function that was passed to the variable constructor. 

658 Can be `None` if no constraint was passed. 

659 """ 

660 return self._constraint 

661 

662 @property 

663 def op(self): 

664 """The op for this variable.""" 

665 return self.handle.op 

666 

667 @property 

668 def trainable(self): 

669 return self._trainable 

670 

671 @property 

672 def synchronization(self): 

673 return self._synchronization 

674 

675 @property 

676 def aggregation(self): 

677 return self._aggregation 

678 

679 def eval(self, session=None): 

680 """Evaluates and returns the value of this variable.""" 

681 if context.executing_eagerly(): 

682 raise RuntimeError("This operation is not supported " 

683 "when eager execution is enabled.") 

684 return self._graph_element.eval(session=session) 

685 

686 def numpy(self): 

687 if context.executing_eagerly(): 

688 return self.read_value().numpy() 

689 raise NotImplementedError( 

690 "numpy() is only available when eager execution is enabled.") 

691 

692 @deprecated(None, "Prefer Dataset.range instead.") 

693 def count_up_to(self, limit): 

694 """Increments this variable until it reaches `limit`. 

695 

696 When that Op is run it tries to increment the variable by `1`. If 

697 incrementing the variable would bring it above `limit` then the Op raises 

698 the exception `OutOfRangeError`. 

699 

700 If no error is raised, the Op outputs the value of the variable before 

701 the increment. 

702 

703 This is essentially a shortcut for `count_up_to(self, limit)`. 

704 

705 Args: 

706 limit: value at which incrementing the variable raises an error. 

707 

708 Returns: 

709 A `Tensor` that will hold the variable value before the increment. If no 

710 other Op modifies this variable, the values produced will all be 

711 distinct. 

712 """ 

713 return gen_state_ops.resource_count_up_to( 

714 self.handle, limit=limit, T=self.dtype) 

715 

716 def _export_to_saved_model_graph(self, object_map=None, tensor_map=None, 

717 options=None, **kwargs): 

718 """For implementing `Trackable`.""" 

719 new_variable = None 

720 if options.experimental_variable_policy._save_variable_devices(): # pylint:disable=protected-access 

721 with ops.device(self.device): 

722 new_variable = copy_to_graph_uninitialized(self) 

723 else: 

724 new_variable = copy_to_graph_uninitialized(self) 

725 object_map[self] = new_variable 

726 tensor_map[self.handle] = new_variable.handle 

727 return [self.handle] 

728 

729 def _serialize_to_tensors(self): 

730 """Implements Trackable._serialize_to_tensors.""" 

731 

732 def _read_variable_closure(): 

733 v = self 

734 with ops.device(v.device): 

735 if context.executing_eagerly() and not v.is_initialized(): 

736 # A SaveSpec tensor value of `None` indicates that the variable is 

737 # uninitialized. 

738 return None 

739 # Read the variable without making a copy to limit memory usage. 

740 x = v.read_value_no_copy() 

741 # To allow variables placed on non-CPU devices to be checkpointed, 

742 # we copy them to CPU on the same machine first. 

743 with ops.device("/device:CPU:0"): 

744 return array_ops.identity(x) 

745 

746 return { 

747 trackable.VARIABLE_VALUE_KEY: 

748 tensor_callable.Callable( 

749 _read_variable_closure, dtype=self.dtype, device=self.device) 

750 } 

751 

752 def _restore_from_tensors(self, restored_tensors): 

753 """Implements Trackable._restore_from_tensors.""" 

754 with ops.device(self.device): 

755 restored_tensor = array_ops.identity( 

756 restored_tensors[trackable.VARIABLE_VALUE_KEY]) 

757 try: 

758 assigned_variable = shape_safe_assign_variable_handle( 

759 self.handle, self.shape, restored_tensor) 

760 except ValueError as e: 

761 raise ValueError( 

762 f"Received incompatible tensor with shape {restored_tensor.shape} " 

763 f"when attempting to restore variable with shape {self.shape} " 

764 f"and name {self.name}.") from e 

765 return assigned_variable 

766 

767 def _read_variable_op(self, no_copy=False): 

768 """Reads the value of the variable. 

769 

770 If the variable is in copy-on-read mode and `no_copy` is True, the variable 

771 is converted to copy-on-write mode before it is read. 

772 

773 Args: 

774 no_copy: Whether to prevent a copy of the variable. 

775 

776 Returns: 

777 The value of the variable. 

778 """ 

779 variable_accessed(self) 

780 

781 def read_and_set_handle(no_copy): 

782 if no_copy and forward_compat.forward_compatible(2022, 5, 3): 

783 gen_resource_variable_ops.disable_copy_on_read(self.handle) 

784 result = gen_resource_variable_ops.read_variable_op( 

785 self.handle, self._dtype) 

786 _maybe_set_handle_data(self._dtype, self.handle, result) 

787 return result 

788 

789 if getattr(self, "_caching_device", None) is not None: 

790 with ops.colocate_with(None, ignore_existing=True): 

791 with ops.device(self._caching_device): 

792 result = read_and_set_handle(no_copy) 

793 else: 

794 result = read_and_set_handle(no_copy) 

795 

796 if not context.executing_eagerly(): 

797 # Note that if a control flow context is active the input of the read op 

798 # might not actually be the handle. This line bypasses it. 

799 record.record_operation( 

800 "ReadVariableOp", [result], [self.handle], 

801 backward_function=lambda x: [x], 

802 forward_function=lambda x: [x]) 

803 return result 

804 

805 def read_value(self): 

806 """Constructs an op which reads the value of this variable. 

807 

808 Should be used when there are multiple reads, or when it is desirable to 

809 read the value only after some condition is true. 

810 

811 Returns: 

812 The value of the variable. 

813 """ 

814 with ops.name_scope("Read"): 

815 value = self._read_variable_op() 

816 # Return an identity so it can get placed on whatever device the context 

817 # specifies instead of the device where the variable is. 

818 return array_ops.identity(value) 

819 

820 def read_value_no_copy(self): 

821 """Constructs an op which reads the value of this variable without copy. 

822 

823 The variable is read without making a copy even when it has been sparsely 

824 accessed. Variables in copy-on-read mode will be converted to copy-on-write 

825 mode. 

826 

827 Returns: 

828 The value of the variable. 

829 """ 

830 with ops.name_scope("Read"): 

831 value = self._read_variable_op(no_copy=True) 

832 # Return an identity so it can get placed on whatever device the context 

833 # specifies instead of the device where the variable is. 

834 return array_ops.identity(value) 

835 

836 def sparse_read(self, indices, name=None): 

837 """Reads the value of this variable sparsely, using `gather`.""" 

838 with ops.name_scope("Gather" if name is None else name) as name: 

839 variable_accessed(self) 

840 value = gen_resource_variable_ops.resource_gather( 

841 self.handle, indices, dtype=self._dtype, name=name) 

842 

843 if self._dtype == dtypes.variant: 

844 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 

845 # variant's handle data. Extract it. 

846 handle_data = get_eager_safe_handle_data(self.handle) 

847 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 

848 value._handle_data = ( # pylint: disable=protected-access 

849 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 

850 is_set=True, shape_and_type=handle_data.shape_and_type[1:])) 

851 return array_ops.identity(value) 

852 

853 return value 

854 

855 def gather_nd(self, indices, name=None): 

856 """Reads the value of this variable sparsely, using `gather_nd`.""" 

857 with ops.name_scope("GatherNd" if name is None else name) as name: 

858 if self.trainable: 

859 variable_accessed(self) 

860 value = gen_resource_variable_ops.resource_gather_nd( 

861 self.handle, indices, dtype=self._dtype, name=name) 

862 

863 return array_ops.identity(value) 

864 

865 def to_proto(self, export_scope=None): 

866 """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. 

867 

868 Args: 

869 export_scope: Optional `string`. Name scope to remove. 

870 

871 Raises: 

872 RuntimeError: If run in EAGER mode. 

873 

874 Returns: 

875 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 

876 in the specified name scope. 

877 """ 

878 if context.executing_eagerly(): 

879 raise RuntimeError("This operation is not supported " 

880 "when eager execution is enabled.") 

881 if export_scope is None or self.handle.name.startswith(export_scope): 

882 var_def = variable_pb2.VariableDef() 

883 var_def.variable_name = ops.strip_name_scope(self.handle.name, 

884 export_scope) 

885 if self._initial_value is not None: 

886 # This is inside an if-statement for backwards compatibility, since 

887 # self._initial_value might be None for variables constructed from old 

888 # protos. 

889 var_def.initial_value_name = ops.strip_name_scope( 

890 self._initial_value.name, export_scope) 

891 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 

892 export_scope) 

893 if self._cached_value is not None: 

894 var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, 

895 export_scope) 

896 else: 

897 # Store the graph_element here 

898 var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, 

899 export_scope) 

900 var_def.is_resource = True 

901 var_def.trainable = self.trainable 

902 var_def.synchronization = self.synchronization.value 

903 var_def.aggregation = self.aggregation.value 

904 if self._save_slice_info: 

905 var_def.save_slice_info_def.MergeFrom( 

906 self._save_slice_info.to_proto(export_scope=export_scope)) 

907 return var_def 

908 else: 

909 return None 

910 

911 @staticmethod 

912 def from_proto(variable_def, import_scope=None): 

913 if context.executing_eagerly(): 

914 raise RuntimeError("This operation is not supported " 

915 "when eager execution is enabled.") 

916 return ResourceVariable( 

917 variable_def=variable_def, import_scope=import_scope) 

918 

919 __array_priority__ = 100 

920 

921 def is_initialized(self, name=None): 

922 """Checks whether a resource variable has been initialized. 

923 

924 Outputs boolean scalar indicating whether the tensor has been initialized. 

925 

926 Args: 

927 name: A name for the operation (optional). 

928 

929 Returns: 

930 A `Tensor` of type `bool`. 

931 """ 

932 return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) 

933 

934 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 

935 """Subtracts a value from this variable. 

936 

937 Args: 

938 delta: A `Tensor`. The value to subtract from this variable. 

939 use_locking: If `True`, use locking during the operation. 

940 name: The name to use for the operation. 

941 read_value: A `bool`. Whether to read and return the new value of the 

942 variable or not. 

943 

944 Returns: 

945 If `read_value` is `True`, this method will return the new value of the 

946 variable after the assignment has completed. Otherwise, when in graph mode 

947 it will return the `Operation` that does the assignment, and when in eager 

948 mode it will return `None`. 

949 """ 

950 # TODO(apassos): this here and below is not atomic. Consider making it 

951 # atomic if there's a way to do so without a performance cost for those who 

952 # don't need it. 

953 with _handle_graph(self.handle), self._assign_dependencies(): 

954 assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( 

955 self.handle, 

956 ops.convert_to_tensor(delta, dtype=self.dtype), 

957 name=name) 

958 if read_value: 

959 return self._lazy_read(assign_sub_op) 

960 return assign_sub_op 

961 

962 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 

963 """Adds a value to this variable. 

964 

965 Args: 

966 delta: A `Tensor`. The value to add to this variable. 

967 use_locking: If `True`, use locking during the operation. 

968 name: The name to use for the operation. 

969 read_value: A `bool`. Whether to read and return the new value of the 

970 variable or not. 

971 

972 Returns: 

973 If `read_value` is `True`, this method will return the new value of the 

974 variable after the assignment has completed. Otherwise, when in graph mode 

975 it will return the `Operation` that does the assignment, and when in eager 

976 mode it will return `None`. 

977 """ 

978 with _handle_graph(self.handle), self._assign_dependencies(): 

979 assign_add_op = gen_resource_variable_ops.assign_add_variable_op( 

980 self.handle, 

981 ops.convert_to_tensor(delta, dtype=self.dtype), 

982 name=name) 

983 if read_value: 

984 return self._lazy_read(assign_add_op) 

985 return assign_add_op 

986 

987 def _lazy_read(self, op): 

988 variable_accessed(self) 

989 return _UnreadVariable( 

990 handle=self.handle, 

991 dtype=self.dtype, 

992 shape=self._shape, 

993 in_graph_mode=self._in_graph_mode, 

994 parent_op=op, 

995 unique_id=self._unique_id) 

996 

997 def assign(self, value, use_locking=None, name=None, read_value=True): 

998 """Assigns a new value to this variable. 

999 

1000 Args: 

1001 value: A `Tensor`. The new value for this variable. 

1002 use_locking: If `True`, use locking during the assignment. 

1003 name: The name to use for the assignment. 

1004 read_value: A `bool`. Whether to read and return the new value of the 

1005 variable or not. 

1006 

1007 Returns: 

1008 If `read_value` is `True`, this method will return the new value of the 

1009 variable after the assignment has completed. Otherwise, when in graph mode 

1010 it will return the `Operation` that does the assignment, and when in eager 

1011 mode it will return `None`. 

1012 """ 

1013 # Note: not depending on the cached value here since this can be used to 

1014 # initialize the variable. 

1015 with _handle_graph(self.handle): 

1016 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) 

1017 if not self._shape.is_compatible_with(value_tensor.shape): 

1018 if self.name is None: 

1019 tensor_name = "" 

1020 else: 

1021 tensor_name = " " + str(self.name) 

1022 raise ValueError( 

1023 (f"Cannot assign value to variable '{tensor_name}': Shape mismatch." 

1024 f"The variable shape {self._shape}, and the " 

1025 f"assigned value shape {value_tensor.shape} are incompatible.")) 

1026 kwargs = {} 

1027 if forward_compat.forward_compatible(2022, 3, 23): 

1028 # If the shape is fully defined, we do a runtime check with the shape of 

1029 # value. 

1030 validate_shape = self._validate_shape and self._shape.is_fully_defined() 

1031 kwargs["validate_shape"] = validate_shape 

1032 assign_op = gen_resource_variable_ops.assign_variable_op( 

1033 self.handle, value_tensor, name=name, **kwargs) 

1034 if read_value: 

1035 return self._lazy_read(assign_op) 

1036 return assign_op 

1037 

1038 def __reduce__(self): 

1039 # The implementation mirrors that of __deepcopy__. 

1040 return functools.partial( 

1041 ResourceVariable, 

1042 initial_value=self.numpy(), 

1043 trainable=self.trainable, 

1044 name=self._shared_name, 

1045 dtype=self.dtype, 

1046 constraint=self.constraint, 

1047 distribute_strategy=self._distribute_strategy), () 

1048 

1049 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

1050 """Subtracts `tf.IndexedSlices` from this variable. 

1051 

1052 Args: 

1053 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 

1054 use_locking: If `True`, use locking during the operation. 

1055 name: the name of the operation. 

1056 

1057 Returns: 

1058 The updated variable. 

1059 

1060 Raises: 

1061 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

1062 """ 

1063 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

1064 raise TypeError(f"Argument `sparse_delta` must be a " 

1065 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 

1066 return self._lazy_read( 

1067 gen_resource_variable_ops.resource_scatter_sub( 

1068 self.handle, 

1069 sparse_delta.indices, 

1070 ops.convert_to_tensor(sparse_delta.values, self.dtype), 

1071 name=name)) 

1072 

1073 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

1074 """Adds `tf.IndexedSlices` to this variable. 

1075 

1076 Args: 

1077 sparse_delta: `tf.IndexedSlices` to be added to this variable. 

1078 use_locking: If `True`, use locking during the operation. 

1079 name: the name of the operation. 

1080 

1081 Returns: 

1082 The updated variable. 

1083 

1084 Raises: 

1085 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

1086 """ 

1087 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

1088 raise TypeError(f"Argument `sparse_delta` must be a " 

1089 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 

1090 return self._lazy_read( 

1091 gen_resource_variable_ops.resource_scatter_add( 

1092 self.handle, 

1093 sparse_delta.indices, 

1094 ops.convert_to_tensor(sparse_delta.values, self.dtype), 

1095 name=name)) 

1096 

1097 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

1098 """Updates this variable with the max of `tf.IndexedSlices` and itself. 

1099 

1100 Args: 

1101 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 

1102 variable. 

1103 use_locking: If `True`, use locking during the operation. 

1104 name: the name of the operation. 

1105 

1106 Returns: 

1107 The updated variable. 

1108 

1109 Raises: 

1110 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

1111 """ 

1112 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

1113 raise TypeError(f"Argument `sparse_delta` must be a " 

1114 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 

1115 return self._lazy_read( 

1116 gen_resource_variable_ops.resource_scatter_max( 

1117 self.handle, 

1118 sparse_delta.indices, 

1119 ops.convert_to_tensor(sparse_delta.values, self.dtype), 

1120 name=name)) 

1121 

1122 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

1123 """Updates this variable with the min of `tf.IndexedSlices` and itself. 

1124 

1125 Args: 

1126 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 

1127 variable. 

1128 use_locking: If `True`, use locking during the operation. 

1129 name: the name of the operation. 

1130 

1131 Returns: 

1132 The updated variable. 

1133 

1134 Raises: 

1135 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

1136 """ 

1137 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

1138 raise TypeError(f"Argument `sparse_delta` must be a " 

1139 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 

1140 return self._lazy_read( 

1141 gen_resource_variable_ops.resource_scatter_min( 

1142 self.handle, 

1143 sparse_delta.indices, 

1144 ops.convert_to_tensor(sparse_delta.values, self.dtype), 

1145 name=name)) 

1146 

1147 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

1148 """Multiply this variable by `tf.IndexedSlices`. 

1149 

1150 Args: 

1151 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 

1152 use_locking: If `True`, use locking during the operation. 

1153 name: the name of the operation. 

1154 

1155 Returns: 

1156 The updated variable. 

1157 

1158 Raises: 

1159 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

1160 """ 

1161 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

1162 raise TypeError(f"Argument `sparse_delta` must be a " 

1163 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 

1164 return self._lazy_read( 

1165 gen_resource_variable_ops.resource_scatter_mul( 

1166 self.handle, 

1167 sparse_delta.indices, 

1168 ops.convert_to_tensor(sparse_delta.values, self.dtype), 

1169 name=name)) 

1170 

1171 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

1172 """Divide this variable by `tf.IndexedSlices`. 

1173 

1174 Args: 

1175 sparse_delta: `tf.IndexedSlices` to divide this variable by. 

1176 use_locking: If `True`, use locking during the operation. 

1177 name: the name of the operation. 

1178 

1179 Returns: 

1180 The updated variable. 

1181 

1182 Raises: 

1183 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

1184 """ 

1185 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

1186 raise TypeError(f"Argument `sparse_delta` must be a " 

1187 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 

1188 return self._lazy_read( 

1189 gen_resource_variable_ops.resource_scatter_div( 

1190 self.handle, 

1191 sparse_delta.indices, 

1192 ops.convert_to_tensor(sparse_delta.values, self.dtype), 

1193 name=name)) 

1194 

1195 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

1196 """Assigns `tf.IndexedSlices` to this variable. 

1197 

1198 Args: 

1199 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 

1200 use_locking: If `True`, use locking during the operation. 

1201 name: the name of the operation. 

1202 

1203 Returns: 

1204 The updated variable. 

1205 

1206 Raises: 

1207 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

1208 """ 

1209 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

1210 raise TypeError(f"Argument `sparse_delta` must be a " 

1211 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 

1212 return self._lazy_read( 

1213 gen_resource_variable_ops.resource_scatter_update( 

1214 self.handle, 

1215 sparse_delta.indices, 

1216 ops.convert_to_tensor(sparse_delta.values, self.dtype), 

1217 name=name)) 

1218 

1219 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 

1220 """Assigns `tf.IndexedSlices` to this variable batch-wise. 

1221 

1222 Analogous to `batch_gather`. This assumes that this variable and the 

1223 sparse_delta IndexedSlices have a series of leading dimensions that are the 

1224 same for all of them, and the updates are performed on the last dimension of 

1225 indices. In other words, the dimensions should be the following: 

1226 

1227 `num_prefix_dims = sparse_delta.indices.ndims - 1` 

1228 `batch_dim = num_prefix_dims + 1` 

1229 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 

1230 batch_dim:]` 

1231 

1232 where 

1233 

1234 `sparse_delta.updates.shape[:num_prefix_dims]` 

1235 `== sparse_delta.indices.shape[:num_prefix_dims]` 

1236 `== var.shape[:num_prefix_dims]` 

1237 

1238 And the operation performed can be expressed as: 

1239 

1240 `var[i_1, ..., i_n, 

1241 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 

1242 i_1, ..., i_n, j]` 

1243 

1244 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 

1245 `scatter_update`. 

1246 

1247 To avoid this operation one can looping over the first `ndims` of the 

1248 variable and using `scatter_update` on the subtensors that result of slicing 

1249 the first dimension. This is a valid option for `ndims = 1`, but less 

1250 efficient than this implementation. 

1251 

1252 Args: 

1253 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 

1254 use_locking: If `True`, use locking during the operation. 

1255 name: the name of the operation. 

1256 

1257 Returns: 

1258 The updated variable. 

1259 

1260 Raises: 

1261 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

1262 """ 

1263 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

1264 raise TypeError(f"Argument `sparse_delta` must be a " 

1265 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 

1266 return self._lazy_read( 

1267 state_ops.batch_scatter_update( 

1268 self, 

1269 sparse_delta.indices, 

1270 sparse_delta.values, 

1271 use_locking=use_locking, 

1272 name=name)) 

1273 

1274 def scatter_nd_sub(self, indices, updates, name=None): 

1275 """Applies sparse subtraction to individual values or slices in a Variable. 

1276 

1277 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

1278 

1279 `indices` must be integer tensor, containing indices into `ref`. 

1280 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

1281 

1282 The innermost dimension of `indices` (with length `K`) corresponds to 

1283 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

1284 dimension of `ref`. 

1285 

1286 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

1287 

1288 ``` 

1289 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

1290 ``` 

1291 

1292 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

1293 8 elements. In Python, that update would look like this: 

1294 

1295 ```python 

1296 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

1297 indices = tf.constant([[4], [3], [1] ,[7]]) 

1298 updates = tf.constant([9, 10, 11, 12]) 

1299 op = ref.scatter_nd_sub(indices, updates) 

1300 with tf.compat.v1.Session() as sess: 

1301 print sess.run(op) 

1302 ``` 

1303 

1304 The resulting update to ref would look like this: 

1305 

1306 [1, -9, 3, -6, -6, 6, 7, -4] 

1307 

1308 See `tf.scatter_nd` for more details about how to make updates to 

1309 slices. 

1310 

1311 Args: 

1312 indices: The indices to be used in the operation. 

1313 updates: The values to be used in the operation. 

1314 name: the name of the operation. 

1315 

1316 Returns: 

1317 The updated variable. 

1318 """ 

1319 return self._lazy_read( 

1320 gen_state_ops.resource_scatter_nd_sub( 

1321 self.handle, 

1322 indices, 

1323 ops.convert_to_tensor(updates, self.dtype), 

1324 name=name)) 

1325 

1326 def scatter_nd_add(self, indices, updates, name=None): 

1327 """Applies sparse addition to individual values or slices in a Variable. 

1328 

1329 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

1330 

1331 `indices` must be integer tensor, containing indices into `ref`. 

1332 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

1333 

1334 The innermost dimension of `indices` (with length `K`) corresponds to 

1335 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

1336 dimension of `ref`. 

1337 

1338 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

1339 

1340 ``` 

1341 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

1342 ``` 

1343 

1344 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

1345 8 elements. In Python, that update would look like this: 

1346 

1347 ```python 

1348 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

1349 indices = tf.constant([[4], [3], [1] ,[7]]) 

1350 updates = tf.constant([9, 10, 11, 12]) 

1351 add = ref.scatter_nd_add(indices, updates) 

1352 with tf.compat.v1.Session() as sess: 

1353 print sess.run(add) 

1354 ``` 

1355 

1356 The resulting update to ref would look like this: 

1357 

1358 [1, 13, 3, 14, 14, 6, 7, 20] 

1359 

1360 See `tf.scatter_nd` for more details about how to make updates to 

1361 slices. 

1362 

1363 Args: 

1364 indices: The indices to be used in the operation. 

1365 updates: The values to be used in the operation. 

1366 name: the name of the operation. 

1367 

1368 Returns: 

1369 The updated variable. 

1370 """ 

1371 return self._lazy_read( 

1372 gen_state_ops.resource_scatter_nd_add( 

1373 self.handle, 

1374 indices, 

1375 ops.convert_to_tensor(updates, self.dtype), 

1376 name=name)) 

1377 

1378 def scatter_nd_update(self, indices, updates, name=None): 

1379 """Applies sparse assignment to individual values or slices in a Variable. 

1380 

1381 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

1382 

1383 `indices` must be integer tensor, containing indices into `ref`. 

1384 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

1385 

1386 The innermost dimension of `indices` (with length `K`) corresponds to 

1387 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

1388 dimension of `ref`. 

1389 

1390 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

1391 

1392 ``` 

1393 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

1394 ``` 

1395 

1396 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

1397 8 elements. In Python, that update would look like this: 

1398 

1399 ```python 

1400 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

1401 indices = tf.constant([[4], [3], [1] ,[7]]) 

1402 updates = tf.constant([9, 10, 11, 12]) 

1403 op = ref.scatter_nd_update(indices, updates) 

1404 with tf.compat.v1.Session() as sess: 

1405 print sess.run(op) 

1406 ``` 

1407 

1408 The resulting update to ref would look like this: 

1409 

1410 [1, 11, 3, 10, 9, 6, 7, 12] 

1411 

1412 See `tf.scatter_nd` for more details about how to make updates to 

1413 slices. 

1414 

1415 Args: 

1416 indices: The indices to be used in the operation. 

1417 updates: The values to be used in the operation. 

1418 name: the name of the operation. 

1419 

1420 Returns: 

1421 The updated variable. 

1422 """ 

1423 return self._lazy_read( 

1424 gen_state_ops.resource_scatter_nd_update( 

1425 self.handle, 

1426 indices, 

1427 ops.convert_to_tensor(updates, self.dtype), 

1428 name=name)) 

1429 

1430 def scatter_nd_max(self, indices, updates, name=None): 

1431 """Updates this variable with the max of `tf.IndexedSlices` and itself. 

1432 

1433 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

1434 

1435 `indices` must be integer tensor, containing indices into `ref`. 

1436 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

1437 

1438 The innermost dimension of `indices` (with length `K`) corresponds to 

1439 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

1440 dimension of `ref`. 

1441 

1442 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

1443 

1444 ``` 

1445 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

1446 ``` 

1447 

1448 See `tf.scatter_nd` for more details about how to make updates to 

1449 slices. 

1450 

1451 Args: 

1452 indices: The indices to be used in the operation. 

1453 updates: The values to be used in the operation. 

1454 name: the name of the operation. 

1455 

1456 Returns: 

1457 The updated variable. 

1458 """ 

1459 return self._lazy_read( 

1460 gen_state_ops.resource_scatter_nd_max( 

1461 self.handle, 

1462 indices, 

1463 ops.convert_to_tensor(updates, self.dtype), 

1464 name=name)) 

1465 

1466 def scatter_nd_min(self, indices, updates, name=None): 

1467 """Updates this variable with the min of `tf.IndexedSlices` and itself. 

1468 

1469 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

1470 

1471 `indices` must be integer tensor, containing indices into `ref`. 

1472 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

1473 

1474 The innermost dimension of `indices` (with length `K`) corresponds to 

1475 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

1476 dimension of `ref`. 

1477 

1478 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

1479 

1480 ``` 

1481 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

1482 ``` 

1483 

1484 See `tf.scatter_nd` for more details about how to make updates to 

1485 slices. 

1486 

1487 Args: 

1488 indices: The indices to be used in the operation. 

1489 updates: The values to be used in the operation. 

1490 name: the name of the operation. 

1491 

1492 Returns: 

1493 The updated variable. 

1494 """ 

1495 return self._lazy_read( 

1496 gen_state_ops.resource_scatter_nd_min( 

1497 self.handle, 

1498 indices, 

1499 ops.convert_to_tensor(updates, self.dtype), 

1500 name=name)) 

1501 

1502 def _write_object_proto(self, proto, options): 

1503 """Writes additional information of the variable into the SavedObject proto. 

1504 

1505 Subclasses of ResourceVariables could choose to override this method to 

1506 customize extra information to provide when saving a SavedModel. 

1507 

1508 Ideally, this should contain the logic in 

1509 write_object_proto_for_resource_variable but `DistributedValue` is an 

1510 outlier at the momemnt. Once `DistributedValue` becomes a proper 

1511 ResourceVariable, we should remove the helper method below. 

1512 

1513 Args: 

1514 proto: `SavedObject` proto to update. 

1515 options: A `SaveOption` instance that configures save behavior. 

1516 """ 

1517 write_object_proto_for_resource_variable(self, proto, options) 

1518 

1519 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 

1520 end_mask, ellipsis_mask, new_axis_mask, 

1521 shrink_axis_mask): 

1522 with _handle_graph(self.handle), self._assign_dependencies(): 

1523 return self._lazy_read( 

1524 gen_array_ops.resource_strided_slice_assign( 

1525 ref=self.handle, 

1526 begin=begin, 

1527 end=end, 

1528 strides=strides, 

1529 value=ops.convert_to_tensor(value, dtype=self.dtype), 

1530 name=name, 

1531 begin_mask=begin_mask, 

1532 end_mask=end_mask, 

1533 ellipsis_mask=ellipsis_mask, 

1534 new_axis_mask=new_axis_mask, 

1535 shrink_axis_mask=shrink_axis_mask)) 

1536 

1537 def __complex__(self): 

1538 return complex(self.value().numpy()) 

1539 

1540 def __int__(self): 

1541 return int(self.value().numpy()) 

1542 

1543 def __long__(self): 

1544 return long(self.value().numpy()) 

1545 

1546 def __float__(self): 

1547 return float(self.value().numpy()) 

1548 

1549 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

1550 del name 

1551 if dtype is not None and not dtype.is_compatible_with(self.dtype): 

1552 raise ValueError( 

1553 f"Incompatible type conversion requested to type {dtype.name} for " 

1554 f"`tf.Variable of type {self.dtype.name}. (Variable: {self})") 

1555 if as_ref: 

1556 return self.read_value().op.inputs[0] 

1557 else: 

1558 return self.value() 

1559 

1560 def __iadd__(self, unused_other): 

1561 raise RuntimeError("`variable += value` with `tf.Variable`s is not " 

1562 "supported. Use `variable.assign_add(value)` to modify " 

1563 "the variable, or `out = variable + value` if you " 

1564 "need to get a new output Tensor.") 

1565 

1566 def __isub__(self, unused_other): 

1567 raise RuntimeError("`variable -= value` with `tf.Variable`s is not " 

1568 "supported. Use `variable.assign_sub(value)` to modify " 

1569 "the variable, or `out = variable * value` if you " 

1570 "need to get a new output Tensor.") 

1571 

1572 def __imul__(self, unused_other): 

1573 raise RuntimeError("`var *= value` with `tf.Variable`s is not " 

1574 "supported. Use `var.assign(var * value)` to modify " 

1575 "the variable, or `out = var * value` if you " 

1576 "need to get a new output Tensor.") 

1577 

1578 def __idiv__(self, unused_other): 

1579 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 

1580 "supported. Use `var.assign(var / value)` to modify " 

1581 "the variable, or `out = var / value` if you " 

1582 "need to get a new output Tensor.") 

1583 

1584 def __itruediv__(self, unused_other): 

1585 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 

1586 "supported. Use `var.assign(var / value)` to modify " 

1587 "the variable, or `out = var / value` if you " 

1588 "need to get a new output Tensor.") 

1589 

1590 def __irealdiv__(self, unused_other): 

1591 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 

1592 "supported. Use `var.assign(var / value)` to modify " 

1593 "the variable, or `out = var / value` if you " 

1594 "need to get a new output Tensor.") 

1595 

1596 def __ipow__(self, unused_other): 

1597 raise RuntimeError("`var **= value` with `tf.Variable`s is not " 

1598 "supported. Use `var.assign(var ** value)` to modify " 

1599 "the variable, or `out = var ** value` if you " 

1600 "need to get a new output Tensor.") 

1601 

1602 

1603class ResourceVariableGradient( 

1604 composite_tensor_gradient.CompositeTensorGradient): 

1605 """CompositeTensorGradient protocol for ResourceVariable.""" 

1606 

1607 # TODO(b/246997907): update this method to return value.handle. 

1608 def get_gradient_components(self, value): 

1609 """Returns the components of `value` that should be included in gradients. 

1610 

1611 For a ResourceVariable, its gradient component is its handle tensor. 

1612 For now, we return the ResourceVariable because the gradient infrastructure 

1613 has special logics to handle ResourceVariables. We should remove those 

1614 special logics and return the handle tensor. 

1615 

1616 Args: 

1617 value: A `ResourceVariable`. 

1618 

1619 Returns: 

1620 `value` itself. 

1621 """ 

1622 return value 

1623 

1624 def replace_gradient_components(self, value, component_grads): 

1625 """Replaces the gradient components in `value` with `component_grads`. 

1626 

1627 The gradient of a ResourceVariable is either None or a Tensor. So we don't 

1628 need `value`'s TypeSpec or non-gradient components in this method. 

1629 

1630 Args: 

1631 value: A `ResourceVariable` with its gradient components compatible with 

1632 `component_grads`. 

1633 component_grads: A `Tensor` or None as the gradient result. 

1634 

1635 Returns: 

1636 The `component_grads`, which is either a `Tensor` or None. 

1637 """ 

1638 return component_grads 

1639 

1640 

1641class ResourceVariable(BaseResourceVariable, composite_tensor.CompositeTensor): 

1642 """Variable based on resource handles. 

1643 

1644 See the [Variables How To](https://tensorflow.org/guide/variables) 

1645 for a high level overview. 

1646 

1647 A `ResourceVariable` allows you to maintain state across subsequent calls to 

1648 session.run. 

1649 

1650 The `ResourceVariable` constructor requires an initial value for the variable, 

1651 which can be a `Tensor` of any type and shape. The initial value defines the 

1652 type and shape of the variable. After construction, the type and shape of 

1653 the variable are fixed. The value can be changed using one of the assign 

1654 methods. 

1655 

1656 Just like any `Tensor`, variables created with 

1657 `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the 

1658 graph. Additionally, all the operators overloaded for the `Tensor` class are 

1659 carried over to variables, so you can also add nodes to the graph by just 

1660 doing arithmetic on variables. 

1661 

1662 Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each 

1663 usage of a ResourceVariable in a TensorFlow graph adds a read_value operation 

1664 to the graph. The Tensors returned by a read_value operation are guaranteed to 

1665 see all modifications to the value of the variable which happen in any 

1666 operation on which the read_value depends on (either directly, indirectly, or 

1667 via a control dependency) and guaranteed to not see any modification to the 

1668 value of the variable from operations that depend on the read_value operation. 

1669 Updates from operations that have no dependency relationship to the read_value 

1670 operation might or might not be visible to read_value. 

1671 

1672 For example, if there is more than one assignment to a ResourceVariable in 

1673 a single session.run call there is a well-defined value for each operation 

1674 which uses the variable's value if the assignments and the read are connected 

1675 by edges in the graph. Consider the following example, in which two writes 

1676 can cause tf.Variable and tf.ResourceVariable to behave differently: 

1677 

1678 ```python 

1679 a = tf.Variable(1.0, use_resource=True) 

1680 a.initializer.run() 

1681 

1682 assign = a.assign(2.0) 

1683 with tf.control_dependencies([assign]): 

1684 b = a.read_value() 

1685 with tf.control_dependencies([b]): 

1686 other_assign = a.assign(3.0) 

1687 with tf.control_dependencies([other_assign]): 

1688 # Will print 2.0 because the value was read before other_assign ran. If 

1689 # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. 

1690 tf.compat.v1.Print(b, [b]).eval() 

1691 ``` 

1692 """ 

1693 

1694 def __init__( 

1695 self, # pylint: disable=super-init-not-called 

1696 initial_value=None, 

1697 trainable=None, 

1698 collections=None, 

1699 validate_shape=True, # pylint: disable=unused-argument 

1700 caching_device=None, 

1701 name=None, 

1702 dtype=None, 

1703 variable_def=None, 

1704 import_scope=None, 

1705 constraint=None, 

1706 distribute_strategy=None, 

1707 synchronization=None, 

1708 aggregation=None, 

1709 shape=None, 

1710 handle=None, 

1711 experimental_enable_variable_lifting=None, 

1712 ): 

1713 """Creates a variable. 

1714 

1715 Args: 

1716 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

1717 which is the initial value for the Variable. Can also be a callable with 

1718 no argument that returns the initial value when called. (Note that 

1719 initializer functions from init_ops.py must first be bound to a shape 

1720 before being used here.) 

1721 trainable: If `True`, the default, also adds the variable to the graph 

1722 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 

1723 the default list of variables to use by the `Optimizer` classes. 

1724 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in 

1725 which case it defaults to `False`. 

1726 collections: List of graph collections keys. The new variable is added to 

1727 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

1728 validate_shape: If `False`, allows the variable to be initialized with a 

1729 value of unknown shape. If `True`, the default, the shape of 

1730 `initial_value` must be known. 

1731 caching_device: Optional device string or function describing where the 

1732 Variable should be cached for reading. Defaults to the Variable's 

1733 device. If not `None`, caches on another device. Typical use is to 

1734 cache on the device where the Ops using the Variable reside, to 

1735 deduplicate copying through `Switch` and other conditional statements. 

1736 name: Optional name for the variable. Defaults to `'Variable'` and gets 

1737 uniquified automatically. 

1738 dtype: If set, initial_value will be converted to the given type. If None, 

1739 either the datatype will be kept (if initial_value is a Tensor) or 

1740 float32 will be used (if it is a Python object convertible to a Tensor). 

1741 variable_def: `VariableDef` protocol buffer. If not None, recreates the 

1742 `ResourceVariable` object with its contents. `variable_def` and other 

1743 arguments (except for import_scope) are mutually exclusive. 

1744 import_scope: Optional `string`. Name scope to add to the 

1745 ResourceVariable. Only used when `variable_def` is provided. 

1746 constraint: An optional projection function to be applied to the variable 

1747 after being updated by an `Optimizer` (e.g. used to implement norm 

1748 constraints or value constraints for layer weights). The function must 

1749 take as input the unprojected Tensor representing the value of the 

1750 variable and return the Tensor for the projected value (which must have 

1751 the same shape). Constraints are not safe to use when doing asynchronous 

1752 distributed training. 

1753 distribute_strategy: The tf.distribute.Strategy this variable is being 

1754 created inside of. 

1755 synchronization: Indicates when a distributed a variable will be 

1756 aggregated. Accepted values are constants defined in the class 

1757 `tf.VariableSynchronization`. By default the synchronization is set to 

1758 `AUTO` and the current `DistributionStrategy` chooses when to 

1759 synchronize. 

1760 aggregation: Indicates how a distributed variable will be aggregated. 

1761 Accepted values are constants defined in the class 

1762 `tf.VariableAggregation`. 

1763 shape: (optional) The shape of this variable. If None, the shape of 

1764 `initial_value` will be used. When setting this argument to 

1765 `tf.TensorShape(None)` (representing an unspecified shape), the variable 

1766 can be assigned with values of different shapes. 

1767 handle: (optional) The handle of a `tf.Variable`. If provided, only 

1768 `trainable`, `shape`, `dtype`, and `handle` will be used to construct 

1769 this `tf.Variable`. 

1770 experimental_enable_variable_lifting: Whether to lift the variable out if 

1771 it's in a `tf.function`. Default is `True`. When this argument 

1772 is `True`, variable creation will follow the behavior and 

1773 restrictions described 

1774 [here](https://www.tensorflow.org/guide/function#creating_tfvariables). 

1775 If this argument is `False`, that description doesn't apply, 

1776 and you can freely create and use the variable in the 

1777 `tf.function`, as if it's a "mutable `tf.Tensor`". You can't 

1778 return the variable though. 

1779 

1780 Raises: 

1781 ValueError: If the initial value is not specified, or does not have a 

1782 shape and `validate_shape` is `True`. 

1783 

1784 @compatibility(eager) 

1785 When Eager Execution is enabled, the default for the `collections` argument 

1786 is `None`, which signifies that this `Variable` will not be added to any 

1787 collections. 

1788 @end_compatibility 

1789 """ 

1790 if variable_def: 

1791 if initial_value is not None: 

1792 raise ValueError(f"The variable_def and initial_value args to " 

1793 f"`tf.Variable` are mutually exclusive, but got both: " 

1794 f"variable_def={variable_def},\n" 

1795 f"initial_value={initial_value}") 

1796 if context.executing_eagerly(): 

1797 raise ValueError(f"Creating a `tf.Variable` with a `variable_def` arg " 

1798 f"is not supported when eager execution is enabled. " 

1799 f"Got: variable_def={variable_def}") 

1800 self._init_from_proto( 

1801 variable_def, 

1802 import_scope=import_scope, 

1803 validate_shape=validate_shape) 

1804 elif handle is not None: 

1805 self._init_from_handle(trainable=trainable, 

1806 shape=shape, 

1807 dtype=dtype, 

1808 handle=handle) 

1809 else: 

1810 self._init_from_args( 

1811 initial_value=initial_value, 

1812 trainable=trainable, 

1813 collections=collections, 

1814 caching_device=caching_device, 

1815 name=name, 

1816 dtype=dtype, 

1817 constraint=constraint, 

1818 synchronization=synchronization, 

1819 aggregation=aggregation, 

1820 shape=shape, 

1821 distribute_strategy=distribute_strategy, 

1822 validate_shape=validate_shape, 

1823 experimental_enable_variable_lifting=experimental_enable_variable_lifting, 

1824 ) 

1825 

1826 # CompositeTensor method 

1827 @property 

1828 def _type_spec(self): 

1829 return VariableSpec.from_value(self) 

1830 

1831 # CompositeTensor method 

1832 def _shape_invariant_to_type_spec(self, shape): 

1833 return VariableSpec(shape, self.dtype, self.trainable) 

1834 

1835 # CompositeTensorGradient protocol 

1836 __composite_gradient__ = ResourceVariableGradient() 

1837 

1838 def _init_from_args( 

1839 self, 

1840 initial_value=None, 

1841 trainable=None, 

1842 collections=None, 

1843 caching_device=None, 

1844 name=None, 

1845 dtype=None, 

1846 constraint=None, 

1847 synchronization=None, 

1848 aggregation=None, 

1849 distribute_strategy=None, 

1850 shape=None, 

1851 validate_shape=True, 

1852 experimental_enable_variable_lifting=None, 

1853 ): 

1854 """Creates a variable. 

1855 

1856 Args: 

1857 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

1858 which is the initial value for the Variable. The initial value must have 

1859 a shape specified unless `validate_shape` is set to False. Can also be a 

1860 callable with no argument that returns the initial value when called. 

1861 (Note that initializer functions from init_ops.py must first be bound to 

1862 a shape before being used here.) 

1863 trainable: If `True`, the default, also adds the variable to the graph 

1864 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 

1865 the default list of variables to use by the `Optimizer` classes. 

1866 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in 

1867 which case it defaults to `False`. 

1868 collections: List of graph collections keys. The new variable is added to 

1869 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

1870 caching_device: Optional device string or function describing where the 

1871 Variable should be cached for reading. Defaults to the Variable's 

1872 device. If not `None`, caches on another device. Typical use is to 

1873 cache on the device where the Ops using the Variable reside, to 

1874 deduplicate copying through `Switch` and other conditional statements. 

1875 name: Optional name for the variable. Defaults to `'Variable'` and gets 

1876 uniquified automatically. 

1877 dtype: If set, initial_value will be converted to the given type. If None, 

1878 either the datatype will be kept (if initial_value is a Tensor) or 

1879 float32 will be used (if it is a Python object convertible to a Tensor). 

1880 constraint: An optional projection function to be applied to the variable 

1881 after being updated by an `Optimizer` (e.g. used to implement norm 

1882 constraints or value constraints for layer weights). The function must 

1883 take as input the unprojected Tensor representing the value of the 

1884 variable and return the Tensor for the projected value (which must have 

1885 the same shape). Constraints are not safe to use when doing asynchronous 

1886 distributed training. 

1887 synchronization: Indicates when a distributed a variable will be 

1888 aggregated. Accepted values are constants defined in the class 

1889 `tf.VariableSynchronization`. By default the synchronization is set to 

1890 `AUTO` and the current `DistributionStrategy` chooses when to 

1891 synchronize. 

1892 aggregation: Indicates how a distributed variable will be aggregated. 

1893 Accepted values are constants defined in the class 

1894 `tf.VariableAggregation`. 

1895 distribute_strategy: DistributionStrategy under which this variable was 

1896 created. 

1897 shape: (optional) The shape of this variable. If None, the shape of 

1898 `initial_value` will be used. When setting this argument to 

1899 `tf.TensorShape(None)` (representing an unspecified shape), the variable 

1900 can be assigned with values of different shapes. 

1901 validate_shape: If `False`, allows the variable to be initialized with a 

1902 value of unknown shape. If `True`, the default, the shape of 

1903 `initial_value` must be known. 

1904 experimental_enable_variable_lifting: Whether to lift the variable out if 

1905 it's in a `tf.function`. Default is `True`. When this argument 

1906 is `True`, variable creation will follow the behavior and 

1907 restrictions described 

1908 [here](https://www.tensorflow.org/guide/function#creating_tfvariables). 

1909 If this argument is `False`, that description doesn't apply, 

1910 and you can freely create and use the variable in the 

1911 `tf.function`, as if it's a "mutable `tf.Tensor`". You can't 

1912 return the variable though. 

1913 

1914 Raises: 

1915 ValueError: If the initial value is not specified, or does not have a 

1916 shape and `validate_shape` is `True`. 

1917 

1918 @compatibility(eager) 

1919 When Eager Execution is enabled, variables are never added to collections. 

1920 It is not implicitly added to the `GLOBAL_VARIABLES` or 

1921 `TRAINABLE_VARIABLES` collections, and the `collections` argument is 

1922 ignored. 

1923 @end_compatibility 

1924 """ 

1925 synchronization, aggregation, trainable = ( 

1926 variables.validate_synchronization_aggregation_trainable( 

1927 synchronization, aggregation, trainable, name)) 

1928 if experimental_enable_variable_lifting is None: 

1929 experimental_enable_variable_lifting = True 

1930 if initial_value is None: 

1931 raise ValueError("The `initial_value` arg to `tf.Variable` must " 

1932 "be specified except when you are not providing a " 

1933 "`variable_def`. You provided neither.") 

1934 init_from_fn = callable(initial_value) 

1935 

1936 if isinstance(initial_value, ops.Tensor) and hasattr( 

1937 initial_value, "graph") and initial_value.graph.building_function: 

1938 raise ValueError(f"Argument `initial_value` ({initial_value}) could not " 

1939 "be lifted out of a `tf.function`. " 

1940 f"(Tried to create variable with name='{name}'). " 

1941 "To avoid this error, when constructing `tf.Variable`s " 

1942 "inside of `tf.function` you can create the " 

1943 "`initial_value` tensor in a " 

1944 "`tf.init_scope` or pass a callable `initial_value` " 

1945 "(e.g., `tf.Variable(lambda : " 

1946 "tf.truncated_normal([10, 40]))`). " 

1947 "Please file a feature request if this " 

1948 "restriction inconveniences you.") 

1949 

1950 if collections is None: 

1951 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 

1952 if not isinstance(collections, (list, tuple, set)): 

1953 raise ValueError( 

1954 f"collections argument to Variable constructor must be a list, " 

1955 f"tuple, or set. Got {collections} of type {type(collections)}") 

1956 if constraint is not None and not callable(constraint): 

1957 raise ValueError(f"Argument `constraint` must be None or a callable. " 

1958 f"a callable. Got a {type(constraint)}: {constraint}") 

1959 

1960 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 

1961 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 

1962 with ops.init_scope(): 

1963 self._in_graph_mode = not context.executing_eagerly() 

1964 if experimental_enable_variable_lifting: 

1965 maybe_init_scope = ops.init_scope 

1966 else: 

1967 maybe_init_scope = contextlib.nullcontext 

1968 with maybe_init_scope(): 

1969 with ops.name_scope( 

1970 name, 

1971 "Variable", [] if init_from_fn else [initial_value], 

1972 skip_on_eager=False) as name: 

1973 # pylint: disable=protected-access 

1974 handle_name = ops.name_from_scope_name(name) 

1975 if self._in_graph_mode: 

1976 shared_name = handle_name 

1977 unique_id = shared_name 

1978 else: 

1979 # When in eager mode, use a uid for the shared_name, to prevent 

1980 # accidental sharing. 

1981 unique_id = "%s_%d" % (handle_name, ops.uid()) 

1982 shared_name = None # Never shared 

1983 # Use attr_scope and device(None) to simulate the behavior of 

1984 # colocate_with when the variable we want to colocate with doesn't 

1985 # yet exist. 

1986 device_context_manager = ( 

1987 ops.device if self._in_graph_mode else ops.NullContextmanager) 

1988 attr = attr_value_pb2.AttrValue( 

1989 list=attr_value_pb2.AttrValue.ListValue( 

1990 s=[compat.as_bytes("loc:@%s" % handle_name)])) 

1991 with ops.get_default_graph()._attr_scope({"_class": attr}): 

1992 with ops.name_scope("Initializer"), device_context_manager(None): 

1993 if init_from_fn: 

1994 initial_value = initial_value() 

1995 if isinstance(initial_value, trackable.CheckpointInitialValue): 

1996 self._maybe_initialize_trackable() 

1997 self._update_uid = initial_value.checkpoint_position.restore_uid 

1998 initial_value = initial_value.wrapped_value 

1999 initial_value = ops.convert_to_tensor( 

2000 initial_value, name="initial_value", dtype=dtype) 

2001 if shape is not None: 

2002 if not initial_value.shape.is_compatible_with(shape): 

2003 raise ValueError( 

2004 f"In this `tf.Variable` creation, the initial value's shape " 

2005 f"({initial_value.shape}) is not compatible with " 

2006 f"the explicitly supplied `shape` argument ({shape}).") 

2007 else: 

2008 shape = initial_value.shape 

2009 handle = eager_safe_variable_handle( 

2010 initial_value=initial_value, 

2011 shape=shape, 

2012 shared_name=shared_name, 

2013 name=name, 

2014 graph_mode=self._in_graph_mode) 

2015 handle._parent_trackable = weakref.ref(self) 

2016 handle._name = handle_name + ":0" 

2017 handle._unique_id = unique_id 

2018 # pylint: disable=protected-access 

2019 if (self._in_graph_mode and initial_value is not None and 

2020 initial_value.op._get_control_flow_context() is not None): 

2021 raise ValueError( 

2022 f"The `initial_value` passed to `tf.Variable` {name} is from " 

2023 f"inside a control-flow construct, such as a loop or " 

2024 f"conditional. When creating a " 

2025 f"`tf.Variable` inside a loop or conditional, use a lambda as " 

2026 f"the `initial_value`. Got: initial_value=({initial_value})") 

2027 # pylint: enable=protected-access 

2028 dtype = initial_value.dtype.base_dtype 

2029 

2030 if self._in_graph_mode: 

2031 with ops.name_scope("IsInitialized"): 

2032 is_initialized_op = ( 

2033 gen_resource_variable_ops.var_is_initialized_op(handle)) 

2034 if initial_value is not None: 

2035 # pylint: disable=g-backslash-continuation 

2036 with ops.name_scope("Assign") as n, \ 

2037 ops.colocate_with(None, ignore_existing=True), \ 

2038 ops.device(handle.device): 

2039 # pylint: disable=protected-access 

2040 initializer_op = ( 

2041 gen_resource_variable_ops.assign_variable_op( 

2042 handle, 

2043 variables._try_guard_against_uninitialized_dependencies( 

2044 name, initial_value), 

2045 name=n)) 

2046 # pylint: enable=protected-access 

2047 # pylint: enable=g-backslash-continuation 

2048 with ops.name_scope("Read"): 

2049 # Manually assign reads to the handle's device to avoid log 

2050 # messages. 

2051 with ops.device(handle.device): 

2052 value = gen_resource_variable_ops.read_variable_op(handle, dtype) 

2053 _maybe_set_handle_data(dtype, handle, value) 

2054 graph_element = value 

2055 if caching_device is not None: 

2056 # Variables may be created in a tf.device() or ops.colocate_with() 

2057 # context. At the same time, users would expect caching device to 

2058 # be independent of this context, and/or would not expect the 

2059 # current device context to be merged with the caching device 

2060 # spec. Therefore we reset the colocation stack before creating 

2061 # the cached value. Note that resetting the colocation stack will 

2062 # also reset the device stack. 

2063 with ops.colocate_with(None, ignore_existing=True): 

2064 with ops.device(caching_device): 

2065 cached_value = array_ops.identity(value) 

2066 else: 

2067 cached_value = None 

2068 else: 

2069 gen_resource_variable_ops.assign_variable_op(handle, initial_value) 

2070 is_initialized_op = None 

2071 initializer_op = None 

2072 graph_element = None 

2073 if caching_device: 

2074 with ops.device(caching_device): 

2075 cached_value = gen_resource_variable_ops.read_variable_op( 

2076 handle, dtype) 

2077 _maybe_set_handle_data(dtype, handle, cached_value) 

2078 else: 

2079 cached_value = None 

2080 

2081 if cached_value is not None: 

2082 # Store the variable object so that the original variable can be 

2083 # accessed to generate functions that are compatible with SavedModel. 

2084 cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access 

2085 

2086 if self._in_graph_mode: 

2087 # Eager variables are only added to collections if they are part of an 

2088 # eager variable store (otherwise in an interactive session they would 

2089 # hog memory and cause OOM). This is done in ops/variable_scope.py. 

2090 ops.add_to_collections(collections, self) 

2091 elif ops.GraphKeys.GLOBAL_STEP in collections: 

2092 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) 

2093 initial_value = initial_value if self._in_graph_mode else None 

2094 super(ResourceVariable, self).__init__( 

2095 trainable=trainable, 

2096 shape=shape, 

2097 dtype=dtype, 

2098 handle=handle, 

2099 synchronization=synchronization, 

2100 constraint=constraint, 

2101 aggregation=aggregation, 

2102 distribute_strategy=distribute_strategy, 

2103 name=name, 

2104 unique_id=unique_id, 

2105 handle_name=handle_name, 

2106 graph_element=graph_element, 

2107 initial_value=initial_value, 

2108 initializer_op=initializer_op, 

2109 is_initialized_op=is_initialized_op, 

2110 cached_value=cached_value, 

2111 caching_device=caching_device, 

2112 validate_shape=validate_shape, 

2113 ) 

2114 

2115 def _init_from_proto(self, 

2116 variable_def, 

2117 import_scope=None, 

2118 validate_shape=True): 

2119 """Initializes from `VariableDef` proto.""" 

2120 # Note that init_from_proto is currently not supported in Eager mode. 

2121 assert not context.executing_eagerly() 

2122 self._in_graph_mode = True 

2123 assert isinstance(variable_def, variable_pb2.VariableDef) 

2124 if not variable_def.is_resource: 

2125 raise ValueError(f"The `variable_def` you passed to `tf.Variable` is " 

2126 f"Trying to restore a TF 1.x Reference Variable " 

2127 f"as a TF 2.x ResourceVariable. This is unsupported. " 

2128 f"Got variable_def={variable_def}") 

2129 

2130 # Create from variable_def. 

2131 g = ops.get_default_graph() 

2132 self._handle = g.as_graph_element( 

2133 ops.prepend_name_scope( 

2134 variable_def.variable_name, import_scope=import_scope), 

2135 allow_operation=False) 

2136 self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape")) 

2137 self._handle_name = self._handle.name 

2138 self._unique_id = self._handle_name 

2139 self._initializer_op = g.as_graph_element( 

2140 ops.prepend_name_scope( 

2141 variable_def.initializer_name, import_scope=import_scope)) 

2142 # Check whether initial_value_name exists for backwards compatibility. 

2143 if (hasattr(variable_def, "initial_value_name") and 

2144 variable_def.initial_value_name): 

2145 self._initial_value = g.as_graph_element( 

2146 ops.prepend_name_scope( 

2147 variable_def.initial_value_name, import_scope=import_scope)) 

2148 else: 

2149 self._initial_value = None 

2150 synchronization, aggregation, trainable = ( 

2151 variables.validate_synchronization_aggregation_trainable( 

2152 variable_def.synchronization, variable_def.aggregation, 

2153 variable_def.trainable, variable_def.variable_name)) 

2154 self._synchronization = synchronization 

2155 self._aggregation = aggregation 

2156 self._trainable = trainable 

2157 if variable_def.snapshot_name: 

2158 snapshot = g.as_graph_element( 

2159 ops.prepend_name_scope( 

2160 variable_def.snapshot_name, import_scope=import_scope)) 

2161 if snapshot.op.type != "ReadVariableOp": 

2162 self._cached_value = snapshot 

2163 else: 

2164 self._cached_value = None 

2165 while snapshot.op.type != "ReadVariableOp": 

2166 snapshot = snapshot.op.inputs[0] 

2167 self._graph_element = snapshot 

2168 else: 

2169 self._cached_value = None 

2170 # Legacy case for protos without the snapshot name; assume it's the 

2171 # following. 

2172 self._graph_element = g.get_tensor_by_name(self._handle.op.name + 

2173 "/Read/ReadVariableOp:0") 

2174 if variable_def.HasField("save_slice_info_def"): 

2175 self._save_slice_info = variables.Variable.SaveSliceInfo( 

2176 save_slice_info_def=variable_def.save_slice_info_def, 

2177 import_scope=import_scope) 

2178 else: 

2179 self._save_slice_info = None 

2180 self._caching_device = None 

2181 self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) 

2182 self._constraint = None 

2183 self._validate_shape = validate_shape 

2184 

2185 def _init_from_handle(self, 

2186 trainable=None, 

2187 shape=None, 

2188 dtype=None, 

2189 handle=None): 

2190 handle_data = get_eager_safe_handle_data(handle) 

2191 if not handle_data.is_set: 

2192 # The handle may not have the handle shape and dtype if it was created 

2193 # using tf.placeholder. 

2194 handle_data = handle_data_util.create_handle_data(shape, dtype) 

2195 handle_data_util.set_handle_data(handle, handle_data) 

2196 # pylint: disable=protected-access 

2197 if hasattr(handle, "_name") and isinstance(handle._name, str): 

2198 handle_name = handle._name.rstrip(":0") 

2199 else: 

2200 handle_name = None 

2201 # pylint: enable=protected-access 

2202 unique_id = getattr(handle, "_unique_id", None) 

2203 super().__init__( 

2204 trainable=trainable, shape=shape, dtype=dtype, handle=handle, 

2205 unique_id=unique_id, handle_name=handle_name) 

2206 

2207 

2208class UninitializedVariable(BaseResourceVariable): 

2209 """A variable with no initializer.""" 

2210 

2211 def __init__( # pylint: disable=super-init-not-called 

2212 self, 

2213 trainable=None, 

2214 caching_device=None, 

2215 name=None, 

2216 shape=None, 

2217 dtype=None, 

2218 constraint=None, 

2219 synchronization=None, 

2220 aggregation=None, 

2221 extra_handle_data=None, 

2222 distribute_strategy=None, 

2223 **unused_kwargs): 

2224 """Creates the variable handle. 

2225 

2226 Args: 

2227 trainable: If `True`, GradientTapes automatically watch uses of this 

2228 Variable. 

2229 caching_device: Optional device string or function describing where the 

2230 Variable should be cached for reading. Defaults to the Variable's 

2231 device. If not `None`, caches on another device. Typical use is to 

2232 cache on the device where the Ops using the Variable reside, to 

2233 deduplicate copying through `Switch` and other conditional statements. 

2234 name: Optional name for the variable. Defaults to `'Variable'` and gets 

2235 uniquified automatically. 

2236 shape: The variable's shape. 

2237 dtype: The variable's dtype. 

2238 constraint: An optional projection function to be applied to the variable 

2239 after being updated by an `Optimizer` (e.g. used to implement norm 

2240 constraints or value constraints for layer weights). The function must 

2241 take as input the unprojected Tensor representing the value of the 

2242 variable and return the Tensor for the projected value (which must have 

2243 the same shape). Constraints are not safe to use when doing asynchronous 

2244 distributed training. 

2245 synchronization: Indicates when a distributed a variable will be 

2246 aggregated. Accepted values are constants defined in the class 

2247 `tf.VariableSynchronization`. By default the synchronization is set to 

2248 `AUTO` and the current `DistributionStrategy` chooses when to 

2249 synchronize. 

2250 aggregation: Indicates how a distributed variable will be aggregated. 

2251 Accepted values are constants defined in the class 

2252 `tf.VariableAggregation`. 

2253 extra_handle_data: Optional, another resource handle or Tensor with handle 

2254 data to merge with `shape` and `dtype`. 

2255 distribute_strategy: The tf.distribute.Strategy this variable is being 

2256 created inside of. 

2257 """ 

2258 with ops.init_scope(): 

2259 # Here we are detecting eagerness within an init_scope, so this will only 

2260 # be true when we are running in TF1 graph mode. 

2261 self._in_graph_mode = not context.executing_eagerly() 

2262 with ops.name_scope(name, "Variable", skip_on_eager=False) as name: 

2263 handle_name = ops.name_from_scope_name(name) 

2264 if self._in_graph_mode: 

2265 shared_name = handle_name 

2266 unique_id = shared_name 

2267 else: 

2268 unique_id = "%s_%d" % (handle_name, ops.uid()) 

2269 shared_name = None # Never shared 

2270 handle = _variable_handle_from_shape_and_dtype( 

2271 shape=shape, 

2272 dtype=dtype, 

2273 shared_name=shared_name, 

2274 name=name, 

2275 graph_mode=self._in_graph_mode, 

2276 initial_value=extra_handle_data) 

2277 handle._parent_trackable = weakref.ref(self) 

2278 handle._name = handle_name + ":0" 

2279 handle._unique_id = unique_id 

2280 

2281 if self._in_graph_mode: 

2282 # We only need to add the read_variable_op in TF1. 

2283 with ops.name_scope("Read"): 

2284 # Manually assign reads to the handle's device to avoid log 

2285 # messages. 

2286 with ops.device(handle.device): 

2287 value = gen_resource_variable_ops.read_variable_op(handle, dtype) 

2288 _maybe_set_handle_data(dtype, handle, value) 

2289 graph_element = value 

2290 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) 

2291 # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable, 

2292 # because retraining or frozen use of imported SavedModels is 

2293 # controlled at higher levels of model building. 

2294 else: 

2295 graph_element = None 

2296 super(UninitializedVariable, self).__init__( 

2297 distribute_strategy=distribute_strategy, 

2298 shape=shape, 

2299 dtype=dtype, 

2300 unique_id=unique_id, 

2301 handle_name=handle_name, 

2302 constraint=constraint, 

2303 handle=handle, 

2304 graph_element=graph_element, 

2305 trainable=trainable, 

2306 synchronization=synchronization, 

2307 aggregation=aggregation, 

2308 in_graph_mode=self._in_graph_mode) 

2309 

2310 

2311_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable) 

2312math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access 

2313 

2314 

2315def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): 

2316 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 

2317 

2318 

2319# Register a conversion function which reads the value of the variable, 

2320# allowing instances of the class to be used as tensors. 

2321tensor_conversion_registry.register_tensor_conversion_function( 

2322 BaseResourceVariable, _dense_var_to_tensor) 

2323 

2324 

2325class _UnreadVariable(BaseResourceVariable): 

2326 """Represents a future for a read of a variable. 

2327 

2328 Pretends to be the tensor if anyone looks. 

2329 """ 

2330 

2331 def __init__(self, handle, dtype, shape, in_graph_mode, parent_op, unique_id): 

2332 if isinstance(handle, ops.EagerTensor): 

2333 handle_name = "" 

2334 else: 

2335 handle_name = handle.name 

2336 # Only create a graph_element if we're in session.run-land as only 

2337 # session.run requires a preexisting tensor to evaluate. Otherwise we can 

2338 # avoid accidentally reading the variable. 

2339 if context.executing_eagerly() or ops.inside_function(): 

2340 graph_element = None 

2341 else: 

2342 with ops.control_dependencies([parent_op]): 

2343 graph_element = gen_resource_variable_ops.read_variable_op( 

2344 handle, dtype) 

2345 _maybe_set_handle_data(dtype, handle, graph_element) 

2346 super(_UnreadVariable, self).__init__( 

2347 handle=handle, 

2348 shape=shape, 

2349 handle_name=handle_name, 

2350 unique_id=unique_id, 

2351 dtype=dtype, 

2352 graph_element=graph_element) 

2353 self._parent_op = parent_op 

2354 

2355 @property 

2356 def name(self): 

2357 if self._in_graph_mode: 

2358 return self._parent_op.name 

2359 else: 

2360 return "UnreadVariable" 

2361 

2362 def value(self): 

2363 return self._read_variable_op() 

2364 

2365 def read_value(self): 

2366 return self._read_variable_op() 

2367 

2368 def _read_variable_op(self): 

2369 with ops.control_dependencies([self._parent_op]): 

2370 result = gen_resource_variable_ops.read_variable_op( 

2371 self._handle, self._dtype) 

2372 _maybe_set_handle_data(self._dtype, self._handle, result) 

2373 return result 

2374 

2375 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 

2376 with ops.control_dependencies([self._parent_op]): 

2377 return super(_UnreadVariable, self).assign_sub(delta, use_locking, name, 

2378 read_value) 

2379 

2380 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 

2381 with ops.control_dependencies([self._parent_op]): 

2382 return super(_UnreadVariable, self).assign_add(delta, use_locking, name, 

2383 read_value) 

2384 

2385 def assign(self, value, use_locking=None, name=None, read_value=True): 

2386 with ops.control_dependencies([self._parent_op]): 

2387 return super(_UnreadVariable, self).assign(value, use_locking, name, 

2388 read_value) 

2389 

2390 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

2391 with ops.control_dependencies([self._parent_op]): 

2392 return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking, 

2393 name) 

2394 

2395 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

2396 with ops.control_dependencies([self._parent_op]): 

2397 return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking, 

2398 name) 

2399 

2400 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

2401 with ops.control_dependencies([self._parent_op]): 

2402 return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking, 

2403 name) 

2404 

2405 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

2406 with ops.control_dependencies([self._parent_op]): 

2407 return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking, 

2408 name) 

2409 

2410 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

2411 with ops.control_dependencies([self._parent_op]): 

2412 return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking, 

2413 name) 

2414 

2415 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

2416 with ops.control_dependencies([self._parent_op]): 

2417 return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking, 

2418 name) 

2419 

2420 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

2421 with ops.control_dependencies([self._parent_op]): 

2422 return super(_UnreadVariable, 

2423 self).scatter_update(sparse_delta, use_locking, name) 

2424 

2425 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 

2426 with ops.control_dependencies([self._parent_op]): 

2427 return super(_UnreadVariable, 

2428 self).batch_scatter_update(sparse_delta, use_locking, name) 

2429 

2430 def scatter_nd_sub(self, indices, updates, name=None): 

2431 with ops.control_dependencies([self._parent_op]): 

2432 return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name) 

2433 

2434 def scatter_nd_add(self, indices, updates, name=None): 

2435 with ops.control_dependencies([self._parent_op]): 

2436 return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name) 

2437 

2438 def scatter_nd_update(self, indices, updates, name=None): 

2439 with ops.control_dependencies([self._parent_op]): 

2440 return super(_UnreadVariable, 

2441 self).scatter_nd_update(indices, updates, name) 

2442 

2443 def scatter_nd_max(self, indices, updates, name=None): 

2444 with ops.control_dependencies([self._parent_op]): 

2445 return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name) 

2446 

2447 def scatter_nd_min(self, indices, updates, name=None): 

2448 with ops.control_dependencies([self._parent_op]): 

2449 return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name) 

2450 

2451 @property 

2452 def op(self): 

2453 """The op for this variable.""" 

2454 return self._parent_op 

2455 

2456 

2457@ops.RegisterGradient("ReadVariableOp") 

2458def _ReadGrad(_, grad): 

2459 """Gradient for read op.""" 

2460 return grad 

2461 

2462 

2463def variable_shape(handle, out_type=dtypes.int32): 

2464 handle_data = get_eager_safe_handle_data(handle) 

2465 if handle_data is None or not handle_data.is_set: 

2466 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 

2467 shape_proto = handle_data.shape_and_type[0].shape 

2468 if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim): 

2469 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 

2470 return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type) 

2471 

2472 

2473@ops.RegisterGradient("ResourceGather") 

2474def _GatherGrad(op, grad): 

2475 """Gradient for gather op.""" 

2476 # Build appropriately shaped IndexedSlices 

2477 handle = op.inputs[0] 

2478 indices = op.inputs[1] 

2479 params_shape = variable_shape(handle) 

2480 size = array_ops.expand_dims(array_ops.size(indices), 0) 

2481 values_shape = array_ops.concat([size, params_shape[1:]], 0) 

2482 values = array_ops.reshape(grad, values_shape) 

2483 indices = array_ops.reshape(indices, size) 

2484 return (indexed_slices.IndexedSlices(values, indices, params_shape), None) 

2485 

2486 

2487@tf_export("__internal__.ops.is_resource_variable", v1=[]) 

2488def is_resource_variable(var): 

2489 """"Returns True if `var` is to be considered a ResourceVariable.""" 

2490 return isinstance(var, BaseResourceVariable) or hasattr( 

2491 var, "_should_act_as_resource_variable") 

2492 

2493 

2494def copy_to_graph_uninitialized(var): 

2495 """Copies an existing variable to a new graph, with no initializer.""" 

2496 # Like ResourceVariable.__deepcopy__, but does not set an initializer on the 

2497 # new variable. 

2498 # pylint: disable=protected-access 

2499 new_variable = UninitializedVariable( 

2500 trainable=var.trainable, 

2501 constraint=var._constraint, 

2502 shape=var.shape, 

2503 dtype=var.dtype, 

2504 name=var._shared_name, 

2505 synchronization=var.synchronization, 

2506 aggregation=var.aggregation, 

2507 extra_handle_data=var.handle) 

2508 new_variable._maybe_initialize_trackable() 

2509 # pylint: enable=protected-access 

2510 return new_variable 

2511 

2512 

2513ops.NotDifferentiable("Assert") 

2514ops.NotDifferentiable("VarIsInitializedOp") 

2515ops.NotDifferentiable("VariableShape") 

2516 

2517 

2518# TODO(b/246356867): This is the draft implementation. Currently VariableSpec is 

2519# the only class using them. Move them to a separate file when necessary. 

2520class StructurePattern: 

2521 pass 

2522 

2523 

2524class PLeaf(StructurePattern): 

2525 """Represents a singleton leaf StructurePattern.""" 

2526 

2527 def __new__(cls): 

2528 if not hasattr(cls, "instance"): 

2529 cls.instance = super().__new__(cls) 

2530 return cls.instance 

2531 

2532 

2533class PList(StructurePattern): 

2534 """Represents a list of StructurePatterns.""" 

2535 

2536 def __init__(self, *components): 

2537 self.components = list(components) 

2538 

2539 def __eq__(self, other): 

2540 return isinstance(other, PList) and self.components == other.components 

2541 

2542 

2543class VariableSpec(tensor_spec.DenseSpec): 

2544 """Describes a tf.Variable. 

2545 

2546 A `VariableSpec` provides metadata describing the `tf.Variable` objects 

2547 accepted or returned by TensorFlow 2.x APIs. 

2548 """ 

2549 

2550 __slots__ = ["trainable", "alias_id"] 

2551 

2552 value_type = property(lambda self: ResourceVariable) 

2553 

2554 def __init__(self, shape, dtype=dtypes.float32, trainable=True, 

2555 alias_id=None): 

2556 super(VariableSpec, self).__init__(shape, dtype=dtype) 

2557 self.trainable = trainable 

2558 self.alias_id = alias_id 

2559 

2560 def is_compatible_with(self, spec_or_value): 

2561 """Returns True if `spec_or_value` is compatible with this `VariableSpec`. 

2562 

2563 `spec_or_value` is considered to be compatible with this `VariableSpec` if 

2564 

2565 * `spec_or_value` is a `Variable` or `VariableSpec`, 

2566 * their shapes are compatible, 

2567 * their dtypes are the same, 

2568 * they are both trainable or not trainable. 

2569 * they share the same alias_id if `spec_or_value` is a `VariableSpec`. 

2570 

2571 Example: 

2572 

2573 >>> v = tf.Variable([1., 2., 3.]) 

2574 >>> spec = VariableSpec([None]) 

2575 >>> spec.is_compatible_with(v) 

2576 True 

2577 >>> v = tf.Variable(1) 

2578 >>> spec.is_compatible_with(v) 

2579 False 

2580 

2581 Args: 

2582 spec_or_value: A VariableSpec or Variable to compare against. 

2583 

2584 Returns: 

2585 True if `spec_or_value` is compatible with this `VariableSpec`. 

2586 """ 

2587 if not isinstance(spec_or_value, (type(self), self.value_type)): 

2588 return False 

2589 compatible = (self.shape.is_compatible_with(spec_or_value.shape) and 

2590 self.dtype == spec_or_value.dtype and 

2591 self.trainable == spec_or_value.trainable) 

2592 if isinstance(spec_or_value, type(self)): 

2593 # alias_id must be the same to be compatible. 

2594 return compatible and self.alias_id == spec_or_value.alias_id 

2595 return compatible 

2596 

2597 @classmethod 

2598 def from_value(cls, value): 

2599 """Creates a `VariableSpec` from the given `Variable`. 

2600 

2601 `value`'s shape, dtype, and trainable attributes will be used to create 

2602 the new `VariableSpec`. 

2603 

2604 Example: 

2605 

2606 >>> v = tf.Variable([1., 2., 3.]) 

2607 >>> VariableSpec.from_value(v) 

2608 VariableSpec(shape=(3,), dtype=tf.float32, trainable=True, alias_id=None) 

2609 

2610 Args: 

2611 value: A Variable. 

2612 

2613 Returns: 

2614 A `VariableSpec` created from `value`. 

2615 """ 

2616 return cls(value.shape, dtype=value.dtype, trainable=value.trainable) 

2617 

2618 def _to_components(self, value): 

2619 return [value.handle] 

2620 

2621 def _from_components(self, components): 

2622 if not isinstance(components, (list, tuple)): 

2623 raise TypeError(f"Components of a ResourceVariable must be a list or " 

2624 f"tuple, got f{components} instead.") 

2625 if len(components) != 1: 

2626 raise ValueError(f"Components of a ResourceVariable must only contain " 

2627 f"its resource handle, got f{components} instead.") 

2628 handle = components[0] 

2629 if not isinstance(handle, ops.Tensor) or handle.dtype != dtypes.resource: 

2630 raise ValueError(f"The handle of a ResourceVariable must be a resource " 

2631 f"tensor, got {handle} instead.") 

2632 return ResourceVariable(trainable=self.trainable, 

2633 shape=self.shape, 

2634 dtype=self.dtype, 

2635 handle=handle) 

2636 

2637 @property 

2638 def _component_specs(self): 

2639 return [tensor_spec.TensorSpec([], dtypes.resource)] 

2640 

2641 def _serialize(self): 

2642 return self.shape, self.dtype, self.trainable, self.alias_id 

2643 

2644 # TraceType method 

2645 def is_subtype_of(self, other): 

2646 if type(self) is not type(other): 

2647 return False 

2648 

2649 # Remove this once we add alias_id to all CompositeTensors with 

2650 # ResourceVariable components. 

2651 if self.alias_id is None and other.alias_id is None: 

2652 return super().is_subtype_of(other) 

2653 

2654 if self.alias_id is None or other.alias_id is None: 

2655 raise NotImplementedError(f"VariableSpec.is_subtype_of doesn't support " 

2656 f"alias_id=None, got self: {self} and other: " 

2657 f"{other}.") 

2658 

2659 return super().is_subtype_of(other) 

2660 

2661 # TraceType method 

2662 def most_specific_common_supertype(self, others): 

2663 if any(type(self) is not type(other) for other in others): 

2664 return None 

2665 

2666 # It is a special case for tf.nest, which often takes CompositeTensors and 

2667 # converts to TypeSpecs internally, such as tf.nest.assert_same_structure. 

2668 if (self.alias_id is None and 

2669 all(other.alias_id is None for other in others)): 

2670 return super().most_specific_common_supertype(others) 

2671 

2672 if self.alias_id is None or any(other.alias_id is None for other in others): 

2673 raise NotImplementedError(f"VariableSpec.most_specific_common_supertype " 

2674 f"doesn't support alias_id=None, got self: " 

2675 f"{self} and others: {others}.") 

2676 

2677 return super().most_specific_common_supertype(others) 

2678 

2679 # TraceType method 

2680 def placeholder_value(self, placeholder_context): 

2681 if placeholder_context.unnest_only: 

2682 return self 

2683 

2684 name = self.name or placeholder_context.naming_scope 

2685 context_graph = placeholder_context.context_graph 

2686 if placeholder_context.has_placeholder(self.alias_id): 

2687 # Get reference to the existing variable if alias_id already 

2688 # exists in the PlaceholderContext 

2689 variable = placeholder_context.get_placeholder(self.alias_id) 

2690 else: 

2691 spec = tensor_spec.TensorSpec([], dtypes.resource) 

2692 spec_context = trace_type.InternalPlaceholderContext( 

2693 context_graph.outer_graph) 

2694 spec_context.update_naming_scope(name) 

2695 placeholder = spec.placeholder_value(spec_context) 

2696 variable = self._from_components([placeholder]) 

2697 # (b/262771247) ShardedVariable break without this and VariableSpecs 

2698 # without alias_id are not TraceTypes. 

2699 if self.alias_id is not None: 

2700 placeholder_context.add_placeholder(self.alias_id, variable) 

2701 # Capture the Variable's placeholder within the default graph of 

2702 # the current thread. 

2703 placeholder = context_graph.capture(variable.handle, name=name) 

2704 placeholder.op._set_attr( # pylint: disable=protected-access 

2705 "_user_specified_name", 

2706 attr_value_pb2.AttrValue(s=compat.as_bytes(name))) 

2707 return variable 

2708 

2709 def _to_tensors(self, value): 

2710 assert isinstance(value, BaseResourceVariable) 

2711 return [value.handle] 

2712 

2713 def _get_structure(self): 

2714 # shape, dtype, trainable, and alias_id are all leaves. 

2715 return PList(PLeaf(), PLeaf(), PLeaf(), PLeaf()) 

2716 

2717 def __repr__(self): 

2718 return (f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype!r}, " 

2719 f"trainable={self.trainable!r}, alias_id={self.alias_id!r})") 

2720 

2721 def __hash__(self): 

2722 return hash((self.shape, self.dtype, self.trainable, self.alias_id)) 

2723 

2724 def __eq__(self, other): 

2725 return (type(self) is type(other) and self.shape == other.shape and 

2726 self.dtype == other.dtype and self.trainable == other.trainable and 

2727 self.alias_id == other.alias_id) 

2728 

2729 

2730nested_structure_coder.register_codec( 

2731 nested_structure_coder.BuiltInTypeSpecCodec( 

2732 VariableSpec, struct_pb2.TypeSpecProto.VARIABLE_SPEC 

2733 ) 

2734) 

2735 

2736 

2737_pywrap_utils.RegisterType("VariableSpec", VariableSpec) 

2738 

2739 

2740def write_object_proto_for_resource_variable(resource_variable, 

2741 proto, 

2742 options, 

2743 enforce_naming=True): 

2744 """Writes additional information of the variable into the SavedObject proto. 

2745 

2746 This allows users to define a `hook` to provide extra information of the 

2747 variable to the SavedObject. 

2748 

2749 For example, DistributedVariable class would fill in components in the 

2750 distributed context. 

2751 

2752 Args: 

2753 resource_variable: A `ResourceVariable` or `DistributedValue` that has the 

2754 information to be saved into the proto. 

2755 proto: `SavedObject` proto to update. 

2756 options: A `SaveOption` instance that configures save behavior. 

2757 enforce_naming: A bool determining whether to check that names end in the 

2758 expected string ':0' 

2759 """ 

2760 proto.variable.SetInParent() 

2761 if enforce_naming and not resource_variable.name.endswith(":0"): 

2762 raise ValueError(f"Cowardly refusing to save variable " 

2763 f"{resource_variable.name} because of " 

2764 f"unexpected suffix in the name (expected ':0')" 

2765 f"which won't be restored.") 

2766 proto.variable.name = tensor_module.get_op_name(resource_variable.name) 

2767 proto.variable.trainable = resource_variable.trainable 

2768 proto.variable.dtype = resource_variable.dtype.as_datatype_enum 

2769 proto.variable.synchronization = resource_variable.synchronization.value 

2770 proto.variable.aggregation = resource_variable.aggregation.value 

2771 proto.variable.shape.CopyFrom(resource_variable.shape.as_proto()) 

2772 if options.experimental_variable_policy._save_variable_devices( # pylint: disable=protected-access 

2773 ): 

2774 if hasattr(resource_variable, "device"): 

2775 proto.variable.device = resource_variable.device