Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/base.py: 57%

237 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""An object-local variable management scheme.""" 

2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16import collections 

17import weakref 

18 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.ops import gen_control_flow_ops 

23from tensorflow.python.trackable import constants 

24from tensorflow.python.training.saving import saveable_object 

25from tensorflow.python.util import tf_contextlib 

26from tensorflow.python.util import tf_decorator 

27from tensorflow.python.util.tf_export import tf_export 

28 

29OBJECT_GRAPH_PROTO_KEY = constants.OBJECT_GRAPH_PROTO_KEY 

30VARIABLE_VALUE_KEY = constants.VARIABLE_VALUE_KEY 

31OBJECT_CONFIG_JSON_KEY = constants.OBJECT_CONFIG_JSON_KEY 

32SaveType = constants.SaveType 

33 

34 

35@tf_export("__internal__.tracking.TrackableReference", v1=[]) 

36class TrackableReference(object): 

37 """A named reference to a trackable object for use with the `Trackable` class. 

38 

39 These references mark named `Trackable` dependencies of a `Trackable` object 

40 and should be created when overriding `Trackable._checkpoint_dependencies`. 

41 

42 Attributes: 

43 name: The local name for this dependency. 

44 ref: The `Trackable` object being referenced. 

45 """ 

46 

47 __slots__ = ("_name", "_ref") 

48 

49 def __init__(self, name, ref): 

50 self._name = name 

51 self._ref = ref 

52 

53 @property 

54 def name(self): 

55 return self._name 

56 

57 @property 

58 def ref(self): 

59 return self._ref 

60 

61 def __iter__(self): 

62 yield self.name 

63 yield self.ref 

64 

65 def __repr__(self): 

66 return f"{self.__class__.__name__}(name={self.name}, ref={self.ref})" 

67 

68 def __eq__(self, o): 

69 if isinstance(o, tuple): 

70 return (self.name, self.ref) == o 

71 elif isinstance(o, TrackableReference): 

72 return self.name == o.name and self.ref == o.ref 

73 else: 

74 return False 

75 

76 

77class WeakTrackableReference(TrackableReference): 

78 """TrackableReference that stores weak references.""" 

79 __slots__ = () 

80 

81 def __init__(self, name, ref): 

82 if not isinstance(self, weakref.ref): 

83 ref = weakref.ref(ref) 

84 super(WeakTrackableReference, self).__init__(name=name, ref=ref) 

85 

86 @property 

87 def ref(self): 

88 return self._ref() 

89 

90 

91# TODO(bfontain): Update once sharded initialization interface is finalized. 

92ShardInfo = collections.namedtuple("CheckpointInitialValueShardInfo", 

93 ["shape", "offset"]) 

94 

95 

96@tf_export("__internal__.tracking.CheckpointInitialValueCallable", v1=[]) 

97class CheckpointInitialValueCallable(object): 

98 """A callable object that returns a CheckpointInitialValue. 

99 

100 See CheckpointInitialValue for more information. 

101 """ 

102 

103 def __init__(self, checkpoint_position): 

104 self._checkpoint_position = checkpoint_position 

105 

106 @property 

107 def checkpoint_position(self): 

108 return self._checkpoint_position 

109 

110 def __call__(self, shape=None, dtype=None, shard_info=None): 

111 # Note that the signature here is for compatibility with normal callable 

112 # initializers which take shape and dtype. Although dtype isn't used, it 

113 # will get passed in by a functool.partial_wrapper in places like 

114 # base_layer_utils.py's make_variable. 

115 return CheckpointInitialValue( 

116 self._checkpoint_position, shape, shard_info=shard_info) 

117 

118 @property 

119 def restore_uid(self): 

120 return self._checkpoint_position.restore_uid 

121 

122 

123@tf_export("__internal__.tracking.CheckpointInitialValue", v1=[]) 

124class CheckpointInitialValue(object): 

125 """Tensor wrapper for managing update UIDs in `Variables`. 

126 

127 When supplied as an initial value, objects of this type let a `Variable` 

