Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/data_structures.py: 29%

578 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Trackable data structures.""" 

2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16import collections 

17import copy 

18import sys 

19 

20try: 

21 import wrapt 

22except ImportError: 

23 # Fall back to the build-time dependency if the system package is not available. 

24 from .....third_party import wrapt # pylint: disable=relative-beyond-top-level 

25 

26from tensorflow.python.eager import def_function 

27from tensorflow.python.eager import function as defun 

28from tensorflow.python.ops import variables 

29from tensorflow.python.trackable import base 

30from tensorflow.python.trackable import layer_utils 

31from tensorflow.python.util.compat import collections_abc 

32from tensorflow.python.util.tf_export import tf_export 

33 

34 

35class NoDependency: 

36 """Allows attribute assignment to `Trackable` objects with no dependency. 

37 

38 Example usage: 

39 ```python 

40 obj = Trackable() 

41 obj.has_dependency = tf.Variable(0., name="dep") 

42 obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) 

43 assert obj.no_dependency.name == "nodep:0" 

44 ``` 

45 

46 `obj` in this example has a dependency on the variable "dep", and both 

47 attributes contain un-wrapped `Variable` objects. 

48 

49 `NoDependency` also works with `tf.keras.Model`, but only for checkpoint 

50 dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) 

51 `Layer` to the attribute without a checkpoint dependency, but the `Model` will 

52 still track the `Layer` (so it will appear in `Model.layers`, and its 

53 variables will appear in `Model.variables`). 

54 """ 

55 

56 __slots__ = ["value"] 

57 

58 def __init__(self, value): 

59 self.value = value 

60 

61 

62def _should_wrap_tuple(t): 

63 """Determine if a tuple has any trackable components.""" 

64 # pylint: disable=unidiomatic-typecheck 

65 # Exact type checking to avoid mucking up custom logic in list/dict 

66 # subclasses, e.g. collections.Counter. 

67 for element in t: 

68 if isinstance(element, NoDependency): 

69 return True # We should remove the NoDependency object from the tuple. 

70 if isinstance(element, base.Trackable): 

71 return True 

72 if type(element) == dict: 

73 return True 

74 if type(element) == collections.OrderedDict: 

75 return True 

76 if type(element) == list: 

77 return True 

78 if isinstance(element, tuple) and _should_wrap_tuple(element): 

79 return True 

80 # There are no trackable elements or data structures. Tuples are immutable, so 

81 # mutation isn't a concern. Don't wrap. 

82 return False 

83 # pylint: enable=unidiomatic-typecheck 

84 

85 

86@tf_export("__internal__.tracking.wrap", v1=[]) 

87def wrap_or_unwrap(value): 

88 """Wraps input value into trackable data structures. 

89 

90 This is mostly useful for containers like list, dict, etc, which could contain 

91 trackable objects in it. Wrapped data structure will be tracked when 

92 associated with a `tf.Module`, so that save model/checkpoint can properly 

93 track the dependency. 

94 

95 It will also unwrap NoDependency objects. 

96 

97 Args: 

98 value: the input object to be wrapped. 

99 

100 Returns: 

101 Wrapped trackable data structure. 

102 """ 

103 # pylint: disable=unidiomatic-typecheck 

104 # Exact type checking to avoid mucking up custom logic in list/dict 

105 # subclasses, e.g. collections.Counter. 

106 if isinstance(value, NoDependency): 

107 return value.value 

108 if isinstance(value, base.Trackable): 

109 return value # Skip conversion for already trackable objects. 

110 elif type(value) == dict: 

111 return _DictWrapper(value) 

112 elif type(value) == collections.OrderedDict: 

113 return _DictWrapper(value) 

114 elif type(value) == list: 

115 return ListWrapper(value) 

116 elif isinstance(value, tuple) and _should_wrap_tuple(value): 

117 # There are trackable elements or data structures. Wrap the tuple. 

118 return _TupleWrapper(value) 

119 else: 

120 return value 

121 # pylint: enable=unidiomatic-typecheck 

122 

123 

124@tf_export("__internal__.tracking.sticky_attribute_assignment", v1=[]) 

125def sticky_attribute_assignment(trackable, name, value): 

126 """Adds dependencies, generally called from __setattr__. 

127 

128 This behavior is shared between Trackable and Model. 

129 

130 Respects NoDependency indicators, but otherwise makes trackable objects 

131 out of common data structures and tracks objects by their attribute names. 

132 

133 Args: 

134 trackable: The object to add dependencies to (generally the one having 

135 an attribute assigned). 

136 name: The attribute name being assigned. 

137 value: The value being assigned. Not necessarily a trackable object. 

138 

139 Returns: 

140 The value which should be stored in the attribute (unwrapped from a 

141 NoDependency object if necessary). 

