Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ref_variable.py: 33%

318 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""RefVariable class.""" 

16 

17from tensorflow.core.framework import attr_value_pb2 

18from tensorflow.core.framework import variable_pb2 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import indexed_slices 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_conversion_registry 

23from tensorflow.python.framework import tensor_shape 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import gen_array_ops 

26from tensorflow.python.ops import gen_state_ops 

27from tensorflow.python.ops import resource_variable_ops 

28from tensorflow.python.ops import state_ops 

29from tensorflow.python.ops import variable_v1 

30from tensorflow.python.ops import variables 

31from tensorflow.python.platform import tf_logging as logging 

32from tensorflow.python.trackable import base as trackable 

33from tensorflow.python.types import core 

34from tensorflow.python.util import compat 

35from tensorflow.python.util import lazy_loader 

36from tensorflow.python.util.deprecation import deprecated 

37 

38 

39variable_scope = lazy_loader.LazyLoader( 

40 "variable_scope", globals(), 

41 "tensorflow.python.ops.variable_scope") 

42 

43 

44def default_variable_creator(next_creator=None, **kwargs): 

45 """Default variable creator.""" 

46 assert next_creator is None 

47 initial_value = kwargs.get("initial_value", None) 

48 trainable = kwargs.get("trainable", None) 

49 collections = kwargs.get("collections", None) 

50 validate_shape = kwargs.get("validate_shape", True) 

51 caching_device = kwargs.get("caching_device", None) 

52 name = kwargs.get("name", None) 

53 variable_def = kwargs.get("variable_def", None) 

54 dtype = kwargs.get("dtype", None) 

55 expected_shape = kwargs.get("expected_shape", None) 

56 import_scope = kwargs.get("import_scope", None) 

57 constraint = kwargs.get("constraint", None) 

58 use_resource = kwargs.get("use_resource", None) 

59 synchronization = kwargs.get("synchronization", None) 

60 aggregation = kwargs.get("aggregation", None) 

61 shape = kwargs.get("shape", None) 

62 

63 if use_resource is None: 

64 use_resource = variable_scope.get_variable_scope().use_resource 

65 if use_resource is None: 

66 use_resource = variable_scope._DEFAULT_USE_RESOURCE # pylint: disable=protected-access 

67 use_resource = use_resource or context.executing_eagerly() 

68 if use_resource: 

69 distribute_strategy = kwargs.get("distribute_strategy", None) 

70 return resource_variable_ops.ResourceVariable( 

71 initial_value=initial_value, 

72 trainable=trainable, 

73 collections=collections, 

74 validate_shape=validate_shape, 

75 caching_device=caching_device, 

76 name=name, 

77 dtype=dtype, 

78 constraint=constraint, 

79 variable_def=variable_def, 

80 import_scope=import_scope, 

81 distribute_strategy=distribute_strategy, 

82 synchronization=synchronization, 

83 aggregation=aggregation, 

84 shape=shape) 

85 else: 

86 return RefVariable( 

87 initial_value=initial_value, 

88 trainable=trainable, 

89 collections=collections, 

90 validate_shape=validate_shape, 

91 caching_device=caching_device, 

92 name=name, 

93 dtype=dtype, 

94 constraint=constraint, 

95 variable_def=variable_def, 

96 expected_shape=expected_shape, 

97 import_scope=import_scope, 

98 synchronization=synchronization, 

99 aggregation=aggregation, 

100 shape=shape) 

101 

102 

103variable_v1.default_variable_creator = default_variable_creator 

104 

105 

106def _to_proto_fn(v, export_scope=None): 

107 """Converts Variable and ResourceVariable to VariableDef for collections.""" 

108 return v.to_proto(export_scope=export_scope) 

109 

110 

111def _from_proto_fn(v, import_scope=None): 

112 """Creates Variable or ResourceVariable from VariableDef as needed.""" 

113 if v.is_resource: 

114 return resource_variable_ops.ResourceVariable.from_proto( 

115 v, import_scope=import_scope) 

116 return variable_v1.VariableV1.from_proto(v, import_scope=import_scope) 

117 

118 

119ops.register_proto_function( 

120 ops.GraphKeys.GLOBAL_VARIABLES, 

121 proto_type=variable_pb2.VariableDef, 

122 to_proto=_to_proto_fn, 

123 from_proto=_from_proto_fn) 

124ops.register_proto_function( 

125 ops.GraphKeys.TRAINABLE_VARIABLES, 

126 proto_type=variable_pb2.VariableDef, 

127 to_proto=_to_proto_fn, 

128 from_proto=_from_proto_fn) 

129ops.register_proto_function( 

130 ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 

131 proto_type=variable_pb2.VariableDef, 

132 to_proto=_to_proto_fn, 

133 from_proto=_from_proto_fn) 

134ops.register_proto_function( 

135 ops.GraphKeys.LOCAL_VARIABLES, 

136 proto_type=variable_pb2.VariableDef, 

137 to_proto=_to_proto_fn, 

138 from_proto=_from_proto_fn) 

139ops.register_proto_function( 

140 ops.GraphKeys.MODEL_VARIABLES, 

141 proto_type=variable_pb2.VariableDef, 

142 to_proto=_to_proto_fn, 

143 from_proto=_from_proto_fn) 

144ops.register_proto_function( 

145 ops.GraphKeys.GLOBAL_STEP, 

146 proto_type=variable_pb2.VariableDef, 

147 to_proto=_to_proto_fn, 

148 from_proto=_from_proto_fn) 

149ops.register_proto_function( 

150 ops.GraphKeys.METRIC_VARIABLES, 

151 proto_type=variable_pb2.VariableDef, 

152 to_proto=_to_proto_fn, 

153 from_proto=_from_proto_fn) 

154 

155 

156# TODO(apassos): do not repeat all comments here 

157class RefVariable(variable_v1.VariableV1, core.Tensor): 

158 """Ref-based implementation of variables.""" 

159 

160 def __init__( 

161 self, # pylint: disable=super-init-not-called 

162 initial_value=None, 

163 trainable=None, 

164 collections=None, 

165 validate_shape=True, 

166 caching_device=None, 

167 name=None, 

168 variable_def=None, 

169 dtype=None, 

170 expected_shape=None, 

171 import_scope=None, 

172 constraint=None, 

173 synchronization=None, 

174 aggregation=None, 

175 shape=None): 

176 """Creates a new variable with value `initial_value`. 