128 (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial 

129 value came from. This allows deferred restorations to be sequenced in the 

130 order the user specified them, and lets us fall back on assignment if an 

131 initial value is not set (e.g. due to a custom getter interfering). 

132 

133 See comments in _add_variable_with_custom_getter for more information about 

134 how `CheckpointInitialValue` is used. 

135 """ 

136 

137 def __init__(self, checkpoint_position, shape=None, shard_info=None): 

138 if shard_info: 

139 full_shape_str = " ".join("%d" % d for d in shape) + " " 

140 slice_spec = ":".join( 

141 "%d,%d" % (o, s) for o, s in zip(shard_info.offset, shard_info.shape)) 

142 shape_and_slice = full_shape_str + slice_spec 

143 else: 

144 shape_and_slice = "" 

145 self.wrapped_value = checkpoint_position.value_tensors( 

146 {VARIABLE_VALUE_KEY: shape_and_slice})[VARIABLE_VALUE_KEY] 

147 self._checkpoint_position = checkpoint_position 

148 

149 def __tf_tensor__(self, dtype=None, name=None): 

150 del dtype 

151 del name 

152 return self.wrapped_value 

153 

154 @property 

155 def checkpoint_position(self): 

156 return self._checkpoint_position 

157 

158 

159class NoRestoreSaveable(saveable_object.SaveableObject): 

160 """Embeds a tensor in a checkpoint with no restore ops.""" 

161 

162 def __init__(self, tensor, name, dtype=None, device=None): 

163 spec = saveable_object.SaveSpec( 

164 tensor, "", name, dtype=dtype, device=device) 

165 super(NoRestoreSaveable, self).__init__(tensor, [spec], name) 

166 

167 def restore(self, restored_tensors, restored_shapes): 

168 return gen_control_flow_ops.no_op() 

169 

170 

171_SlotVariableRestoration = collections.namedtuple( 

172 "_SlotVariableRestoration", 

173 [ 

174 # The checkpoint proto id of the optimizer object. 

175 "optimizer_id", 

176 # The checkpoint proto id of the slot variable. 

177 "slot_variable_id", 

178 "slot_name", 

179 ]) 

180 

181 

182@tf_export("__internal__.tracking.no_automatic_dependency_tracking", v1=[]) 

183def no_automatic_dependency_tracking(method): 

184 """Disables automatic dependency tracking on attribute assignment. 

185 

186 Use to decorate any method of a Trackable object. Attribute assignment in 

187 that method will not add dependencies (also respected in Model). Harmless if 

188 used in a class which does not do automatic dependency tracking (which means 

189 it's safe to use in base classes which may have subclasses which also inherit 

190 from Trackable). 

191 

192 Args: 

193 method: The method to decorate. 

194 

195 Returns: 

196 A decorated method which sets and un-sets automatic dependency tracking for 

197 the object the method is called on (not thread safe). 

198 """ 

199 

200 def _method_wrapper(self, *args, **kwargs): 

201 previous_value = getattr(self, "_self_setattr_tracking", True) 

202 self._self_setattr_tracking = False # pylint: disable=protected-access 

203 try: 

204 result = method(self, *args, **kwargs) 

205 finally: 

206 self._self_setattr_tracking = previous_value # pylint: disable=protected-access 

207 return result 

208 

209 return tf_decorator.make_decorator( 

210 target=method, decorator_func=_method_wrapper) 

211 

212 

213@tf_contextlib.contextmanager 

214def no_manual_dependency_tracking_scope(obj): 

215 """A context that disables manual dependency tracking for the given `obj`. 

216 

217 Sometimes library methods might track objects on their own and we might want 

218 to disable that and do the tracking on our own. One can then use this context 

219 manager to disable the tracking the library method does and do your own 

220 tracking. 

221 

222 For example: 

223 

224 class TestLayer(tf.keras.Layer): 

225 def build(): 

226 with no_manual_dependency_tracking_scope(self): 

227 var = self.add_variable("name1") # Creates a var and doesn't track it 

228 self._track_trackable("name2", var) # We track variable with name `name2` 

229 

230 Args: 

231 obj: A trackable object. 

232 

233 Yields: 

234 a scope in which the object doesn't track dependencies manually. 

235 """ 

236 # pylint: disable=protected-access 

237 previous_value = getattr(obj, "_manual_tracking", True) 

238 obj._manual_tracking = False 

239 try: 

240 yield 

241 finally: 

242 obj._manual_tracking = previous_value 

243 

244 

245@tf_contextlib.contextmanager 

246def no_automatic_dependency_tracking_scope(obj): 

247 """A context that disables automatic dependency tracking when assigning attrs. 

248 

249 Objects that inherit from Autotrackable automatically creates dependencies 

250 to trackable objects through attribute assignments, and wraps data structures 

251 (lists or dicts) with trackable classes. This scope may be used to temporarily 

252 disable this behavior. This works similar to the decorator 

253 `no_automatic_dependency_tracking`. 

254 

255 Example usage: 

256 ``` 

257 model = tf.keras.Model() 

258 model.arr1 = [] # Creates a ListWrapper object 

259 with no_automatic_dependency_tracking_scope(model): 

260 model.arr2 = [] # Creates a regular, untracked python list 

261 ``` 

262 

263 Args: 

264 obj: A trackable object. 

265 

266 Yields: 

267 a scope in which the object doesn't track dependencies. 

268 """ 

269 previous_value = getattr(obj, "_setattr_tracking", True) 

270 obj._setattr_tracking = False # pylint: disable=protected-access 

271 try: 

272 yield 

273 finally: 

274 obj._setattr_tracking = previous_value # pylint: disable=protected-access 

275 

276 

277@tf_export("__internal__.tracking.Trackable", v1=[]) 

278class Trackable(object): 

279 """Base class for `Trackable` objects without automatic dependencies. 

280 

281 This class has no __setattr__ override for performance reasons. Dependencies 

282 must be added explicitly. Unless attribute assignment is performance-critical, 

283 use `AutoTrackable` instead. Use `Trackable` for `isinstance` 

284 checks. 

285 """ 

286 

287 # For compatibility with wrapt.ObjectProxy, attributes are all prefixed with 

288 # _self_. We have some properties to forward semi-public attributes to their 

289 # _self_ equivalents. 

290 

291 @property 

292 def _setattr_tracking(self): 

293 if not hasattr(self, "_self_setattr_tracking"): 

294 self._self_setattr_tracking = True 

295 return self._self_setattr_tracking 

296 

297 @_setattr_tracking.setter 

298 def _setattr_tracking(self, value): 

299 self._self_setattr_tracking = value 

300 

301 @property 

302 def _update_uid(self): 

303 return self._self_update_uid 

304 

305 @_update_uid.setter 

306 def _update_uid(self, value): 

307 self._self_update_uid = value 

308 

309 @property 

310 def _unconditional_checkpoint_dependencies(self): 

311 return self._self_unconditional_checkpoint_dependencies 

312 

313 @property 

314 def _unconditional_dependency_names(self): 

315 return self._self_unconditional_dependency_names 

316 

317 @property 

318 def _name_based_restores(self): 

319 return self._self_name_based_restores 

320 

321 # Trackable does not do automatic dependency tracking, but uses the 

322 # no_automatic_dependency_tracking decorator so it can avoid adding 

323 # dependencies if a subclass is Trackable / inherits from Model (both of 

324 # which have __setattr__ overrides). 

325 @no_automatic_dependency_tracking 

326 def _maybe_initialize_trackable(self): 

327 """Initialize dependency management. 

328 

329 Not __init__, since most objects will forget to call it. 

330 """ 

331 if hasattr(self, "_self_unconditional_checkpoint_dependencies"): 

332 # __init__ already called. This check means that we don't need 

333 # Trackable.__init__() in the constructor of every TensorFlow object. 

334 return 

335 # A list of TrackableReference objects. Some classes implementing 

336 # `Trackable`, notably `Optimizer`s, may override the 

337 # _checkpoint_dependencies property with conditional dependencies 

338 # (e.g. based on the current graph when saving). 

339 self._self_unconditional_checkpoint_dependencies = [] 

340 # Maps names -> Trackable objects 

341 self._self_unconditional_dependency_names = {} 

342 # Restorations for other Trackable objects on which this object may 

343 # eventually depend. Maps local name -> CheckpointPosition list. Optimizers 

344 # tack on conditional dependencies, and so need separate management of 

345 # deferred dependencies too. 

346 self._self_unconditional_deferred_dependencies = {} 

347 # The UID of the highest assignment to this object. Used to ensure that the 

348 # last requested assignment determines the final value of an object. 

349 if hasattr(self, "_self_update_uid"): 

350 raise AssertionError( 

351 "Internal error: the object had an update UID set before its " 

352 "initialization code was run.") 

353 self._self_update_uid = -1 

354 # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator 

355 # instances, which should be checked when creating variables or other 

356 # saveables. These are passed on recursively to all dependencies, since 

357 # unlike object-based checkpoint restores we don't know which subgraph is 

358 # being restored in advance. This mechanism is only necessary for 

359 # restore-on-create when executing eagerly, and so is unused when graph 

360 # building. 

361 self._self_name_based_restores = set() 

362 

363 # Dictionary of SaveableObjects factories. This dictionary is defined when 

364 # the object is loaded from the SavedModel. When writing a custom class, 

365 # prefer overriding "_gather_saveables_from_checkpoint" to using this 

366 # attribute. 

367 self._self_saveable_object_factories = {} 

368 

369 @property 

370 def _object_identifier(self): 

371 """String used to identify this object in a SavedModel. 

372 

373 THIS FIELD HAS BEEN DEPRECATED IN FAVOR OF THE NAME REGISTERED WITH 

374 `register_serializable`. 

375 

376 Generally, the object identifier is constant across objects of the same 

377 class, while the metadata field is used for instance-specific data. 

378 

379 Returns: 

380 String object identifier. 

381 """ 

382 return "_generic_user_object" 

383 

384 def _no_dependency(self, value): 

385 """If automatic dependency tracking is enabled, ignores `value`.""" 

386 return value 

387 

388 def _name_based_attribute_restore(self, checkpoint): 

389 """Restore the object's attributes from a name-based checkpoint.""" 

390 self._self_name_based_restores.add(checkpoint) 

391 if self._self_update_uid < checkpoint.restore_uid: 

392 checkpoint.eager_restore(self) 

393 self._self_update_uid = checkpoint.restore_uid 

394 

395 @property 

396 def _checkpoint_dependencies(self): 

397 """All dependencies of this object. 

398 

399 May be overridden to include conditional dependencies. 

400 

401 Returns: 

402 A list of `TrackableReference` objects indicating named 

403 `Trackable` dependencies which should be saved along with this 

404 object. 

405 """ 

406 return self._self_unconditional_checkpoint_dependencies 

407 

408 @property 

409 def _deferred_dependencies(self): 

410 """A dictionary with deferred dependencies. 

411 

412 Stores restorations for other Trackable objects on which this object 

413 may eventually depend. May be overridden by sub-classes (e.g. Optimizers use 

414 conditional dependencies based the current graph, and so need separate 

415 management of deferred dependencies too). 

416 

417 Returns: 

418 A dictionary mapping from local name to a list of CheckpointPosition 

419 objects. 

420 """ 

421 return self._self_unconditional_deferred_dependencies 

422 

423 def _lookup_dependency(self, name): 

424 """Look up a dependency by name. 

425 

426 May be overridden to include conditional dependencies. 

427 

428 Args: 

429 name: The local name of the dependency. 

430 

431 Returns: 

432 A `Trackable` object, or `None` if no dependency by this name was 

433 found. 

434 """ 

435 return self._self_unconditional_dependency_names.get(name, None) 

436 

437 def _add_variable_with_custom_getter(self, 

438 name, 

439 shape=None, 

440 dtype=dtypes.float32, 

441 initializer=None, 

442 getter=None, 

443 overwrite=False, 

444 **kwargs_for_getter): 

445 """Restore-on-create for a variable be saved with this `Trackable`. 

446 

447 If the user has requested that this object or another `Trackable` which 

448 depends on this object be restored from a checkpoint (deferred loading 

449 before variable object creation), `initializer` may be ignored and the value 

450 from the checkpoint used instead. 

451 

452 Args: 

453 name: A name for the variable. Must be unique within this object. 

454 shape: The shape of the variable. 

455 dtype: The data type of the variable. 

456 initializer: The initializer to use. Ignored if there is a deferred 

457 restoration stored in the Trackable. 

458 getter: The getter to wrap which actually fetches the variable. 

459 overwrite: If True, disables unique name and type checks. 

460 **kwargs_for_getter: Passed to the getter. 

461 

462 Returns: 

463 The new variable object. 

464 

465 Raises: 

466 ValueError: If the variable name is not unique. 

467 """ 

468 self._maybe_initialize_trackable() 

469 with ops.init_scope(): 

470 if context.executing_eagerly(): 

471 # If this is a variable with a single Tensor stored in the checkpoint, 

472 # we can set that value as an initializer rather than initializing and 

473 # then assigning (when executing eagerly). This call returns None if 

474 # there is nothing to restore. 

475 checkpoint_initializer = self._preload_simple_restoration(name=name) 

476 else: 

477 checkpoint_initializer = None 

478 if (checkpoint_initializer is not None and 

479 not (isinstance(initializer, CheckpointInitialValueCallable) and 

480 (initializer.restore_uid > checkpoint_initializer.restore_uid))): 

481 # If multiple Trackable objects are "creating" the same variable 

482 # via the magic of custom getters, the one with the highest restore UID 

483 # (the one called last) has to make the final initializer. If another 

484 # custom getter interrupts this process by overwriting the initializer, 

485 # then we'll catch that when we call _track_trackable. So this is 

486 # "best effort" to set the initializer with the highest restore UID. 

487 initializer = checkpoint_initializer 

488 new_variable = getter( 

489 name=name, 

490 shape=shape, 

491 dtype=dtype, 

492 initializer=initializer, 

493 **kwargs_for_getter) 

494 

495 # If we set an initializer and the variable processed it, tracking will not 

496 # assign again. It will add this variable to our dependencies, and if there 

497 # is a non-trivial restoration queued, it will handle that. This also 

498 # handles slot variables. 

499 if not overwrite or isinstance(new_variable, Trackable): 

500 return self._track_trackable(new_variable, name=name, overwrite=overwrite) 

501 else: 

502 # TODO(allenl): Some variable types are not yet supported. Remove this 

503 # fallback once all get_variable() return types are Trackable. 

504 return new_variable 

505 

506 def _preload_simple_restoration(self, name): 

507 """Return a dependency's value for restore-on-create. 

508 

509 Note the restoration is not deleted; if for some reason preload is called 

510 and then not assigned to the variable (for example because a custom getter 

511 overrides the initializer), the assignment will still happen once the 

512 variable is tracked (determined based on checkpoint.restore_uid). 

513 

514 Args: 

515 name: The object-local name of the dependency holding the variable's 

516 value. 

517 

518 Returns: 

519 An callable for use as a variable's initializer/initial_value, or None if 

520 one should not be set (either because there was no variable with this name 

521 in the checkpoint or because it needs more complex deserialization). Any 

522 non-trivial deserialization will happen when the variable object is 

523 tracked. 

524 """ 

525 deferred_dependencies_list = self._deferred_dependencies.get(name, ()) 

526 if not deferred_dependencies_list: 

527 # Nothing to do; we don't have a restore for this dependency queued up. 

528 return 

529 for checkpoint_position in deferred_dependencies_list: 

530 if not checkpoint_position.is_simple_variable(): 

531 # If _any_ pending restoration is too complicated to fit in an 

532 # initializer (because it has dependencies, or because there are 

533 # multiple Tensors to restore), bail and let the general tracking code 

534 # handle it. 

535 return None 

536 checkpoint_position = max( 

537 deferred_dependencies_list, 

538 key=lambda restore: restore.checkpoint.restore_uid) 

539 return CheckpointInitialValueCallable( 

540 checkpoint_position=checkpoint_position) 

541 

542 def _track_trackable(self, trackable, name, overwrite=False): 

543 """Declare a dependency on another `Trackable` object. 

544 

545 Indicates that checkpoints for this object should include variables from 

546 `trackable`. 

547 

548 Variables in a checkpoint are mapped to `Trackable`s based on the names 

549 provided when the checkpoint was written. To avoid breaking existing 

550 checkpoints when modifying a class, neither variable names nor dependency 

551 names (the names passed to `_track_trackable`) may change. 

552 

553 Args: 

554 trackable: A `Trackable` which this object depends on. 

555 name: A local name for `trackable`, used for loading checkpoints into the 

556 correct objects. 

557 overwrite: Boolean, whether silently replacing dependencies is OK. Used 

558 for __setattr__, where throwing an error on attribute reassignment would 

559 be inappropriate. 

560 

561 Returns: 

562 `trackable`, for convenience when declaring a dependency and 

563 assigning to a member variable in one statement. 

564 

565 Raises: 

566 TypeError: If `trackable` does not inherit from `Trackable`. 

567 ValueError: If another object is already tracked by this name. 

568 """ 

569 self._maybe_initialize_trackable() 

570 if not isinstance(trackable, Trackable): 

571 raise TypeError( 

572 "Trackable._track_trackable() can only be used to track objects of " 

573 f"type Trackable. Got type {type(trackable)}.") 

574 if not getattr(self, "_manual_tracking", True): 

575 return trackable 

576 new_reference = TrackableReference(name=name, ref=trackable) 

577 current_object = self._lookup_dependency(name) 

578 if (current_object is not None and current_object is not trackable): 

579 if not overwrite: 

580 raise ValueError( 

581 f"Called Trackable._track_trackable() with name='{name}', " 

582 "but a Trackable with this name is already declared as a " 

583 "dependency. Names must be unique (or overwrite=True).") 

584 # This is a weird thing to do, but we're not going to stop people from 

585 # using __setattr__. 

586 for index, (old_name, _) in enumerate( 

587 self._self_unconditional_checkpoint_dependencies): 

588 if name == old_name: 

589 self._self_unconditional_checkpoint_dependencies[ 

590 index] = new_reference 

591 elif current_object is None: 

592 self._self_unconditional_checkpoint_dependencies.append(new_reference) 

593 self._handle_deferred_dependencies(name=name, trackable=trackable) 

594 self._self_unconditional_dependency_names[name] = trackable 

595 return trackable 

596 

597 def _handle_deferred_dependencies(self, name, trackable): 

598 """Pop and load any deferred checkpoint restores into `trackable`. 

599 

600 This method does not add a new dependency on `trackable`, but it does 

601 check if any outstanding/deferred dependencies have been queued waiting for 

602 this dependency to be added (matched based on `name`). If so, 

603 `trackable` and its dependencies are restored. The restorations are 

604 considered fulfilled and so are deleted. 

605 

606 `_track_trackable` is more appropriate for adding a 

607 normal/unconditional dependency, and includes handling for deferred 

608 restorations. This method allows objects such as `Optimizer` to use the same 

609 restoration logic while managing conditional dependencies themselves, by 

610 overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the 

611 object's dependencies based on the context it is saved/restored in (a single 

612 optimizer instance can have state associated with multiple graphs). 

613 

614 Args: 

615 name: The name of the dependency within this object (`self`), used to 

616 match `trackable` with values saved in a checkpoint. 

617 trackable: The Trackable object to restore (inheriting from `Trackable`). 

618 """ 

619 self._maybe_initialize_trackable() 

620 trackable._maybe_initialize_trackable() # pylint: disable=protected-access 

621 deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) 