142 """ 

143 if isinstance(value, NoDependency): 

144 add_dependency = False 

145 else: 

146 add_dependency = True 

147 value = wrap_or_unwrap(value) 

148 if not add_dependency: 

149 return value 

150 if isinstance(value, base.Trackable): 

151 trackable._track_trackable( # pylint: disable=protected-access 

152 value, name=name, 

153 # Allow the user to switch the Trackable which is tracked by this 

154 # name, since assigning a new variable to an attribute has 

155 # historically been fine (e.g. Adam did this). 

156 overwrite=True) 

157 return value 

158 

159 

160class _UntrackableError(ValueError): 

161 

162 def __init__(self, value): # pylint: disable=super-init-not-called 

163 self._value = value 

164 

165 def __str__(self): 

166 return ("Only trackable objects (such as Layers or Optimizers) may be " 

167 f"stored in a List object. Got {self._value}, which does not " 

168 "inherit from Trackable.") 

169 

170 

171@tf_export("__internal__.tracking.TrackableDataStructure", v1=[]) 

172class TrackableDataStructure(base.Trackable): 

173 """Base class for data structures which contain trackable objects.""" 

174 

175 def __init__(self): 

176 # Attributes prefixed with "_self_" for compatibility with 

177 # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as 

178 # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of 

179 # ways. 

180 self._self_trainable = True 

181 self._self_extra_variables = [] 

182 self._self_attribute_sentinel = layer_utils.AttributeSentinel(True) 

183 

184 @property 

185 def _attribute_sentinel(self): 

186 return self._self_attribute_sentinel 

187 

188 @property 

189 def trainable(self): 

190 return self._self_trainable 

191 

192 @trainable.setter 

193 def trainable(self, value): 

194 self._self_trainable = value 

195 

196 def _track_value(self, value, name): 

197 """Add a dependency on `value`.""" 

198 value = sticky_attribute_assignment( 

199 trackable=self, value=value, name=name) 

200 if isinstance(value, variables.Variable): 

201 self._self_extra_variables.append(value) 

202 if not isinstance(value, base.Trackable): 

203 raise _UntrackableError(value) 

204 if hasattr(value, "_use_resource_variables"): 

205 # In subclassed models, legacy layers (tf.layers) must always use 

206 # resource variables. 

207 value._use_resource_variables = True # pylint: disable=protected-access 

208 value_attribute_sentinel = getattr(value, "_attribute_sentinel", None) 

209 if value_attribute_sentinel: 

210 value_attribute_sentinel.add_parent(self._attribute_sentinel) 

211 return value 

212 

213 @property 

214 def _values(self): 

215 """An iterable/sequence which may contain trackable objects.""" 

216 raise NotImplementedError("Abstract method") 

217 

218 @property 

219 def _layers(self): 

220 """All Layers and Layer containers, including empty containers.""" 

221 # Filter objects on demand so that wrapper objects use values from the thing 

222 # they're wrapping if out of sync. 

223 collected = [] 

224 for obj in self._values: 

225 if (isinstance(obj, TrackableDataStructure) 

226 or layer_utils.is_layer(obj) 

227 or layer_utils.has_weights(obj)): 

228 collected.append(obj) 

229 return collected 

230 

231 @property 

232 def layers(self): 

233 return list(layer_utils.filter_empty_layer_containers(self._layers)) 

234 

235 @property 

236 def trainable_weights(self): 

237 if not self._self_trainable: 

238 return [] 

239 trainable_variables = [] 

240 for obj in self._values: 

241 if isinstance(obj, base.Trackable) and hasattr( 

242 obj, "trainable_variables"): 

243 trainable_variables += obj.trainable_variables 

244 trainable_extra_variables = [ 

245 v for v in self._self_extra_variables if v.trainable 

246 ] 

247 return trainable_variables + trainable_extra_variables 

248 

249 @property 

250 def non_trainable_weights(self): 

251 trainable_extra_variables = [ 

252 v for v in self._self_extra_variables if v.trainable 

253 ] 

254 non_trainable_extra_variables = [ 

255 v for v in self._self_extra_variables if not v.trainable 

256 ] 

257 non_trainable_variables = [] 

258 for obj in self._values: 

259 if isinstance(obj, base.Trackable) and hasattr( 

260 obj, "non_trainable_variables"): 

261 non_trainable_variables += obj.non_trainable_variables 

262 

263 if not self._self_trainable: 

264 # Return order is all trainable vars, then all non-trainable vars. 

265 trainable_variables = [] 

266 for obj in self._values: 

267 if isinstance(obj, base.Trackable) and hasattr( 

268 obj, "trainable_variables"): 

269 trainable_variables += obj.trainable_variables 

270 

271 non_trainable_variables = ( 

272 trainable_variables + trainable_extra_variables + 

273 non_trainable_variables + non_trainable_extra_variables) 

274 else: 

275 non_trainable_variables = ( 

276 non_trainable_variables + non_trainable_extra_variables) 

277 

278 return non_trainable_variables 

279 

280 @property 

281 def weights(self): 

282 return self.trainable_weights + self.non_trainable_weights 

283 

284 @property 

285 def trainable_variables(self): 

286 return self.trainable_weights 

287 

288 @property 

289 def non_trainable_variables(self): 

290 return self.non_trainable_weights 

291 

292 @property 

293 def variables(self): 

294 return self.weights 

295 

296 @property 

297 def updates(self): 

298 """Aggregate updates from any `Layer` instances.""" 

299 # Updates and conditional losses are forwarded as-is rather than being 

300 # filtered based on inputs, since this is just a container and won't ever 

301 # have any inputs. 

302 aggregated = [] 

303 for layer in self.layers: 

304 if hasattr(layer, "updates"): 

305 aggregated += layer.updates 

306 return aggregated 

307 

308 @property 

309 def losses(self): 

310 """Aggregate losses from any `Layer` instances.""" 

311 aggregated = [] 

312 for layer in self.layers: 

313 if hasattr(layer, "losses"): 

314 aggregated += layer.losses 

315 return aggregated 

316 

317 def __hash__(self): 

318 # Support object-identity hashing, so these structures can be used as keys 

319 # in sets/dicts. 

320 return id(self) 

321 

322 def __eq__(self, other): 

323 # Similar to Tensors, trackable data structures use object-identity 

324 # equality to support set/dict membership. 

325 return self is other 

326 

327 

328class List(TrackableDataStructure, collections_abc.Sequence): 

329 """An append-only sequence type which is trackable. 