177 

178 The new variable is added to the graph collections listed in `collections`, 

179 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

180 

181 If `trainable` is `True` the variable is also added to the graph collection 

182 `GraphKeys.TRAINABLE_VARIABLES`. 

183 

184 This constructor creates both a `variable` Op and an `assign` Op to set the 

185 variable to its initial value. 

186 

187 Args: 

188 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

189 which is the initial value for the Variable. The initial value must have 

190 a shape specified unless `validate_shape` is set to False. Can also be a 

191 callable with no argument that returns the initial value when called. In 

192 that case, `dtype` must be specified. (Note that initializer functions 

193 from init_ops.py must first be bound to a shape before being used here.) 

194 trainable: If `True`, also adds the variable to the graph collection 

195 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 

196 list of variables to use by the `Optimizer` classes. Defaults to `True`, 

197 unless `synchronization` is set to `ON_READ`, in which case it defaults 

198 to `False`. 

199 collections: List of graph collections keys. The new variable is added to 

200 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

201 validate_shape: If `False`, allows the variable to be initialized with a 

202 value of unknown shape. If `True`, the default, the shape of 

203 `initial_value` must be known. 

204 caching_device: Optional device string describing where the Variable 

205 should be cached for reading. Defaults to the Variable's device. If not 

206 `None`, caches on another device. Typical use is to cache on the device 

207 where the Ops using the Variable reside, to deduplicate copying through 

208 `Switch` and other conditional statements. 

209 name: Optional name for the variable. Defaults to `'Variable'` and gets 

210 uniquified automatically. 

211 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 

212 Variable object with its contents, referencing the variable's nodes in 

213 the graph, which must already exist. The graph is not changed. 

214 `variable_def` and the other arguments are mutually exclusive. 

215 dtype: If set, initial_value will be converted to the given type. If 

216 `None`, either the datatype will be kept (if `initial_value` is a 

217 Tensor), or `convert_to_tensor` will decide. 

218 expected_shape: A TensorShape. If set, initial_value is expected to have 

219 this shape. 

220 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 

221 used when initializing from protocol buffer. 

222 constraint: An optional projection function to be applied to the variable 

223 after being updated by an `Optimizer` (e.g. used to implement norm 

224 constraints or value constraints for layer weights). The function must 

225 take as input the unprojected Tensor representing the value of the 

226 variable and return the Tensor for the projected value (which must have 

227 the same shape). Constraints are not safe to use when doing asynchronous 

228 distributed training. 

229 synchronization: Indicates when a distributed a variable will be 

230 aggregated. Accepted values are constants defined in the class 

231 `tf.VariableSynchronization`. By default the synchronization is set to 

232 `AUTO` and the current `DistributionStrategy` chooses when to 

233 synchronize. 

234 aggregation: Indicates how a distributed variable will be aggregated. 

235 Accepted values are constants defined in the class 

236 `tf.VariableAggregation`. 

237 shape: (optional) The shape of this variable. If None, the shape of 

238 `initial_value` will be used. When setting this argument to 

239 `tf.TensorShape(None)` (representing an unspecified shape), the variable 

240 can be assigned with values of different shapes. 

241 

242 Raises: 

243 ValueError: If both `variable_def` and initial_value are specified. 

244 ValueError: If the initial value is not specified, or does not have a 

245 shape and `validate_shape` is `True`. 

246 RuntimeError: If eager execution is enabled. 

247 """ 

248 self._in_graph_mode = True 

249 if variable_def: 

250 # If variable_def is provided, recreates the variable from its fields. 

251 if initial_value: 

252 raise ValueError("variable_def and initial_value are mutually " 

253 "exclusive.") 

254 self._init_from_proto(variable_def, import_scope=import_scope) 

255 else: 

256 # Create from initial_value. 

257 self._init_from_args( 

258 initial_value=initial_value, 

259 trainable=trainable, 

260 collections=collections, 

261 validate_shape=validate_shape, 

262 caching_device=caching_device, 

263 name=name, 

264 dtype=dtype, 

265 expected_shape=expected_shape, 

266 constraint=constraint, 

267 synchronization=synchronization, 

268 aggregation=aggregation, 

269 shape=shape) 

270 

271 def __repr__(self): 

272 if context.executing_eagerly() and not self._in_graph_mode: 

273 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 

274 self.name, self.get_shape(), self.dtype.name, 

275 ops.numpy_text(self.read_value(), is_repr=True)) 

276 else: 

277 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 

278 self.name, self.get_shape(), self.dtype.name) 

279 

280 def _init_from_args(self, 

281 initial_value=None, 

282 trainable=None, 

283 collections=None, 

284 validate_shape=True, 

285 caching_device=None, 

286 name=None, 

287 dtype=None, 

288 expected_shape=None, 

289 constraint=None, 

290 synchronization=None, 

291 aggregation=None, 

292 shape=None): 

293 """Creates a new variable from arguments. 

294 

295 Args: 

296 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

297 which is the initial value for the Variable. The initial value must have 

298 a shape specified unless `validate_shape` is set to False. Can also be a 

299 callable with no argument that returns the initial value when called. 

300 (Note that initializer functions from init_ops.py must first be bound to 

301 a shape before being used here.) 

302 trainable: If `True`, also adds the variable to the graph collection 

303 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 

304 list of variables to use by the `Optimizer` classes. Defaults to `True`, 

305 unless `synchronization` is set to `ON_READ`, in which case it defaults 

306 to `False`. 

307 collections: List of graph collections keys. The new variable is added to 

308 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 

309 validate_shape: If `False`, allows the variable to be initialized with a 

310 value of unknown shape. If `True`, the default, the shape of 

311 `initial_value` must be known. 

312 caching_device: Optional device string or function describing where the 