622 for checkpoint_position in sorted( 

623 deferred_dependencies_list, 

624 key=lambda restore: restore.checkpoint.restore_uid, 

625 reverse=True): 

626 checkpoint_position.restore(trackable) 

627 

628 # Pass on any name-based restores queued in this object. 

629 for name_based_restore in sorted( 

630 self._self_name_based_restores, 

631 key=lambda checkpoint: checkpoint.restore_uid, 

632 reverse=True): 

633 trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access 

634 

635 def _gather_saveables_for_checkpoint(self): 

636 """Returns a dictionary of values to checkpoint with this object. 

637 

638 NOTE: This method is deprecated, prefer implementing `_serialize_to_tensors` 

639 and `_restore_from_tensors` instead. This method is only used in the 

640 deprecated `tf.compat.v1.train.Saver`. 

641 

642 Keys in the returned dictionary are local to this object and in a separate 

643 namespace from dependencies. Values may either be `SaveableObject` factories 

644 or variables easily converted to `SaveableObject`s (as in 

645 `tf.compat.v1.train.Saver`'s 

646 `var_list` constructor argument). 

647 

648 `SaveableObjects` have a name set, which Trackable needs to generate 

649 itself. So rather than returning `SaveableObjects` directly, this method 

650 should return a dictionary of callables which take `name` arguments and 

651 return `SaveableObjects` with that name. 

652 

653 If this object may also be passed to the global-name-based 

654 `tf.compat.v1.train.Saver`, 

655 the returned callables should have a default value for their name argument 

656 (i.e. be callable with no arguments). 

657 

658 Returned values must be saved only by this object; if any value may be 

659 shared, it should instead be a dependency. For example, variable objects 

660 save their own values with the key `VARIABLE_VALUE_KEY`, but objects which 

661 reference variables simply add a dependency. 

662 

663 Returns: 

664 The dictionary mapping attribute names to `SaveableObject` factories 

665 described above. For example: 

666 {VARIABLE_VALUE_KEY: 

667 lambda name="global_name_for_this_object": 

668 SaveableObject(name=name, ...)} 

669 """ 