330 

331 Maintains checkpoint dependencies on its contents (which must also be 

332 trackable), and forwards any `Layer` metadata such as updates and losses. 

333 

334 Note that `List` is purely a container. It lets a `tf.keras.Model` or 

335 other trackable object know about its contents, but does not call any 

336 `Layer` instances which are added to it. To indicate a sequence of `Layer` 

337 instances which should be called sequentially, use `tf.keras.Sequential`. 

338 

339 Example usage: 

340 ```python 

341 class HasList(tf.keras.Model): 

342 

343 def __init__(self): 

344 super().__init__() 

345 self.layer_list = List([layers.Dense(3)]) 

346 self.layer_list.append(layers.Dense(4)) 

347 

348 def call(self, x): 

349 aggregation = 0. 

350 for l in self.layer_list: 

351 x = l(x) 

352 aggregation += tf.reduce_sum(x) 

353 return aggregation 

354 ``` 

355 

356 This kind of wrapping is necessary because `Trackable` objects do not 

357 (yet) deeply inspect regular Python data structures, so for example assigning 

358 a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a 

359 checkpoint dependency and does not add the `Layer` instance's weights to its 

360 parent `Model`. 

361 """ 

362 

363 def __init__(self, *args, **kwargs): 

364 """Construct a new sequence. Arguments are passed to `list()`.""" 

365 super().__init__() 

366 self._storage = self._make_storage(*args, **kwargs) 

367 for index, element in enumerate(self._storage): 

368 self._storage[index] = self._track_value( 

369 element, name=self._name_element(index)) 

370 

371 def copy(self): 

372 return type(self)(copy.copy(self._storage)) 

373 

374 def __copy__(self): 

375 return self.copy() 

376 

377 def __deepcopy__(self, memo): 

378 return type(self)(copy.deepcopy(self._storage, memo)) 

379 

380 def _make_storage(self, *args, **kwargs): 

381 """Determines the backing storage (overridden in subclasses).""" 

382 return list(*args, **kwargs) 

383 

384 def _name_element(self, index): 

385 return "%d" % (index,) 

386 

387 @property 

388 def _values(self): 

389 """Collect values for TrackableDataStructure.""" 

390 return self 

391 

392 def append(self, value): 

393 """Add a new trackable value.""" 

394 value = self._track_value(value, self._name_element(len(self._storage))) 

395 self._storage.append(value) 

396 

397 def extend(self, values): 

398 """Add a sequence of trackable values.""" 

399 for value in values: 

400 self.append(value) 

401 

402 def __iadd__(self, values): 

403 self.extend(values) 

404 return self 

405 

406 def __add__(self, other): 

407 return self._storage + getattr(other, "_storage", other) 

408 

409 def __imul__(self, y): 

410 if y <= 0: 

411 raise ValueError( 

412 f"List only supports append, multiplying in place by {y} removes " 

413 "elements.") 

414 

415 n = len(self._storage) 

416 for _ in range(y - 1): 

417 for i in range(n): 

418 self.append(self._storage[i]) 

419 

420 return self 

421 

422 def __mul__(self, n): 

423 return self._storage * n 

424 

425 def __rmul__(self, n): 

426 return self * n 

427 

428 def __radd__(self, other): 

429 return other + self._storage 

430 

431 def __getitem__(self, key): 

432 return self._storage[key] 

433 

434 def __getslice__(self, i, j): 

435 return self._storage[slice(i, j)] 

436 

437 def __len__(self): 

438 return len(self._storage) 

439 

440 def __repr__(self): 

441 return "List(%s)" % (repr(self._storage),) 

442 

443 def __sizeof__(self): 

444 return super().__sizeof__() + sys.getsizeof(self._storage) 

445 

446 

447# TODO(tomhennigan) Update to collections.UserList? 

448# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop 

449# Python 3.4 support (may still be tricky). 

450class ListWrapper( 

451 List, 

452 collections_abc.MutableSequence, 

453 # Shadowed, but there for isinstance checks. 

454 list): 

455 """Wraps the built-in `list` to support restore-on-create for variables. 