313 Variable should be cached for reading. Defaults to the Variable's 

314 device. If not `None`, caches on another device. Typical use is to 

315 cache on the device where the Ops using the Variable reside, to 

316 deduplicate copying through `Switch` and other conditional statements. 

317 name: Optional name for the variable. Defaults to `'Variable'` and gets 

318 uniquified automatically. 

319 dtype: If set, initial_value will be converted to the given type. If None, 

320 either the datatype will be kept (if initial_value is a Tensor) or 

321 float32 will be used (if it is a Python object convertible to a Tensor). 

322 expected_shape: Deprecated. Ignored. 

323 constraint: An optional projection function to be applied to the variable 

324 after being updated by an `Optimizer` (e.g. used to implement norm 

325 constraints or value constraints for layer weights). The function must 

326 take as input the unprojected Tensor representing the value of the 

327 variable and return the Tensor for the projected value (which must have 

328 the same shape). Constraints are not safe to use when doing asynchronous 

329 distributed training. 

330 synchronization: Indicates when a distributed a variable will be 

331 aggregated. Accepted values are constants defined in the class 

332 `tf.VariableSynchronization`. By default the synchronization is set to 

333 `AUTO` and the current `DistributionStrategy` chooses when to 

334 synchronize. 

335 aggregation: Indicates how a distributed variable will be aggregated. 

336 Accepted values are constants defined in the class 

337 `tf.VariableAggregation`. 

338 shape: (optional) The shape of this variable. If None, the shape of 

339 `initial_value` will be used. When setting this argument to 

340 `tf.TensorShape(None)` (representing an unspecified shape), the variable 

341 can be assigned with values of different shapes. 

342 

343 Raises: 

344 ValueError: If the initial value is not specified, or does not have a 

345 shape and `validate_shape` is `True`. 

346 RuntimeError: If lifted into the eager context. 

347 """ 

348 _ = expected_shape 

349 if initial_value is None: 

350 raise ValueError("initial_value must be specified.") 

351 init_from_fn = callable(initial_value) 

352 

353 if collections is None: 

354 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 

355 if not isinstance(collections, (list, tuple, set)): 

356 raise ValueError( 

357 "collections argument to Variable constructor must be a list, tuple, " 

358 "or set. Got %s of type %s" % (collections, type(collections))) 

359 if constraint is not None and not callable(constraint): 

360 raise ValueError("The `constraint` argument must be a callable.") 

361 

362 # Store the graph key so optimizers know how to only retrieve variables from 

363 # this graph. 

364 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 

365 if isinstance(initial_value, trackable.CheckpointInitialValue): 

366 self._maybe_initialize_trackable() 

367 self._update_uid = initial_value.checkpoint_position.restore_uid 

368 initial_value = initial_value.wrapped_value 

369 

370 synchronization, aggregation, trainable = ( 

371 variables.validate_synchronization_aggregation_trainable( 

372 synchronization, aggregation, trainable, name)) 

373 self._synchronization = synchronization 

374 self._aggregation = aggregation 

375 self._trainable = trainable 

376 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 

377 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 

378 with ops.init_scope(): 

379 # Ensure that we weren't lifted into the eager context. 

380 if context.executing_eagerly(): 

381 raise RuntimeError( 

382 "Reference variables are not supported when eager execution is " 

383 "enabled. Please run `tf.compat.v1.enable_resource_variables()` to " 

384 "switch to resource variables.") 

385 with ops.name_scope(name, "Variable", 

386 [] if init_from_fn else [initial_value]) as name: 

387 

388 if init_from_fn: 

389 # Use attr_scope and device(None) to simulate the behavior of 

390 # colocate_with when the variable we want to colocate with doesn't 

391 # yet exist. 

392 true_name = ops.name_from_scope_name(name) # pylint: disable=protected-access 

393 attr = attr_value_pb2.AttrValue( 

394 list=attr_value_pb2.AttrValue.ListValue( 

395 s=[compat.as_bytes("loc:@%s" % true_name)])) 

396 # pylint: disable=protected-access 

397 with ops.get_default_graph()._attr_scope({"_class": attr}): 

398 with ops.name_scope("Initializer"), ops.device(None): 

399 initial_value = initial_value() 

400 if isinstance(initial_value, trackable.CheckpointInitialValue): 

401 self._maybe_initialize_trackable() 

402 self._update_uid = initial_value.checkpoint_position.restore_uid 

403 initial_value = initial_value.wrapped_value 

404 self._initial_value = ops.convert_to_tensor( 

405 initial_value, name="initial_value", dtype=dtype) 

406 if shape is None: 

407 shape = ( 

408 self._initial_value.get_shape() 

409 if validate_shape else tensor_shape.unknown_shape()) 

410 self._variable = state_ops.variable_op_v2( 

411 shape, self._initial_value.dtype.base_dtype, name=name) 

412 # pylint: enable=protected-access 

413 

414 # Or get the initial value from a Tensor or Python object. 

415 else: 

416 self._initial_value = ops.convert_to_tensor( 

417 initial_value, name="initial_value", dtype=dtype) 

418 # pylint: disable=protected-access 

419 if self._initial_value.op._get_control_flow_context() is not None: 

420 raise ValueError( 

421 "Initializer for variable %s is from inside a control-flow " 

422 "construct, such as a loop or conditional. When creating a " 

423 "variable inside a loop or conditional, use a lambda as the " 

424 "initializer." % name) 

425 if shape is None: 

426 # pylint: enable=protected-access 

427 shape = ( 

428 self._initial_value.get_shape() 

429 if validate_shape else tensor_shape.unknown_shape()) 

430 # In this case, the variable op can't be created until after the 

431 # initial_value has been converted to a Tensor with a known type. 

432 self._variable = state_ops.variable_op_v2( 

433 shape, self._initial_value.dtype.base_dtype, name=name) 

434 

435 # Cache the name in `self`, because some APIs call `Variable.name` in a 

436 # tight loop, and this halves the cost. 

437 self._name = self._variable.name 

438 

439 # Manually overrides the variable's shape with the initial value's. 

440 if validate_shape: 

441 initial_value_shape = self._initial_value.get_shape() 

442 if not initial_value_shape.is_fully_defined(): 

443 raise ValueError("initial_value must have a shape specified: %s" % 

444 self._initial_value) 

445 

446 # If 'initial_value' makes use of other variables, make sure we don't 

447 # have an issue if these other variables aren't initialized first by 

448 # using their initialized_value() method. 

449 self._initializer_op = state_ops.assign( 

450 self._variable, 

451 variables._try_guard_against_uninitialized_dependencies( # pylint: disable=protected-access 

452 name, self._initial_value), 

453 validate_shape=validate_shape).op 

454 

455 # TODO(vrv): Change this class to not take caching_device, but 

456 # to take the op to colocate the snapshot with, so we can use 

457 # colocation rather than devices. 

458 if caching_device is not None: 

459 with ops.device(caching_device): 

460 self._snapshot = array_ops.identity(self._variable, name="read") 

461 else: 

462 with ops.colocate_with(self._variable.op): 

463 self._snapshot = array_ops.identity(self._variable, name="read") 

464 ops.add_to_collections(collections, self) 

465 

466 self._caching_device = caching_device 

467 self._save_slice_info = None 

468 self._constraint = constraint 

469 

470 def _init_from_proto(self, variable_def, import_scope=None): 

471 """Recreates the Variable object from a `VariableDef` protocol buffer. 