670 return getattr(self, "_self_saveable_object_factories", {}) 

671 

672 def _serialize_to_tensors(self): 

673 """Gathers tensors to save to the checkpoint. 

674 

675 You should only override `_serialize_to_tensors` and `_restore_from_tensors` 

676 if you are defining a custom resource or variable with custom ops. 

677 

678 Otherwise, please store the state of your trackable in `tf.Variable` objects 

679 and add them to Trackable object hierarchy using `setattr` (for subclasses 

680 of `AutoTrackable`) or overriding the `_trackable_children` method. 

681 

682 For an example of a valid implementation of these two methods, please see 

683 `DenseHashTable`. 

684 

685 **Invalid implementation** 

686 

687 ```` 

688 class NamedTrackable(Trackable): 

689 def __init__(self, name: str): 

690 self.name = name 

691 def _serialize_to_tensors(self): 

692 return {"name": self.name} 

693 def _restore_from_tensors(self, restored_tensors): 

694 self.name = restored_tensors["name"] 

695 ``` 

696 

697 In this example, `NamedTrackable` can be saved and restored from 

698 checkpoints, but is incompatible with SavedModel, which tries to convert 

699 the serialize/restore functions into tf.functions. This fails because 

700 attribute assignment (`self.attr = new_value`) is not graph-friendly. 

701 

702 **Suggested fix** 

703 

704 ``` 

705 class NamedTrackable(Trackable): 

706 def __init__(self, name: str): 

707 self.name = tf.Variable(name) 

708 

709 def _trackable_children(self): 

710 return {"name": self.name} 

711 ``` 

712 

713 If the `name` attribute should be saved to the checkpoint, then convert it 

714 a `tf.Variable`. 

715 

716 **TF1 Saver Compatibility** 

717 If your Trackable needs to be comatible with `tf.compat.v1.train.Saver`, 

718 implement `_gather_saveables_from_checkpoint`. 

719 

720 Returns: 

721 A dictionary mapping names to tensors. 

722 """ 