456 

457 Unlike `List`, this sequence type is mutable in the same ways built-in lists 

458 are. Instead of throwing an error immediately like `List`, it records 

459 problematic mutations (e.g. assigning a new element to a position already 

460 occupied, meaning both elements get the same names at different times) and 

461 refuses to save. 

462 

463 On assignment to an attribute of a Model or Trackable object, Python 

464 lists are replaced with ListWrapper. Wrapping a list in a 

465 `NoDependency` object prevents this. 

466 """ 

467 

468 def __init__(self, wrapped_list): 

469 """Construct a new list wrapper. 

470 

471 Args: 

472 wrapped_list: The initial value of the data structure. A shallow copy may 

473 be maintained for error checking. `wrapped_list` itself should not be 

474 modified directly after constructing the `ListWrapper`, and if changes 

475 are detected the `ListWrapper` will throw an exception on save. 

476 """ 

477 # Monotonic flags which indicate this object would not be restored properly, 

478 # and therefore should throw an error on save to avoid giving the impression 

479 # that restoring it will work. 

480 self._non_append_mutation_value = False 

481 self._external_modification_value = False 

482 super().__init__(wrapped_list) 

483 self._last_wrapped_list_snapshot = list(self._storage) 

484 

485 @property 

486 def _non_append_mutation(self): 

487 return self._non_append_mutation_value 

488 

489 @_non_append_mutation.setter 

490 def _non_append_mutation(self, value): 

491 # Trackable only cares that a mutation occurred at some point; when 

492 # attempting to save it checks whether a mutation occurred and the object is 

493 # in a "dirty" state but otherwise the specifics of how it got to that state 

494 # are ignored. By contrast, the attribute cache needs to signal the mutation 

495 # immediately since a caller could query the value of an attribute (And 

496 # should not hit the cached value since the mutation may have affected the 

497 # result.) 

498 self._attribute_sentinel.invalidate_all() 

499 self._non_append_mutation_value = value 

500 

501 @property 

502 def _external_modification(self): 

503 return self._external_modification_value 

504 

505 @_external_modification.setter 

506 def _external_modification(self, value): 

507 # Invalidate for the same reason as `_non_append_mutation` 

508 self._attribute_sentinel.invalidate_all() 

509 self._external_modification_value = value 

510 

511 # pylint: disable=protected-access 

512 def __copy__(self): 

513 copied = super().__copy__() 

514 copied._non_append_mutation = self._non_append_mutation 

515 copied._external_modification = self._external_modification 

516 return copied 

517 

518 def __deepcopy__(self, memo): 

519 copied = super().__deepcopy__(memo) 

520 copied._non_append_mutation = self._non_append_mutation 

521 copied._external_modification = self._external_modification 

522 return copied 

523 # pylint: enable=protected-access 

524 

525 def __reduce_ex__(self, protocol): 

526 return (self.__class__, 

527 (self._storage,)) 

528 

529 def _make_storage(self, wrapped_list): 

530 """Use the user's original list for storage.""" 

531 return wrapped_list 

532 

533 def _check_external_modification(self): 

534 """Checks for any changes to the wrapped list not through the wrapper.""" 

535 if self._external_modification or self._non_append_mutation: 

536 return 

537 if self._storage != self._last_wrapped_list_snapshot: 

538 self._external_modification = True 

539 self._last_wrapped_list_snapshot = None 

540 

541 def _update_snapshot(self): 

542 """Acknowledges tracked changes to the wrapped list.""" 

543 

544 # Mutation tracking for attributes reuses the same infrastructure as 

545 # Trackable mutation tracking. 

546 self._attribute_sentinel.invalidate_all() 

547 if self._external_modification or self._non_append_mutation: 

548 return 

549 self._last_wrapped_list_snapshot = list(self._storage) 

550 

551 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 

552 self._check_external_modification() 

553 if self._non_append_mutation: 

554 raise ValueError( 

555 f"Unable to save the object {self} (a list wrapper constructed to " 

556 "track trackable TensorFlow objects). A list element was replaced " 

557 "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), " 

558 "or moved (sort). In order to support restoration on object " 

559 "creation, tracking is exclusively for append-only data structures." 

560 "\n\nIf you don't need this list checkpointed, wrap it in a " 

561 "non-trackable object; it will be subsequently ignored.") 

562 if self._external_modification: 

563 raise ValueError( 

564 f"Unable to save the object {self} (a list wrapper constructed to " 

565 "track trackable TensorFlow objects). The wrapped list was modified " 

566 f"outside the wrapper (its final value was {self._storage}, its value" 

567 " when a checkpoint dependency was added was " 

568 f"{self._last_wrapped_list_snapshot}), which breaks " 

569 "restoration on object creation.\n\nIf you don't need this list " 

570 "checkpointed, wrap it in a NoDependency object; it will be " 

571 "subsequently ignored.") 

572 children = super()._trackable_children(save_type, **kwargs) 

573 

574 if save_type == base.SaveType.SAVEDMODEL: 

575 # Add functions to be serialized. 