472 

473 Args: 

474 variable_def: `VariableDef` protocol buffer, describing a variable whose 

475 nodes already exists in the graph. 

476 import_scope: Optional `string`. Name scope to add. 

477 """ 

478 assert isinstance(variable_def, variable_pb2.VariableDef) 

479 # Create from variable_def. 

480 g = ops.get_default_graph() 

481 self._variable = g.as_graph_element( 

482 ops.prepend_name_scope( 

483 variable_def.variable_name, import_scope=import_scope)) 

484 self._name = self._variable.name 

485 self._initializer_op = g.as_graph_element( 

486 ops.prepend_name_scope( 

487 variable_def.initializer_name, import_scope=import_scope)) 

488 # Tests whether initial_value_name exists first for backwards compatibility. 

489 if (hasattr(variable_def, "initial_value_name") and 

490 variable_def.initial_value_name): 

491 self._initial_value = g.as_graph_element( 

492 ops.prepend_name_scope( 

493 variable_def.initial_value_name, import_scope=import_scope)) 

494 else: 

495 self._initial_value = None 

496 synchronization, aggregation, trainable = ( 

497 variables.validate_synchronization_aggregation_trainable( 

498 variable_def.synchronization, variable_def.aggregation, 

499 variable_def.trainable, variable_def.variable_name)) 

500 self._synchronization = synchronization 

501 self._aggregation = aggregation 

502 self._trainable = trainable 

503 self._snapshot = g.as_graph_element( 

504 ops.prepend_name_scope( 

505 variable_def.snapshot_name, import_scope=import_scope)) 

506 if variable_def.HasField("save_slice_info_def"): 

507 self._save_slice_info = variables.Variable.SaveSliceInfo( 

508 save_slice_info_def=variable_def.save_slice_info_def, 

509 import_scope=import_scope) 

510 else: 

511 self._save_slice_info = None 

512 self._caching_device = None 

513 self._constraint = None 

514 

515 def _as_graph_element(self): 

516 """Conversion function for Graph.as_graph_element().""" 

517 return self._variable 

518 

519 def value(self): 

520 """Returns the last snapshot of this variable. 

521 

522 You usually do not need to call this method as all ops that need the value 

523 of the variable call it automatically through a `convert_to_tensor()` call. 

524 

525 Returns a `Tensor` which holds the value of the variable. You can not 

526 assign a new value to this tensor as it is not a reference to the variable. 

527 

528 To avoid copies, if the consumer of the returned value is on the same device 

529 as the variable, this actually returns the live value of the variable, not 

530 a copy. Updates to the variable are seen by the consumer. If the consumer 

531 is on a different device it will get a copy of the variable. 

532 

533 Returns: 

534 A `Tensor` containing the value of the variable. 

535 """ 

536 return self._snapshot 

537 

538 def read_value(self): 

539 """Returns the value of this variable, read in the current context. 

540 

541 Can be different from value() if it's on another device, with control 

542 dependencies, etc. 

543 

544 Returns: 

545 A `Tensor` containing the value of the variable. 

546 """ 

547 return array_ops.identity(self._variable, name="read") 

548 

549 def _ref(self): 

550 """Returns a reference to this variable. 

551 

552 You usually do not need to call this method as all ops that need a reference 

553 to the variable call it automatically. 

554 

555 Returns is a `Tensor` which holds a reference to the variable. You can 

556 assign a new value to the variable by passing the tensor to an assign op. 

557 See `tf.Variable.value` if you want to get the value of the 

558 variable. 

559 

560 Returns: 

561 A `Tensor` that is a reference to the variable. 

562 """ 

563 return self._variable 

564 

565 def set_shape(self, shape): 

566 """Overrides the shape for this variable. 

567 

568 Args: 

569 shape: the `TensorShape` representing the overridden shape. 

570 """ 

571 self._ref().set_shape(shape) 

572 self.value().set_shape(shape) 

573 

574 @property 

575 def trainable(self): 

576 return self._trainable 

577 

578 @property 

579 def synchronization(self): 

580 return self._synchronization 

581 

582 @property 

583 def aggregation(self): 

584 return self._aggregation 

585 

586 def eval(self, session=None): 

587 """In a session, computes and returns the value of this variable. 

588 

589 This is not a graph construction method, it does not add ops to the graph. 

590 

591 This convenience method requires a session where the graph 

592 containing this variable has been launched. If no session is 

593 passed, the default session is used. See `tf.compat.v1.Session` for more 

594 information on launching a graph and on sessions. 

595 

596 ```python 

597 v = tf.Variable([1, 2]) 

598 init = tf.compat.v1.global_variables_initializer() 

599 

600 with tf.compat.v1.Session() as sess: 

601 sess.run(init) 

602 # Usage passing the session explicitly. 

603 print(v.eval(sess)) 

604 # Usage with the default session. The 'with' block 

605 # above makes 'sess' the default session. 

606 print(v.eval()) 

607 ``` 