723 raise NotImplementedError 

724 

725 def _restore_from_tensors(self, restored_tensors): 

726 """Restores checkpointed values to this `Trackable`. 

727 

728 Please see the documentation for `Trackable._serialize_to_tensors`. 

729 

730 Args: 

731 restored_tensors: A dictionary mapping names to tensors. The keys to this 

732 dictionary matches the names passed to _serialize_to_tensors. 

733 

734 Returns: 

735 An op that runs the restoration. 

736 """ 

737 raise NotImplementedError 

738 

739 def _serialize_to_proto(self, object_proto=None, **kwargs): 

740 """Returns a proto of any type to be saved into the SavedModel. 

741 

742 Trackable classes decorated with `register_serializable` should overwrite 

743 this method to save metadata for this object to the SavedModel. The proto 

744 returned by this function will be passed to `_deserialize_from_proto` in the 

745 form of a `google.protobuf.Any` proto. 

746 

747 This data is only saved and used by the Python API. Existing C++ loading 

748 APIs such as `tensorflow::LoadSavedModel` will not read this field at all. 

749 

750 Args: 

751 object_proto: A `SavedObject` proto that may be filled by this function. 

752 Only the core serializable types (Variable, Function, Constant, Asset) 

753 should modify this argument. 

754 **kwargs: Future keyword arguments passed to the object during saving. 

755 

756 Returns: 

757 A proto that serializes this class's type. 

758 """ 