576 children.update({ 

577 str(key): value 

578 for key, value in enumerate(self) 

579 if _is_function(value) 

580 }) 

581 

582 return children 

583 

584 def _has_mutation_or_trackable(self): 

585 """Short-circuits a check for trackables if there's already a mutation.""" 

586 if self._non_append_mutation: 

587 return True 

588 return any(isinstance(element, base.Trackable) for element in self._storage) 

589 

590 def __delitem__(self, key): 

591 self._check_external_modification() 

592 if self._has_mutation_or_trackable(): 

593 self._non_append_mutation = True 

594 del self._storage[key] 

595 self._update_snapshot() 

596 

597 def __setitem__(self, key, value): 

598 self._check_external_modification() 

599 

600 if isinstance(key, slice): 

601 # Note: this is quite inefficient, but the list API supports a broad range 

602 # of slice setters (e.g. truncate, extend, replace) and imitating this 

603 # for a range of Python versions is non-trivial. 

604 storage_copy = list(self._storage) 

605 self._storage[key] = value 

606 

607 len_before = len(storage_copy) 

608 len_now = len(self._storage) 

609 for i in range(max(len_before, len_now)): 

610 value_now = self._storage[i] if i < len_now else None 

611 value_before = storage_copy[i] if i < len_before else None 

612 

613 if isinstance(value_before, base.Trackable): 

614 self._non_append_mutation = True 

615 

616 if value_now is not None and value_now != value_before: 

617 self._storage[i] = self._track_value(self._storage[i], 

618 self._name_element(i)) 

619 

620 else: 

621 if isinstance(self._storage[key], base.Trackable): 

622 self._non_append_mutation = True 

623 self._storage[key] = self._track_value(value, self._name_element(key)) 

624 

625 self._update_snapshot() 

626 

627 def append(self, value): 

628 """Add a new trackable value.""" 

629 self._check_external_modification() 

630 super().append(value) 

631 self._update_snapshot() 

632 

633 def extend(self, values): 

634 """Add a sequence of trackable values.""" 

635 self._check_external_modification() 

636 super().extend(values) 

637 self._update_snapshot() 

638 

639 def __imul__(self, y): 

640 if y <= 0: 

641 self._check_external_modification() 

642 if self._has_mutation_or_trackable(): 

643 self._non_append_mutation = True 

644 self._storage *= y 

645 self._update_snapshot() 

646 return self 

647 

648 # Relies on super() calling append, which updates the snapshot. 

649 return super().__imul__(y) 

650 

651 def __eq__(self, other): 

652 return self._storage == getattr(other, "_storage", other) 

653 

654 def __ne__(self, other): 

655 return self._storage != getattr(other, "_storage", other) 

656 

657 def __lt__(self, other): 

658 return self._storage < getattr(other, "_storage", other) 

659 

660 def __le__(self, other): 

661 return self._storage <= getattr(other, "_storage", other) 

662 

663 def __gt__(self, other): 

664 return self._storage > getattr(other, "_storage", other) 

665 

666 def __ge__(self, other): 

667 return self._storage >= getattr(other, "_storage", other) 

668 

669 def __hash__(self): 

670 # List wrappers need to compare like regular lists, and so like regular 

671 # lists they don't belong in hash tables. 

672 raise TypeError("unhashable type: 'ListWrapper'") 

673 

674 def insert(self, index, obj): 

675 self._check_external_modification() 

676 if (self._has_mutation_or_trackable() or isinstance(obj, base.Trackable)): 

677 self._non_append_mutation = True 

678 self._storage.insert(index, obj) 

679 self._update_snapshot() 

680 

681 def sort(self): 

682 self._check_external_modification() 

683 if self._has_mutation_or_trackable(): 

684 self._non_append_mutation = True 

685 self._storage.sort() 

686 self._update_snapshot() 

687 

688 def __setslice__(self, i, j, y): 

689 self.__setitem__(slice(i, j), y) 

690 

691 def __delslice__(self, i, j): 

692 self._check_external_modification() 

693 if self._has_mutation_or_trackable(): 

694 self._non_append_mutation = True 

695 del self._storage[slice(i, j)] 

696 self._update_snapshot() 

697 

698 def _track_value(self, value, name): 

699 """Allows storage of non-trackable objects.""" 

700 try: 

701 value = super()._track_value(value=value, name=name) 

702 except ValueError: 

703 # Even if this value isn't trackable, we need to make sure 

704 # NoDependency objects get unwrapped. 

705 value = sticky_attribute_assignment( 

706 trackable=self, value=value, name=name) 

707 return value 

708 

709 def __repr__(self): 

710 return "ListWrapper(%s)" % (repr(self._storage),) 

711 

712 

713class Mapping(TrackableDataStructure, collections_abc.Mapping): 

714 """An append-only trackable mapping data structure with string keys. 

715 

716 Maintains checkpoint dependencies on its contents (which must also be 

717 trackable), named based on its keys. 

718 

719 Note that once a key has been added, it may not be deleted or replaced. 

720 """ 

721 

722 def __init__(self, *args, **kwargs): 

723 """Construct a new sequence. Arguments are passed to `dict()`.""" 

724 super().__init__() 