608 

609 Args: 

610 session: The session to use to evaluate this variable. If none, the 

611 default session is used. 

612 

613 Returns: 

614 A numpy `ndarray` with a copy of the value of this variable. 

615 """ 

616 return self._variable.eval(session=session) 

617 

618 @property 

619 def initial_value(self): 

620 """Returns the Tensor used as the initial value for the variable. 

621 

622 Note that this is different from `initialized_value()` which runs 

623 the op that initializes the variable before returning its value. 

624 This method returns the tensor that is used by the op that initializes 

625 the variable. 

626 

627 Returns: 

628 A `Tensor`. 

629 """ 

630 return self._initial_value 

631 

632 @property 

633 def constraint(self): 

634 """Returns the constraint function associated with this variable. 

635 

636 Returns: 

637 The constraint function that was passed to the variable constructor. 

638 Can be `None` if no constraint was passed. 

639 """ 

640 return self._constraint 

641 

642 def assign(self, value, use_locking=False, name=None, read_value=True): 

643 """Assigns a new value to the variable. 

644 

645 This is essentially a shortcut for `assign(self, value)`. 

646 

647 Args: 

648 value: A `Tensor`. The new value for this variable. 

649 use_locking: If `True`, use locking during the assignment. 

650 name: The name of the operation to be created 

651 read_value: if True, will return something which evaluates to the new 

652 value of the variable; if False will return the assign op. 

653 

654 Returns: 

655 A `Tensor` that will hold the new value of this variable after 

656 the assignment has completed. 

657 """ 

658 assign = state_ops.assign( 

659 self._variable, value, use_locking=use_locking, name=name) 

660 if read_value: 

661 return assign 

662 return assign.op 

663 

664 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 

665 """Adds a value to this variable. 

666 

667 This is essentially a shortcut for `assign_add(self, delta)`. 

668 

669 Args: 

670 delta: A `Tensor`. The value to add to this variable. 

671 use_locking: If `True`, use locking during the operation. 

672 name: The name of the operation to be created 

673 read_value: if True, will return something which evaluates to the new 

674 value of the variable; if False will return the assign op. 

675 

676 Returns: 

677 A `Tensor` that will hold the new value of this variable after 

678 the addition has completed. 

679 """ 

680 assign = state_ops.assign_add( 

681 self._variable, delta, use_locking=use_locking, name=name) 

682 if read_value: 

683 return assign 

684 return assign.op 

685 

686 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 

687 """Subtracts a value from this variable. 

688 

689 This is essentially a shortcut for `assign_sub(self, delta)`. 

690 

691 Args: 

692 delta: A `Tensor`. The value to subtract from this variable. 

693 use_locking: If `True`, use locking during the operation. 

694 name: The name of the operation to be created 

695 read_value: if True, will return something which evaluates to the new 

696 value of the variable; if False will return the assign op. 

697 

698 Returns: 

699 A `Tensor` that will hold the new value of this variable after 

700 the subtraction has completed. 

701 """ 

702 assign = state_ops.assign_sub( 

703 self._variable, delta, use_locking=use_locking, name=name) 

704 if read_value: 

705 return assign 

706 return assign.op 

707 

708 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

709 """Subtracts `tf.IndexedSlices` from this variable. 

710 

711 Args: 

712 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 

713 use_locking: If `True`, use locking during the operation. 

714 name: the name of the operation. 

715 

716 Returns: 

717 A `Tensor` that will hold the new value of this variable after 

718 the scattered subtraction has completed. 

719 

720 Raises: 

721 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

722 """ 

723 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

724 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 

725 return gen_state_ops.scatter_sub( 

726 self._variable, 

727 sparse_delta.indices, 

728 sparse_delta.values, 

729 use_locking=use_locking, 

730 name=name) 

731 

732 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

733 """Adds `tf.IndexedSlices` to this variable. 

734 

735 Args: 

736 sparse_delta: `tf.IndexedSlices` to be added to this variable. 

737 use_locking: If `True`, use locking during the operation. 

738 name: the name of the operation. 

739 

740 Returns: 

741 A `Tensor` that will hold the new value of this variable after 

742 the scattered addition has completed. 

743 

744 Raises: 

745 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

746 """ 

747 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

748 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 

749 return gen_state_ops.scatter_add( 

750 self._variable, 

751 sparse_delta.indices, 

752 sparse_delta.values, 

753 use_locking=use_locking, 

754 name=name) 

755 

756 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

757 """Updates this variable with the max of `tf.IndexedSlices` and itself. 

758 

759 Args: 

760 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 

761 variable. 

762 use_locking: If `True`, use locking during the operation. 

763 name: the name of the operation. 

764 

765 Returns: 

766 A `Tensor` that will hold the new value of this variable after 

767 the scattered maximization has completed. 

768 

769 Raises: 

770 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

771 """ 

772 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

773 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 

774 return gen_state_ops.scatter_max( 

775 self._variable, 

776 sparse_delta.indices, 

777 sparse_delta.values, 

778 use_locking=use_locking, 

779 name=name) 

780 

781 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

782 """Updates this variable with the min of `tf.IndexedSlices` and itself. 

783 

784 Args: 

785 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 

786 variable. 

787 use_locking: If `True`, use locking during the operation. 

788 name: the name of the operation. 

789 

790 Returns: 

791 A `Tensor` that will hold the new value of this variable after 

792 the scattered minimization has completed. 

793 

794 Raises: 

795 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

796 """ 

797 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

798 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 

799 return gen_state_ops.scatter_min( 

800 self._variable, 

801 sparse_delta.indices, 

802 sparse_delta.values, 

803 use_locking=use_locking, 

804 name=name) 

805 

806 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

807 """Multiply this variable by `tf.IndexedSlices`. 

808 

809 Args: 

810 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 

811 use_locking: If `True`, use locking during the operation. 

812 name: the name of the operation. 

813 