759 del object_proto, kwargs # Unused. 

760 

761 return None 

762 

763 @classmethod 

764 def _deserialize_from_proto(cls, 

765 proto=None, 

766 dependencies=None, 

767 object_proto=None, 

768 export_dir=None, 

769 asset_file_def=None, 

770 operation_attributes=None, 

771 **kwargs): 

772 """Returns a new object restored by the SavedModel. 

773 

774 Trackable classes decorated with `register_serializable` should overwrite 

775 this method to change how the object is loaded from SavedModel. By default, 

776 the object is initialized with no arguments. 

777 

778 Example: 

779 

780 ``` 

781 def _serialize_to_proto(self, **unused_kwargs): 

782 return Message(name="a") 

783 

784 @classmethod 

785 def _deserialize_from_proto(cls, proto, **unused_kwargs): 

786 if proto.Is(Message.DESCRIPTOR): 

787 unpacked = Message() 

788 proto.Unpack(unpacked) 

789 return cls(unpacked.name) 

790 else: 

791 return cls() 

792 ``` 

793 

794 This function is only used by the Python API. C++ and TensorFlow Serving do 

795 not have access to your registered class and cannot execute any of the 

796 non-tf.functions attached to the Python class. However, all signatures and 

797 tf.functions are still accessible. 

798 

799 **Avoid creating duplicate trackables** 

800 

801 SavedModel is saved by recursively gathering all of the trackables and their 

802 children. SavedModel loading reverses those steps by creating all 

803 trackables, then reconnecting the children trackables to their parents using 

804 `Trackable._add_trackable_child`. 

805 

806 That means that if `_deserialize_from_proto` calls the `__init__` function, 

807 which creates all of the children trackables, then those children end up 

808 being created *twice*. 

809 

810 To avoid this, structure your code so that Trackables are not created 

811 when deserialized from SavedModel: 

812 

813 ``` 

814 @register_serializable() 

815 class Serializable(trackable): 

816 def __init __(self, from_proto=False): 

817 create_non_trackable_objects() 

818 if not from_proto: 

819 create_variables_and_other_trackables() 

820 

821 def _deserialize_from_proto(cls, **kwargs): 

822 return cls(from_proto=True) 

823 

824 def _add_trackable_child(self, name, value): 

825 self.__setattr__(name, value) 

826 ``` 

827 

828 Args: 

829 proto: A `google.protobuf.Any` proto read from the `SavedModel`. 

830 dependencies: A dictionary mapping names to dependencies (see 

831 `_deserialization_dependencies`) 

832 object_proto: The `SavedObject` proto for this object. 

833 export_dir: The `SavedModel` directory 

834 asset_file_def: The `MetaGraphDef`'s `asset_file_def` field. 

835 operation_attributes: Dictionary mapping nodes to attribute from the 

836 imported `GraphDef`. 

837 **kwargs: Future keyword arguments passed to the object when loading. 

838 

839 Returns: 

840 A new object. 

841 """ 