725 self._storage = self._make_storage(*args, **kwargs) 

726 self._storage.update( 

727 {key: self._track_value( 

728 value, name=self._name_element(key)) 

729 for key, value in self._storage.items()}) 

730 

731 def __copy__(self): 

732 return type(self)(copy.copy(self._storage)) 

733 

734 def __deepcopy__(self, memo): 

735 return type(self)(copy.deepcopy(self._storage, memo)) 

736 

737 def _make_storage(self, *args, **kwargs): 

738 return dict(*args, **kwargs) 

739 

740 @property 

741 def _values(self): 

742 """Collect values for TrackableDataStructure.""" 

743 # Sort items deterministically by key 

744 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 

745 if ordered: 

746 return ordered[1] 

747 return [] 

748 

749 def _name_element(self, key): 

750 if not isinstance(key, str): 

751 raise TypeError( 

752 f"Mapping accepts only string keys, but got a key {repr(key)}.") 

753 return str(key) 

754 

755 def __setitem__(self, key, value): 

756 name = self._name_element(key) 

757 value = self._track_value(value, name=name) 

758 current_value = self._storage.setdefault(key, value) 

759 if current_value is not value: 

760 raise ValueError( 

761 "Mappings are an append-only data structure. Tried to overwrite the " 

762 f"key '{key}' with value {value}, but it already contains " 

763 f"{current_value}") 

764 

765 def update(self, *args, **kwargs): 

766 for key, value in dict(*args, **kwargs).items(): 

767 self[key] = value 

768 

769 def __getitem__(self, key): 

770 return self._storage[key] 

771 

772 def __len__(self): 

773 return len(self._storage) 

774 

775 def __repr__(self): 

776 return "Mapping(%s)" % (repr(self._storage),) 

777 

778 def __iter__(self): 

779 return iter(self._storage) 

780 

781 

782class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): 

783 """Wraps built-in dicts to support restore-on-create for variables. 

784 

785 _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping, 

786 _DictWrapper allows non-string keys and values and arbitrary mutations (delete 

787 keys, reassign values). Like ListWrapper, these mutations mean that 

788 _DictWrapper will raise an exception on save. 

789 """ 

790 

791 def __init__(self, wrapped_dict=None): 

792 if wrapped_dict is None: 

793 # Allow zero-argument construction, e.g. from session.run's re-wrapping. 

794 wrapped_dict = {} 

795 if not isinstance(wrapped_dict, collections_abc.Mapping): 

796 # Allow construction from a sequence, e.g. from nest.pack_sequence_as. 

797 wrapped_dict = dict(wrapped_dict) 

798 wrapt.ObjectProxy.__init__(self, wrapped_dict) 

799 TrackableDataStructure.__init__(self) 

800 self._self_non_string_key = False 

801 self._self_external_modification = False 

802 self.__wrapped__.update( 

803 {key: self._track_value( 

804 value, name=self._name_element(key)) 

805 for key, value in self.__wrapped__.items()}) 

806 self._update_snapshot() 

807 

808 def __reduce_ex__(self, protocol): 

809 return (self.__class__, 

810 (self.__wrapped__,)) 

811 

812 def __getattribute__(self, name): 

813 if (hasattr(type(self), name) 

814 and isinstance(getattr(type(self), name), property)): 

815 # Bypass ObjectProxy for properties. Whether this workaround is necessary 

816 # appears to depend on the Python version but not the wrapt version: 3.4 

817 # in particular seems to look up properties on the wrapped object instead 

818 # of the wrapper without this logic. 

819 return object.__getattribute__(self, name) 

820 else: 

821 return super().__getattribute__(name) 

822 

823 def copy(self): 

824 return copy.copy(self) 

825 

826 # pylint: disable=protected-access 

827 def __copy__(self): 

828 copied = _DictWrapper(copy.copy(self.__wrapped__)) 

829 copied._self_external_modification = self._self_external_modification 

830 copied._self_non_string_key = self._self_non_string_key 

831 return copied 

832 

833 def __deepcopy__(self, memo): 

834 copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo)) 

835 copied._self_external_modification = self._self_external_modification 

836 copied._self_non_string_key = self._self_non_string_key 

837 return copied 

838 # pylint: enable=protected-access 

839 

840 @property 

841 def _values(self): 

842 """Collect values for TrackableDataStructure.""" 

843 # Sort items deterministically by key 

844 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 

845 if ordered: 

846 return ordered[1] 

847 return [] 

848 

849 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 

850 """Check that the object is saveable before listing its dependencies.""" 

851 self._check_self_external_modification() 

852 if self._self_non_string_key: 

853 raise ValueError( 

854 f"Unable to save the object {self} (a dictionary wrapper constructed " 

855 "automatically on attribute assignment). The wrapped dictionary " 

856 "contains a non-string key which maps to a trackable object or " 

857 "mutable data structure.\n\nIf you don't need this dictionary " 

858 "checkpointed, wrap it in a non-trackable " 

859 "object; it will be subsequently ignored.") 

860 if self._self_external_modification: 