814 Returns: 

815 A `Tensor` that will hold the new value of this variable after 

816 the scattered multiplication has completed. 

817 

818 Raises: 

819 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

820 """ 

821 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

822 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 

823 return gen_state_ops.scatter_mul( 

824 self._variable, 

825 sparse_delta.indices, 

826 sparse_delta.values, 

827 use_locking=use_locking, 

828 name=name) 

829 

830 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

831 """Divide this variable by `tf.IndexedSlices`. 

832 

833 Args: 

834 sparse_delta: `tf.IndexedSlices` to divide this variable by. 

835 use_locking: If `True`, use locking during the operation. 

836 name: the name of the operation. 

837 

838 Returns: 

839 A `Tensor` that will hold the new value of this variable after 

840 the scattered division has completed. 

841 

842 Raises: 

843 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

844 """ 

845 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

846 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 

847 return gen_state_ops.scatter_div( 

848 self._variable, 

849 sparse_delta.indices, 

850 sparse_delta.values, 

851 use_locking=use_locking, 

852 name=name) 

853 

854 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

855 """Assigns `tf.IndexedSlices` to this variable. 

856 

857 Args: 

858 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 

859 use_locking: If `True`, use locking during the operation. 

860 name: the name of the operation. 

861 

862 Returns: 

863 A `Tensor` that will hold the new value of this variable after 

864 the scattered assignment has completed. 

865 

866 Raises: 

867 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

868 """ 

869 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 

870 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 

871 return gen_state_ops.scatter_update( 

872 self._variable, 

873 sparse_delta.indices, 

874 sparse_delta.values, 

875 use_locking=use_locking, 

876 name=name) 

877 

878 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 

879 """Assigns `tf.IndexedSlices` to this variable batch-wise. 

880 

881 Analogous to `batch_gather`. This assumes that this variable and the 

882 sparse_delta IndexedSlices have a series of leading dimensions that are the 

883 same for all of them, and the updates are performed on the last dimension of 

884 indices. In other words, the dimensions should be the following: 

885 

886 `num_prefix_dims = sparse_delta.indices.ndims - 1` 

887 `batch_dim = num_prefix_dims + 1` 

888 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 

889 batch_dim:]` 

890 

891 where 

892 

893 `sparse_delta.updates.shape[:num_prefix_dims]` 

894 `== sparse_delta.indices.shape[:num_prefix_dims]` 

895 `== var.shape[:num_prefix_dims]` 

896 

897 And the operation performed can be expressed as: 

898 

899 `var[i_1, ..., i_n, 

900 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 

901 i_1, ..., i_n, j]` 

902 

903 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 

904 `scatter_update`. 

905 

906 To avoid this operation one can looping over the first `ndims` of the 

907 variable and using `scatter_update` on the subtensors that result of slicing 

908 the first dimension. This is a valid option for `ndims = 1`, but less 

909 efficient than this implementation. 

910 

911 Args: 

912 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 

913 use_locking: If `True`, use locking during the operation. 

914 name: the name of the operation. 

915 

916 Returns: 

917 A `Tensor` that will hold the new value of this variable after 

918 the scattered assignment has completed. 

919 

920 Raises: 

921 TypeError: if `sparse_delta` is not an `IndexedSlices`. 

922 """ 

923 return state_ops.batch_scatter_update( 

924 self, 

925 sparse_delta.indices, 

926 sparse_delta.values, 

927 use_locking=use_locking, 

928 name=name) 

929 

930 def scatter_nd_sub(self, indices, updates, name=None): 

931 """Applies sparse subtraction to individual values or slices in a Variable. 

932 

933 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

934 

935 `indices` must be integer tensor, containing indices into `ref`. 

936 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

937 

938 The innermost dimension of `indices` (with length `K`) corresponds to 

939 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

940 dimension of `ref`. 

941 

942 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

943 

944 ``` 

945 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

946 ``` 

947 

948 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

949 8 elements. In Python, that update would look like this: 

950 

951 ```python 

952 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

953 indices = tf.constant([[4], [3], [1] ,[7]]) 

954 updates = tf.constant([9, 10, 11, 12]) 

955 op = ref.scatter_nd_sub(indices, updates) 

956 with tf.compat.v1.Session() as sess: 

957 print sess.run(op) 

958 ``` 

959 

960 The resulting update to ref would look like this: 

961 

962 [1, -9, 3, -6, -6, 6, 7, -4] 

963 

964 See `tf.scatter_nd` for more details about how to make updates to 

965 slices. 

966 

967 Args: 

968 indices: The indices to be used in the operation. 

969 updates: The values to be used in the operation. 

970 name: the name of the operation. 

971 

972 Returns: 

973 A `Tensor` that will hold the new value of this variable after 

974 the scattered subtraction has completed. 

975 """ 

976 return gen_state_ops.scatter_nd_sub( 

977 self._variable, indices, updates, use_locking=True, name=name) 

978 

979 def scatter_nd_add(self, indices, updates, name=None): 

980 """Applies sparse addition to individual values or slices in a Variable. 

981 

982 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

983 

984 `indices` must be integer tensor, containing indices into `ref`. 

985 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

986 

987 The innermost dimension of `indices` (with length `K`) corresponds to 

988 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

989 dimension of `ref`. 

990 

991 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

992 

993 ``` 

994 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

995 ``` 

996 

997 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

998 8 elements. In Python, that update would look like this: 

999 

1000 ```python 

1001 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

1002 indices = tf.constant([[4], [3], [1] ,[7]]) 

1003 updates = tf.constant([9, 10, 11, 12]) 

1004 add = ref.scatter_nd_add(indices, updates) 

1005 with tf.compat.v1.Session() as sess: 

1006 print sess.run(add) 