842 del (proto, dependencies, object_proto, export_dir, asset_file_def, 

843 operation_attributes, kwargs) 

844 

845 return cls() 

846 

847 def _add_trackable_child(self, name, value): 

848 """Restores a connection between trackables when loading from SavedModel. 

849 

850 SavedModel stores both the object metadata and its list of children. When 

851 loading, this function is used along with `_deserialize_from_proto` to load 

852 objects from the SavedModel: First, all saved objects are created with 

853 `_deserialize_from_proto`. After that is complete, the children are 

854 connected using `_add_trackable_child`. 

855 

856 **Example** 

857 

858 `tf.Module`, `tf.keras.Model` and Keras layers use `__setattr__` to track 

859 children. This is why users can call `model.v = tf.Variable(...)`, and the 

860 variable will be automatically saved to the checkpoint. The implementation 

861 of this method for the listed objects is: 

862 

863 ``` 

864 def _add_trackable_child(self, name, value): 

865 self.__setattr__(name, value) 

866 ``` 

867 

868 Args: 

869 name: The name of the connection between the parent and child `Trackable`. 

870 value: The child `Trackable` object. 

871 """ 

872 self._track_trackable(value, name, overwrite=True) 

873 

874 def _deserialization_dependencies(self, children): 

875 """Returns a dictionary containing `Trackables` that this object depends on. 

876 

877 Dependencies define the order to serialize and deserialize objects in the 

878 SavedModel. For example: 

879 

880 class A(Trackable): 

881 b = B() 

882 def _deserialization_dependencies(self, children): 

883 return {'b': self.b} 

884 

885 class B(Trackable): 

886 pass 

887 

888 We say that object `a=A()` depends on `a.b`. 

889 

890 Dependencies are guaranteed to be serialized and deserialized before the 

891 object depending on them. The following methods use dependencies: 

892 - `_deserialize_from_proto` [loading] 

893 

894 SavedModel loads with the bottom-up approach, by first creating all objects 

895 in the order defined by the dependencies, then connecting the children. 

896 

897 Unlike `_trackable_children`, this function does not define the 

898 `SavedObjectGraph`. It only changes the order in which things are 

899 saved/loaded. Therefore, if there are dependencies that are not in the 

900 `SavedObjectGraph`, saving will fail. 

901 

902 Args: 

903 children: Dict returned from `_trackable_children`. 

904 

905 Returns: 

906 A dictionary mapping names to `Trackable`. 

907 """ 

908 del children # Unused. 

909 return {} 

910 

911 def _trackable_children(self, 

912 save_type=SaveType.CHECKPOINT, 

913 cache=None, 

914 **kwargs): 