861 raise ValueError( 

862 f"Unable to save the object {self} (a dictionary wrapper constructed " 

863 "automatically on attribute assignment). The wrapped dictionary was " 

864 f"modified outside the wrapper (its final value was {self}, its value" 

865 " when a checkpoint dependency was added was " 

866 f"{self._self_last_wrapped_dict_snapshot}), which breaks " 

867 "restoration on object creation.\n\nIf you don't need this " 

868 "dictionary checkpointed, wrap it in a " 

869 "non-trackable object; it will be subsequently ignored.") 

870 assert not self._dirty # Any reason for dirtiness should have an exception. 

871 children = super()._trackable_children(save_type, **kwargs) 

872 

873 if save_type == base.SaveType.SAVEDMODEL: 

874 # Add functions to be serialized. 

875 children.update( 

876 {key: value for key, value in self.items() if _is_function(value)}) 

877 

878 return children 

879 

880 @property 

881 def _dirty(self): 

882 """Check if there has already been a mutation which prevents saving.""" 

883 return (self._self_external_modification 

884 or self._self_non_string_key) 

885 

886 def _check_self_external_modification(self): 

887 """Checks for any changes to the wrapped dict not through the wrapper.""" 

888 if self._dirty: 

889 return 

890 if self != self._self_last_wrapped_dict_snapshot: 

891 self._self_external_modification = True 

892 self._self_last_wrapped_dict_snapshot = None 

893 

894 def _update_snapshot(self): 

895 """Acknowledges tracked changes to the wrapped dict.""" 

896 self._attribute_sentinel.invalidate_all() 

897 if self._dirty: 

898 return 

899 self._self_last_wrapped_dict_snapshot = dict(self) 

900 

901 def _track_value(self, value, name): 

902 """Allows storage of non-trackable objects.""" 

903 if isinstance(name, str): 

904 string_key = True 

905 else: 

906 name = "-non_string_key" 

907 string_key = False 

908 try: 

909 no_dependency = isinstance(value, NoDependency) 

910 value = super()._track_value(value=value, name=name) 

911 if not (string_key or no_dependency): 

912 # A non-string key maps to a trackable value. This data structure 

913 # is not saveable. 

914 self._self_non_string_key = True 

915 return value 

916 except ValueError: 

917 # Even if this value isn't trackable, we need to make sure 

918 # NoDependency objects get unwrapped. 

919 return sticky_attribute_assignment( 

920 trackable=self, value=value, name=name) 

921 

922 def _name_element(self, key): 

923 """Tells TrackableDataStructure to use keys as names as-is.""" 

924 return key 

925 

926 def __setitem__(self, key, value): 

927 """Allow any modifications, but possibly mark the wrapper as unsaveable.""" 

928 self._check_self_external_modification() 

929 self._maybe_initialize_trackable() 

930 no_dep = isinstance(value, NoDependency) 

931 if isinstance(key, str): 

932 value = self._track_value(value, name=key) 

933 else: 

934 value = wrap_or_unwrap(value) 

935 if not no_dep and isinstance(value, base.Trackable): 

936 # Non-string keys are OK as long as we have no reason to add a 

937 # dependency on the value (either because the value is not 

938 # trackable, or because it was wrapped in a NoDependency object). 

939 self._self_non_string_key = True 

940 self.__wrapped__[key] = value 

941 

942 self._update_snapshot() 

943 

944 def __delitem__(self, key): 

945 self._check_self_external_modification() 

946 del self.__wrapped__[key] 

947 self._update_snapshot() 

948 

949 def __repr__(self): 

950 return "DictWrapper(%s)" % (repr(self.__wrapped__),) 

951 

952 def __hash__(self): 

953 raise TypeError("unhashable type: 'DictWrapper'") 

954 

955 def __eq__(self, other): 

956 # Override the TrackableDataStructure "== -> is" forwarding and go back to 

957 # the wrapt implementation. 

958 return self.__wrapped__ == other 

959 

960 def update(self, *args, **kwargs): 

961 for key, value in dict(*args, **kwargs).items(): 

962 self[key] = value 

963 

964 

965class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy): 

966 """Trackable wrapper for tuples and namedtuples.""" 

967 

968 def __init__(self, original_wrapped_tuple=()): 

969 add_dependency = [] 

970 substituted_wrapped_tuple = [] 

971 for element in original_wrapped_tuple: 

972 if isinstance(element, NoDependency): 

973 add_dependency.append(False) 

974 else: 

975 add_dependency.append(True) 

976 substituted_wrapped_tuple.append(wrap_or_unwrap(element)) 

977 try: 

978 fields = original_wrapped_tuple._fields 

979 except AttributeError: 

980 # Not a namedtuple 

981 is_namedtuple = False 

982 else: 

983 is_namedtuple = True 

984 original_type = type(original_wrapped_tuple) 

985 # Flag to poison saving if we can't re-construct a namedtupled because its 

986 # __new__ takes different keyword arguments than its _fields. 

987 self._self_tuple_is_constructable = True 

988 if is_namedtuple: 

989 try: 

990 # NamedTuples take N arguments, unlike tuple which takes a sequence. 