1007 ``` 

1008 

1009 The resulting update to ref would look like this: 

1010 

1011 [1, 13, 3, 14, 14, 6, 7, 20] 

1012 

1013 See `tf.scatter_nd` for more details about how to make updates to 

1014 slices. 

1015 

1016 Args: 

1017 indices: The indices to be used in the operation. 

1018 updates: The values to be used in the operation. 

1019 name: the name of the operation. 

1020 

1021 Returns: 

1022 A `Tensor` that will hold the new value of this variable after 

1023 the scattered addition has completed. 

1024 """ 

1025 return gen_state_ops.scatter_nd_add( 

1026 self._variable, indices, updates, use_locking=True, name=name) 

1027 

1028 def scatter_nd_update(self, indices, updates, name=None): 

1029 """Applies sparse assignment to individual values or slices in a Variable. 

1030 

1031 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

1032 

1033 `indices` must be integer tensor, containing indices into `ref`. 

1034 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

1035 

1036 The innermost dimension of `indices` (with length `K`) corresponds to 

1037 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

1038 dimension of `ref`. 

1039 

1040 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

1041 

1042 ``` 

1043 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

1044 ``` 

1045 

1046 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

1047 8 elements. In Python, that update would look like this: 

1048 

1049 ```python 

1050 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

1051 indices = tf.constant([[4], [3], [1] ,[7]]) 

1052 updates = tf.constant([9, 10, 11, 12]) 

1053 op = ref.scatter_nd_update(indices, updates) 

1054 with tf.compat.v1.Session() as sess: 

1055 print sess.run(op) 

1056 ``` 

1057 

1058 The resulting update to ref would look like this: 

1059 

1060 [1, 11, 3, 10, 9, 6, 7, 12] 

1061 

1062 See `tf.scatter_nd` for more details about how to make updates to 

1063 slices. 

1064 

1065 Args: 

1066 indices: The indices to be used in the operation. 

1067 updates: The values to be used in the operation. 

1068 name: the name of the operation. 

1069 

1070 Returns: 

1071 A `Tensor` that will hold the new value of this variable after 

1072 the scattered assignment has completed. 

1073 """ 

1074 return gen_state_ops.scatter_nd_update( 

1075 self._variable, indices, updates, use_locking=True, name=name) 

1076 

1077 def scatter_nd_max(self, indices, updates, name=None): 

1078 """Updates this variable with the max of `tf.IndexedSlices` and itself. 

1079 

1080 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

1081 

1082 `indices` must be integer tensor, containing indices into `ref`. 

1083 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

1084 

1085 The innermost dimension of `indices` (with length `K`) corresponds to 

1086 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

1087 dimension of `ref`. 

1088 

1089 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

1090 

1091 ``` 

1092 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

1093 ``` 

1094 

1095 See `tf.scatter_nd` for more details about how to make updates to 

1096 slices. 

1097 

1098 Args: 

1099 indices: The indices to be used in the operation. 

1100 updates: The values to be used in the operation. 

1101 name: the name of the operation. 

1102 

1103 Returns: 

1104 A `Tensor` that will hold the new value of this variable after 

1105 the scattered addition has completed. 

1106 """ 

1107 return gen_state_ops.scatter_nd_max( 

1108 self._variable, indices, updates, use_locking=True, name=name) 

1109 

1110 def scatter_nd_min(self, indices, updates, name=None): 

1111 """Updates this variable with the min of `tf.IndexedSlices` and itself. 

1112 

1113 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

1114 

1115 `indices` must be integer tensor, containing indices into `ref`. 

1116 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

1117 

1118 The innermost dimension of `indices` (with length `K`) corresponds to 

1119 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

1120 dimension of `ref`. 

1121 

1122 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

1123 

1124 ``` 

1125 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