915 """Returns this object's `Trackable` attributes. 

916 

917 This method is used to build the object graph (or the object hierarchy, 

918 in pickling terms) for checkpoint save/restore, and `SavedModel` export. 

919 

920 Override this method to define the children of this instance. Please read 

921 the implementation restrictions: 

922 

923 **Rule 1: All children must be convertable to `Trackable`.** 

924 

925 Must pass `isinstance` check or `converter.convert_to_trackable`. 

926 

927 **Rule 2: [Checkpoint-only] Do not create new objects.** 

928 

929 When saving to a `SavedModel`, this method is called *exactly once* for each 

930 `Trackable` in the object graph. When saving or restoring from a checkpoint, 

931 this method may be called *multiple times*. Thus, this method may create 

932 new Trackables when `save_type == SaveType.SAVEDMODEL` but not when 

933 `save_type == SaveType.CHECKPOINT`. 

934 

935 When saving to `SavedModel`, new `Trackable` children can be created to save 

936 non-Trackable attributes to the `SavedModel`. In the example below, `hyper` 

937 is a regular python float hyperparameter. To save this value, a new Variable 

938 is created to store the value of `hyper`: 

939 

940 ``` 

941 def __init__(self): 

942 self.hyper = 1e-5 

943 

944 def _trackable_children(self, save_type, **unused_kwargs): 

945 # Correct implementation 

946 children = {} 

947 if format == 'saved_model': 

948 children['hyper'] = tf.Variable(self.hyper) 

949 return children 

950 ``` 

951 

952 An incorrect implementation of `_trackable_children` is shown below. This 

953 function would cause failures when loading the checkpoint, and calling 

954 `load_status.assert_consumed()` or 

955 `load_status.assert_existing_objects_matched`. If you want a value to be 

956 saved in the checkpoint, hyper must be defined as a `tf.Variable` from the 

957 start. 

958 

959 ``` 

960 def _trackable_children(self, save_type, **unused_kwargs): 

961 # Incorrect implementation 

962 return {'hyper': tf.Variable(self.hyper)} 

963 ``` 

964 

965 **Rule 3: [`SavedModel`-only] Watch out for un-traced tf.functions.** 

966 

967 At the begining of `_trackable_children`, always call 

968 `get_concrete_function()` for any `tf.function` that has an input signature. 

969 

970 When `tf.functions` are saved to `SavedModel`, any `tf.functions` that have 

971 an input signature and has never been called is traced at export time in 

972 order to copy the op graph into the `SavedModel`. `tf.functions` that are 

973 traced for the first time are allowed to create new state: 

974 

975 

976 ``` 

977 @tf.function(input_signature=[]): 

978 def fn(self); 

979 if self.v is None: 

980 self.v = tf.Variable(1.) 

981 return self.v 

982 ``` 

983 

984 A problem occurs when there is a `Trackable` that returns `fn` as one of its 

985 children and `self.v` has not been created yet. When `fn` is traced, 

986 `self.v` is added to the `Trackable`, but `SavedModel` does not see this 

987 modification since the `Trackable`'s children have already been gathered. 

988 

989 Therefore, as a precaution, call `get_concrete_function()` at the very 

990 start of `_trackable_children` to ensure that the function is traced: 

991 

992 

993 ``` 

994 def _trackable_children(self): 

995 self.fn.get_concrete_function() 

996 return {"v": self.v, "fn": self.fn} 

997 ``` 

998 

999 Args: 

1000 save_type: A string, can be 'savedmodel' or 'checkpoint'. Defaults to 

1001 SaveType.CHECKPOINT. 

1002 cache: May be `None`, or a dictionary. When `save_type == savedmodel`, a 

1003 new cache is created at the start of the SavedModel export, and shared 

1004 between all `Trackables` in the same object graph. This cache may be 

1005 used for advanced saving functionality. 

1006 **kwargs: Additional kwargs that may be added at a later time. 

1007 

1008 Returns: 

1009 Dictionary mapping names to child trackables. 

1010 """ 

1011 del save_type, cache, kwargs # Unused. 

1012 

1013 self._maybe_initialize_trackable() 

1014 return {name: ref for name, ref in self._checkpoint_dependencies} 

1015 

1016 def _export_to_saved_model_graph(self, 

1017 object_map, 

1018 tensor_map, 

1019 options, 

1020 **kwargs): 

1021 """Creates a copy of this object's tensors onto SavedModel graph. 

1022 

1023 Needs to be overridden if the class contains tensors that must be saved 

1024 into the graph. This method should update the `object_map` and `tensor_map` 

1025 dictionaries. 

1026 

1027 This method is called on all nodes in the Trackable Graph (generated by 

1028 `_trackable_children`). The nodes are traversed in the order defined by 

1029 `_deserialization_dependencies` 

1030 

1031 All usages of _map_resources should be migrated to this method. 

1032 

1033 Args: 

1034 object_map: A dictionary that maps original Trackables to the copied 

1035 Trackables. This only needs to be updated if the object is a 

1036 tf.function, or if the copied tensors are necessary for checkpointing 

1037 this object. 

1038 tensor_map: Dictionary mapping original tensors to copied tensors. 

1039 options: A `tf.saved_model.SaveOptions` object. 

1040 **kwargs: Additional kwargs that may be added at a later time. 

1041 

1042 Returns: 

1043 Flat list of original tensors that have been copied. 

1044 """ 

1045 _, _, _ = object_map, tensor_map, options 

1046 del kwargs 

1047 return []