991 substituted_wrapped_tuple = original_type( 

992 **dict(zip(fields, substituted_wrapped_tuple))) 

993 except TypeError: 

994 wrapt.ObjectProxy.__init__(self, original_wrapped_tuple) 

995 TrackableDataStructure.__init__(self) 

996 self._self_tuple_is_constructable = False 

997 return 

998 else: 

999 substituted_wrapped_tuple = original_type(substituted_wrapped_tuple) 

1000 wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple) 

1001 TrackableDataStructure.__init__(self) 

1002 

1003 if is_namedtuple: 

1004 # For namedtuples, also track by names for compatibility with 

1005 # dictionaries. 

1006 for name, should_depend, element in zip( 

1007 fields, add_dependency, substituted_wrapped_tuple): 

1008 if should_depend: 

1009 self._track_value(element, name=name) 

1010 

1011 # Track by index as well, for compatibility with lists. 

1012 for index, (should_depend, element) in enumerate( 

1013 zip(add_dependency, substituted_wrapped_tuple)): 

1014 if should_depend: 

1015 self._track_value(element, name="%d" % (index,)) 

1016 

1017 @property 

1018 def _values(self): 

1019 """Collect values for TrackableDataStructure.""" 

1020 return self 

1021 

1022 def _track_value(self, value, name): 

1023 """Allows storage of non-trackable objects.""" 

1024 try: 

1025 value = super()._track_value(value=value, name=name) 

1026 except ValueError: 

1027 # Even if this value isn't trackable, we need to make sure 

1028 # NoDependency objects get unwrapped. 

1029 value = sticky_attribute_assignment( 

1030 trackable=self, value=value, name=name) 

1031 return value 

1032 

1033 def __repr__(self): 

1034 return "_TupleWrapper(%s)" % (repr(self.__wrapped__),) 

1035 

1036 def __hash__(self): 

1037 # Override the TrackableDataStructure hash forwarding and go back to 

1038 # the wrapt implementation. 

1039 return hash(self.__wrapped__) 

1040 

1041 def __eq__(self, other): 

1042 # Override the TrackableDataStructure "== -> is" forwarding and go back to 

1043 # the wrapt implementation. 

1044 return self.__wrapped__ == other 

1045 

1046 def __copy__(self): 

1047 return _TupleWrapper(copy.copy(self.__wrapped__)) 

1048 

1049 def __deepcopy__(self, memo): 

1050 return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo)) 

1051 

1052 def __reduce_ex__(self, protocol): 

1053 return (self.__class__, 

1054 (self.__wrapped__,)) 

1055 

1056 # imul and iadd are the only tuple-relevant in-place operators. They need to 

1057 # be special-cased to avoid mutating the original proxy object. 

1058 def __imul__(self, y): 

1059 """Avoid running self.__wrapped__ *= y, which mutates `self`.""" 

1060 return self.__wrapped__ * y 

1061 

1062 def __iadd__(self, y): 

1063 """Avoid running self.__wrapped__ += y, which mutates `self`.""" 

1064 return self.__wrapped__ + y 

1065 

1066 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 

1067 if not self._self_tuple_is_constructable: 

1068 raise ValueError( 

1069 f"Unable to save because the namedtuple {self.__wrapped__} is not " 

1070 "constructable from its _fields (i.e. __new__ is overridden). " 

1071 f"Expected keyword arguments {self.__wrapped__._fields}. If you do " 

1072 "not need to save this object, consider wrapping it in a custom " 

1073 "object that does not inherit from tuple.") 

1074 return super()._trackable_children(save_type, **kwargs) 

1075 

1076 def __getattribute__(self, name): 

1077 if name != "__wrapped__" and hasattr(self.__wrapped__, name): 

1078 # Prefer attributes on the wrapped object when they conflict with 

1079 # attributes on the wrapper object. 

1080 return getattr(self.__wrapped__, name) 

1081 

1082 if (hasattr(type(self), name) 

1083 and isinstance(getattr(type(self), name), property)): 

1084 # Bypass ObjectProxy for properties. Whether this workaround is necessary 

1085 # appears to depend on the Python version but not the wrapt version: 3.4 

1086 # in particular seems to look up properties on the wrapped object instead 

1087 # of the wrapper without this logic. 

1088 return object.__getattribute__(self, name) 

1089 else: 

1090 return super().__getattribute__(name) 

1091 

1092 

1093def _is_function(x): 

1094 return isinstance(x, (def_function.Function, defun.ConcreteFunction)) 

1095 

1096 

1097def set_list_item(list_object, index_string, value): 

1098 item_index = int(index_string) 

1099 if len(list_object) <= item_index: 

1100 list_object.extend([None] * (1 + item_index - len(list_object))) 

1101 list_object[item_index] = value 

1102 

1103 

1104def set_tuple_item(list_object, index_string, value): 

1105 try: 

1106 item_index = int(index_string) 

1107 except ValueError: 

1108 # Ignore namedtuple fields. 

1109 return 

1110 if len(list_object) <= item_index: 

1111 list_object.extend([None] * (1 + item_index - len(list_object))) 

1112 list_object[item_index] = value