1126 ``` 

1127 

1128 See `tf.scatter_nd` for more details about how to make updates to 

1129 slices. 

1130 

1131 Args: 

1132 indices: The indices to be used in the operation. 

1133 updates: The values to be used in the operation. 

1134 name: the name of the operation. 

1135 

1136 Returns: 

1137 A `Tensor` that will hold the new value of this variable after 

1138 the scattered addition has completed. 

1139 """ 

1140 return gen_state_ops.scatter_nd_min( 

1141 self._variable, indices, updates, use_locking=True, name=name) 

1142 

1143 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 

1144 end_mask, ellipsis_mask, new_axis_mask, 

1145 shrink_axis_mask): 

1146 return gen_array_ops.strided_slice_assign( 

1147 ref=self._ref(), 

1148 begin=begin, 

1149 end=end, 

1150 strides=strides, 

1151 value=value, 

1152 name=name, 

1153 begin_mask=begin_mask, 

1154 end_mask=end_mask, 

1155 ellipsis_mask=ellipsis_mask, 

1156 new_axis_mask=new_axis_mask, 

1157 shrink_axis_mask=shrink_axis_mask) 

1158 

1159 @deprecated(None, "Prefer Dataset.range instead.") 

1160 def count_up_to(self, limit): 

1161 """Increments this variable until it reaches `limit`. 

1162 

1163 When that Op is run it tries to increment the variable by `1`. If 

1164 incrementing the variable would bring it above `limit` then the Op raises 

1165 the exception `OutOfRangeError`. 

1166 

1167 If no error is raised, the Op outputs the value of the variable before 

1168 the increment. 

1169 

1170 This is essentially a shortcut for `count_up_to(self, limit)`. 

1171 

1172 Args: 

1173 limit: value at which incrementing the variable raises an error. 

1174 

1175 Returns: 

1176 A `Tensor` that will hold the variable value before the increment. If no 

1177 other Op modifies this variable, the values produced will all be 

1178 distinct. 

1179 """ 

1180 return state_ops.count_up_to(self._variable, limit=limit) 

1181 

1182 # Conversion to tensor. 

1183 @staticmethod 

1184 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 

1185 """Utility function for converting a Variable to a Tensor.""" 

1186 _ = name 

1187 if dtype and not dtype.is_compatible_with(v.dtype): 

1188 raise ValueError( 

1189 "Incompatible type conversion requested to type '%s' for variable " 

1190 "of type '%s'" % (dtype.name, v.dtype.name)) 

1191 if as_ref: 

1192 return v._ref() # pylint: disable=protected-access 

1193 else: 

1194 return v.value() 

1195 

1196 # NOTE(mrry): This enables the Variable's overloaded "right" binary 

1197 # operators to run when the left operand is an ndarray, because it 

1198 # accords the Variable class higher priority than an ndarray, or a 

1199 # numpy matrix. 

1200 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 

1201 # mechanism, which allows more control over how Variables interact 

1202 # with ndarrays. 

1203 __array_priority__ = 100 

1204 

1205 @property 

1206 def name(self): 

1207 """The name of this variable.""" 

1208 return self._name 

1209 

1210 @property 

1211 def initializer(self): 

1212 """The initializer operation for this variable.""" 

1213 return self._initializer_op 

1214 

1215 @property 

1216 def device(self): 

1217 """The device of this variable.""" 

1218 return self._variable.device 

1219 

1220 @property 

1221 def dtype(self): 

1222 """The `DType` of this variable.""" 

1223 return self._variable.dtype 

1224 

1225 @property 

1226 def op(self): 

1227 """The `Operation` of this variable.""" 

1228 return self._variable.op 

1229 

1230 @property 

1231 def graph(self): 

1232 """The `Graph` of this variable.""" 

1233 return self._variable.graph 

1234 

1235 @property 

1236 def _distribute_strategy(self): 

1237 """The `tf.distribute.Strategy` that this variable was created under.""" 

1238 return None # Ref variables are never created inside a strategy. 

1239 

1240 @property 

1241 def shape(self): 

1242 """The `TensorShape` of this variable. 

1243 

1244 Returns: 

1245 A `TensorShape`. 

1246 """ 

1247 return self._variable.get_shape() 

1248 

1249 def to_proto(self, export_scope=None): 

1250 """Converts a `Variable` to a `VariableDef` protocol buffer. 

1251 

1252 Args: 

1253 export_scope: Optional `string`. Name scope to remove. 

1254 

1255 Returns: 

1256 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 

1257 in the specified name scope. 

1258 """ 

1259 if (export_scope is None or self._variable.name.startswith(export_scope)): 

1260 var_def = variable_pb2.VariableDef() 

1261 var_def.variable_name = ops.strip_name_scope(self._variable.name, 

1262 export_scope) 

1263 if self._initial_value is not None: 

1264 # For backwards compatibility. 

1265 var_def.initial_value_name = ops.strip_name_scope( 

1266 self._initial_value.name, export_scope) 

1267 var_def.trainable = self.trainable 

1268 var_def.synchronization = self.synchronization.value 

1269 var_def.aggregation = self.aggregation.value 

1270 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 

1271 export_scope) 

1272 var_def.snapshot_name = ops.strip_name_scope(self._snapshot.name, 

1273 export_scope) 

1274 if self._save_slice_info: 

1275 var_def.save_slice_info_def.MergeFrom( 

1276 self._save_slice_info.to_proto(export_scope=export_scope)) 

1277 return var_def 

1278 else: 

1279 return None 

1280 

1281 def __iadd__(self, other): 

1282 logging.log_first_n( 

1283 logging.WARN, "Variable += will be deprecated. Use variable.assign_add" 

1284 " if you want assignment to the variable value or 'x = x + y'" 

1285 " if you want a new python Tensor object.", 1) 

1286 return self + other 

1287 

1288 def __isub__(self, other): 

1289 logging.log_first_n( 

1290 logging.WARN, "Variable -= will be deprecated. Use variable.assign_sub" 

1291 " if you want assignment to the variable value or 'x = x - y'" 

1292 " if you want a new python Tensor object.", 1) 

1293 return self - other 

1294 

1295 def __imul__(self, other): 

1296 logging.log_first_n( 

1297 logging.WARN, 

1298 "Variable *= will be deprecated. Use `var.assign(var * other)`" 

1299 " if you want assignment to the variable value or `x = x * y`" 

1300 " if you want a new python Tensor object.", 1) 

1301 return self * other 

1302 

1303 def __idiv__(self, other): 

1304 logging.log_first_n( 

1305 logging.WARN, 

1306 "Variable /= will be deprecated. Use `var.assign(var / other)`" 

1307 " if you want assignment to the variable value or `x = x / y`" 

1308 " if you want a new python Tensor object.", 1) 

1309 return self / other 

1310 

1311 def __itruediv__(self, other): 

1312 logging.log_first_n( 

1313 logging.WARN, 

1314 "Variable /= will be deprecated. Use `var.assign(var / other)`" 

1315 " if you want assignment to the variable value or `x = x / y`" 

1316 " if you want a new python Tensor object.", 1) 

1317 return self / other 

1318 

1319 def __irealdiv__(self, other): 

1320 logging.log_first_n( 

1321 logging.WARN, 

1322 "Variable /= will be deprecated. Use `var.assign(var / other)`" 

1323 " if you want assignment to the variable value or `x = x / y`" 

1324 " if you want a new python Tensor object.", 1) 

1325 return self / other 

1326 

1327 def __ipow__(self, other): 

1328 logging.log_first_n( 

1329 logging.WARN, 

1330 "Variable **= will be deprecated. Use `var.assign(var ** other)`" 

1331 " if you want assignment to the variable value or `x = x ** y`" 

1332 " if you want a new python Tensor object.", 1) 

1333 return self**other 

1334 

1335 def _serialize_to_tensors(self): 

1336 """Implements Trackable._serialize_to_tensors.""" 

1337 return {trackable.VARIABLE_VALUE_KEY: self} 

1338 

1339 def _restore_from_tensors(self, restored_tensors): 

1340 """Implements Trackable._restore_from_tensors.""" 

1341 restored_tensor = restored_tensors[trackable.VARIABLE_VALUE_KEY] 

1342 return state_ops.assign( 

1343 self, 

1344 restored_tensor, 

1345 validate_shape=self.get_shape().is_fully_defined()) 

1346 

1347 

1348# Register a conversion function which reads the value of the variable, 

1349# allowing instances of the class to be used as tensors. 

1350tensor_conversion_registry.register_tensor_conversion_function( 

1351 RefVariable, RefVariable._TensorConversionFunction) # pylint: disable=protected-access 

1352 

1353 

1354variable_v1.set_variable_from_proto_fn(RefVariable)