Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/data_structures.py: 29%
578 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Trackable data structures."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import collections
17import copy
18import sys
20try:
21 import wrapt
22except ImportError:
23 # Fall back to the build-time dependency if the system package is not available.
24 from .....third_party import wrapt # pylint: disable=relative-beyond-top-level
26from tensorflow.python.eager import def_function
27from tensorflow.python.eager import function as defun
28from tensorflow.python.ops import variables
29from tensorflow.python.trackable import base
30from tensorflow.python.trackable import layer_utils
31from tensorflow.python.util.compat import collections_abc
32from tensorflow.python.util.tf_export import tf_export
35class NoDependency:
36 """Allows attribute assignment to `Trackable` objects with no dependency.
38 Example usage:
39 ```python
40 obj = Trackable()
41 obj.has_dependency = tf.Variable(0., name="dep")
42 obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
43 assert obj.no_dependency.name == "nodep:0"
44 ```
46 `obj` in this example has a dependency on the variable "dep", and both
47 attributes contain un-wrapped `Variable` objects.
49 `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
50 dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
51 `Layer` to the attribute without a checkpoint dependency, but the `Model` will
52 still track the `Layer` (so it will appear in `Model.layers`, and its
53 variables will appear in `Model.variables`).
54 """
56 __slots__ = ["value"]
58 def __init__(self, value):
59 self.value = value
62def _should_wrap_tuple(t):
63 """Determine if a tuple has any trackable components."""
64 # pylint: disable=unidiomatic-typecheck
65 # Exact type checking to avoid mucking up custom logic in list/dict
66 # subclasses, e.g. collections.Counter.
67 for element in t:
68 if isinstance(element, NoDependency):
69 return True # We should remove the NoDependency object from the tuple.
70 if isinstance(element, base.Trackable):
71 return True
72 if type(element) == dict:
73 return True
74 if type(element) == collections.OrderedDict:
75 return True
76 if type(element) == list:
77 return True
78 if isinstance(element, tuple) and _should_wrap_tuple(element):
79 return True
80 # There are no trackable elements or data structures. Tuples are immutable, so
81 # mutation isn't a concern. Don't wrap.
82 return False
83 # pylint: enable=unidiomatic-typecheck
86@tf_export("__internal__.tracking.wrap", v1=[])
87def wrap_or_unwrap(value):
88 """Wraps input value into trackable data structures.
90 This is mostly useful for containers like list, dict, etc, which could contain
91 trackable objects in it. Wrapped data structure will be tracked when
92 associated with a `tf.Module`, so that save model/checkpoint can properly
93 track the dependency.
95 It will also unwrap NoDependency objects.
97 Args:
98 value: the input object to be wrapped.
100 Returns:
101 Wrapped trackable data structure.
102 """
103 # pylint: disable=unidiomatic-typecheck
104 # Exact type checking to avoid mucking up custom logic in list/dict
105 # subclasses, e.g. collections.Counter.
106 if isinstance(value, NoDependency):
107 return value.value
108 if isinstance(value, base.Trackable):
109 return value # Skip conversion for already trackable objects.
110 elif type(value) == dict:
111 return _DictWrapper(value)
112 elif type(value) == collections.OrderedDict:
113 return _DictWrapper(value)
114 elif type(value) == list:
115 return ListWrapper(value)
116 elif isinstance(value, tuple) and _should_wrap_tuple(value):
117 # There are trackable elements or data structures. Wrap the tuple.
118 return _TupleWrapper(value)
119 else:
120 return value
121 # pylint: enable=unidiomatic-typecheck
124@tf_export("__internal__.tracking.sticky_attribute_assignment", v1=[])
125def sticky_attribute_assignment(trackable, name, value):
126 """Adds dependencies, generally called from __setattr__.
128 This behavior is shared between Trackable and Model.
130 Respects NoDependency indicators, but otherwise makes trackable objects
131 out of common data structures and tracks objects by their attribute names.
133 Args:
134 trackable: The object to add dependencies to (generally the one having
135 an attribute assigned).
136 name: The attribute name being assigned.
137 value: The value being assigned. Not necessarily a trackable object.
139 Returns:
140 The value which should be stored in the attribute (unwrapped from a
141 NoDependency object if necessary).
142 """
143 if isinstance(value, NoDependency):
144 add_dependency = False
145 else:
146 add_dependency = True
147 value = wrap_or_unwrap(value)
148 if not add_dependency:
149 return value
150 if isinstance(value, base.Trackable):
151 trackable._track_trackable( # pylint: disable=protected-access
152 value, name=name,
153 # Allow the user to switch the Trackable which is tracked by this
154 # name, since assigning a new variable to an attribute has
155 # historically been fine (e.g. Adam did this).
156 overwrite=True)
157 return value
160class _UntrackableError(ValueError):
162 def __init__(self, value): # pylint: disable=super-init-not-called
163 self._value = value
165 def __str__(self):
166 return ("Only trackable objects (such as Layers or Optimizers) may be "
167 f"stored in a List object. Got {self._value}, which does not "
168 "inherit from Trackable.")
171@tf_export("__internal__.tracking.TrackableDataStructure", v1=[])
172class TrackableDataStructure(base.Trackable):
173 """Base class for data structures which contain trackable objects."""
175 def __init__(self):
176 # Attributes prefixed with "_self_" for compatibility with
177 # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as
178 # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of
179 # ways.
180 self._self_trainable = True
181 self._self_extra_variables = []
182 self._self_attribute_sentinel = layer_utils.AttributeSentinel(True)
184 @property
185 def _attribute_sentinel(self):
186 return self._self_attribute_sentinel
188 @property
189 def trainable(self):
190 return self._self_trainable
192 @trainable.setter
193 def trainable(self, value):
194 self._self_trainable = value
196 def _track_value(self, value, name):
197 """Add a dependency on `value`."""
198 value = sticky_attribute_assignment(
199 trackable=self, value=value, name=name)
200 if isinstance(value, variables.Variable):
201 self._self_extra_variables.append(value)
202 if not isinstance(value, base.Trackable):
203 raise _UntrackableError(value)
204 if hasattr(value, "_use_resource_variables"):
205 # In subclassed models, legacy layers (tf.layers) must always use
206 # resource variables.
207 value._use_resource_variables = True # pylint: disable=protected-access
208 value_attribute_sentinel = getattr(value, "_attribute_sentinel", None)
209 if value_attribute_sentinel:
210 value_attribute_sentinel.add_parent(self._attribute_sentinel)
211 return value
213 @property
214 def _values(self):
215 """An iterable/sequence which may contain trackable objects."""
216 raise NotImplementedError("Abstract method")
218 @property
219 def _layers(self):
220 """All Layers and Layer containers, including empty containers."""
221 # Filter objects on demand so that wrapper objects use values from the thing
222 # they're wrapping if out of sync.
223 collected = []
224 for obj in self._values:
225 if (isinstance(obj, TrackableDataStructure)
226 or layer_utils.is_layer(obj)
227 or layer_utils.has_weights(obj)):
228 collected.append(obj)
229 return collected
231 @property
232 def layers(self):
233 return list(layer_utils.filter_empty_layer_containers(self._layers))
235 @property
236 def trainable_weights(self):
237 if not self._self_trainable:
238 return []
239 trainable_variables = []
240 for obj in self._values:
241 if isinstance(obj, base.Trackable) and hasattr(
242 obj, "trainable_variables"):
243 trainable_variables += obj.trainable_variables
244 trainable_extra_variables = [
245 v for v in self._self_extra_variables if v.trainable
246 ]
247 return trainable_variables + trainable_extra_variables
249 @property
250 def non_trainable_weights(self):
251 trainable_extra_variables = [
252 v for v in self._self_extra_variables if v.trainable
253 ]
254 non_trainable_extra_variables = [
255 v for v in self._self_extra_variables if not v.trainable
256 ]
257 non_trainable_variables = []
258 for obj in self._values:
259 if isinstance(obj, base.Trackable) and hasattr(
260 obj, "non_trainable_variables"):
261 non_trainable_variables += obj.non_trainable_variables
263 if not self._self_trainable:
264 # Return order is all trainable vars, then all non-trainable vars.
265 trainable_variables = []
266 for obj in self._values:
267 if isinstance(obj, base.Trackable) and hasattr(
268 obj, "trainable_variables"):
269 trainable_variables += obj.trainable_variables
271 non_trainable_variables = (
272 trainable_variables + trainable_extra_variables +
273 non_trainable_variables + non_trainable_extra_variables)
274 else:
275 non_trainable_variables = (
276 non_trainable_variables + non_trainable_extra_variables)
278 return non_trainable_variables
280 @property
281 def weights(self):
282 return self.trainable_weights + self.non_trainable_weights
284 @property
285 def trainable_variables(self):
286 return self.trainable_weights
288 @property
289 def non_trainable_variables(self):
290 return self.non_trainable_weights
292 @property
293 def variables(self):
294 return self.weights
296 @property
297 def updates(self):
298 """Aggregate updates from any `Layer` instances."""
299 # Updates and conditional losses are forwarded as-is rather than being
300 # filtered based on inputs, since this is just a container and won't ever
301 # have any inputs.
302 aggregated = []
303 for layer in self.layers:
304 if hasattr(layer, "updates"):
305 aggregated += layer.updates
306 return aggregated
308 @property
309 def losses(self):
310 """Aggregate losses from any `Layer` instances."""
311 aggregated = []
312 for layer in self.layers:
313 if hasattr(layer, "losses"):
314 aggregated += layer.losses
315 return aggregated
317 def __hash__(self):
318 # Support object-identity hashing, so these structures can be used as keys
319 # in sets/dicts.
320 return id(self)
322 def __eq__(self, other):
323 # Similar to Tensors, trackable data structures use object-identity
324 # equality to support set/dict membership.
325 return self is other
328class List(TrackableDataStructure, collections_abc.Sequence):
329 """An append-only sequence type which is trackable.
331 Maintains checkpoint dependencies on its contents (which must also be
332 trackable), and forwards any `Layer` metadata such as updates and losses.
334 Note that `List` is purely a container. It lets a `tf.keras.Model` or
335 other trackable object know about its contents, but does not call any
336 `Layer` instances which are added to it. To indicate a sequence of `Layer`
337 instances which should be called sequentially, use `tf.keras.Sequential`.
339 Example usage:
340 ```python
341 class HasList(tf.keras.Model):
343 def __init__(self):
344 super().__init__()
345 self.layer_list = List([layers.Dense(3)])
346 self.layer_list.append(layers.Dense(4))
348 def call(self, x):
349 aggregation = 0.
350 for l in self.layer_list:
351 x = l(x)
352 aggregation += tf.reduce_sum(x)
353 return aggregation
354 ```
356 This kind of wrapping is necessary because `Trackable` objects do not
357 (yet) deeply inspect regular Python data structures, so for example assigning
358 a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
359 checkpoint dependency and does not add the `Layer` instance's weights to its
360 parent `Model`.
361 """
363 def __init__(self, *args, **kwargs):
364 """Construct a new sequence. Arguments are passed to `list()`."""
365 super().__init__()
366 self._storage = self._make_storage(*args, **kwargs)
367 for index, element in enumerate(self._storage):
368 self._storage[index] = self._track_value(
369 element, name=self._name_element(index))
371 def copy(self):
372 return type(self)(copy.copy(self._storage))
374 def __copy__(self):
375 return self.copy()
377 def __deepcopy__(self, memo):
378 return type(self)(copy.deepcopy(self._storage, memo))
380 def _make_storage(self, *args, **kwargs):
381 """Determines the backing storage (overridden in subclasses)."""
382 return list(*args, **kwargs)
384 def _name_element(self, index):
385 return "%d" % (index,)
387 @property
388 def _values(self):
389 """Collect values for TrackableDataStructure."""
390 return self
392 def append(self, value):
393 """Add a new trackable value."""
394 value = self._track_value(value, self._name_element(len(self._storage)))
395 self._storage.append(value)
397 def extend(self, values):
398 """Add a sequence of trackable values."""
399 for value in values:
400 self.append(value)
402 def __iadd__(self, values):
403 self.extend(values)
404 return self
406 def __add__(self, other):
407 return self._storage + getattr(other, "_storage", other)
409 def __imul__(self, y):
410 if y <= 0:
411 raise ValueError(
412 f"List only supports append, multiplying in place by {y} removes "
413 "elements.")
415 n = len(self._storage)
416 for _ in range(y - 1):
417 for i in range(n):
418 self.append(self._storage[i])
420 return self
422 def __mul__(self, n):
423 return self._storage * n
425 def __rmul__(self, n):
426 return self * n
428 def __radd__(self, other):
429 return other + self._storage
431 def __getitem__(self, key):
432 return self._storage[key]
434 def __getslice__(self, i, j):
435 return self._storage[slice(i, j)]
437 def __len__(self):
438 return len(self._storage)
440 def __repr__(self):
441 return "List(%s)" % (repr(self._storage),)
443 def __sizeof__(self):
444 return super().__sizeof__() + sys.getsizeof(self._storage)
447# TODO(tomhennigan) Update to collections.UserList?
448# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
449# Python 3.4 support (may still be tricky).
450class ListWrapper(
451 List,
452 collections_abc.MutableSequence,
453 # Shadowed, but there for isinstance checks.
454 list):
455 """Wraps the built-in `list` to support restore-on-create for variables.
457 Unlike `List`, this sequence type is mutable in the same ways built-in lists
458 are. Instead of throwing an error immediately like `List`, it records
459 problematic mutations (e.g. assigning a new element to a position already
460 occupied, meaning both elements get the same names at different times) and
461 refuses to save.
463 On assignment to an attribute of a Model or Trackable object, Python
464 lists are replaced with ListWrapper. Wrapping a list in a
465 `NoDependency` object prevents this.
466 """
468 def __init__(self, wrapped_list):
469 """Construct a new list wrapper.
471 Args:
472 wrapped_list: The initial value of the data structure. A shallow copy may
473 be maintained for error checking. `wrapped_list` itself should not be
474 modified directly after constructing the `ListWrapper`, and if changes
475 are detected the `ListWrapper` will throw an exception on save.
476 """
477 # Monotonic flags which indicate this object would not be restored properly,
478 # and therefore should throw an error on save to avoid giving the impression
479 # that restoring it will work.
480 self._non_append_mutation_value = False
481 self._external_modification_value = False
482 super().__init__(wrapped_list)
483 self._last_wrapped_list_snapshot = list(self._storage)
485 @property
486 def _non_append_mutation(self):
487 return self._non_append_mutation_value
489 @_non_append_mutation.setter
490 def _non_append_mutation(self, value):
491 # Trackable only cares that a mutation occurred at some point; when
492 # attempting to save it checks whether a mutation occurred and the object is
493 # in a "dirty" state but otherwise the specifics of how it got to that state
494 # are ignored. By contrast, the attribute cache needs to signal the mutation
495 # immediately since a caller could query the value of an attribute (And
496 # should not hit the cached value since the mutation may have affected the
497 # result.)
498 self._attribute_sentinel.invalidate_all()
499 self._non_append_mutation_value = value
501 @property
502 def _external_modification(self):
503 return self._external_modification_value
505 @_external_modification.setter
506 def _external_modification(self, value):
507 # Invalidate for the same reason as `_non_append_mutation`
508 self._attribute_sentinel.invalidate_all()
509 self._external_modification_value = value
511 # pylint: disable=protected-access
512 def __copy__(self):
513 copied = super().__copy__()
514 copied._non_append_mutation = self._non_append_mutation
515 copied._external_modification = self._external_modification
516 return copied
518 def __deepcopy__(self, memo):
519 copied = super().__deepcopy__(memo)
520 copied._non_append_mutation = self._non_append_mutation
521 copied._external_modification = self._external_modification
522 return copied
523 # pylint: enable=protected-access
525 def __reduce_ex__(self, protocol):
526 return (self.__class__,
527 (self._storage,))
529 def _make_storage(self, wrapped_list):
530 """Use the user's original list for storage."""
531 return wrapped_list
533 def _check_external_modification(self):
534 """Checks for any changes to the wrapped list not through the wrapper."""
535 if self._external_modification or self._non_append_mutation:
536 return
537 if self._storage != self._last_wrapped_list_snapshot:
538 self._external_modification = True
539 self._last_wrapped_list_snapshot = None
541 def _update_snapshot(self):
542 """Acknowledges tracked changes to the wrapped list."""
544 # Mutation tracking for attributes reuses the same infrastructure as
545 # Trackable mutation tracking.
546 self._attribute_sentinel.invalidate_all()
547 if self._external_modification or self._non_append_mutation:
548 return
549 self._last_wrapped_list_snapshot = list(self._storage)
551 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
552 self._check_external_modification()
553 if self._non_append_mutation:
554 raise ValueError(
555 f"Unable to save the object {self} (a list wrapper constructed to "
556 "track trackable TensorFlow objects). A list element was replaced "
557 "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), "
558 "or moved (sort). In order to support restoration on object "
559 "creation, tracking is exclusively for append-only data structures."
560 "\n\nIf you don't need this list checkpointed, wrap it in a "
561 "non-trackable object; it will be subsequently ignored.")
562 if self._external_modification:
563 raise ValueError(
564 f"Unable to save the object {self} (a list wrapper constructed to "
565 "track trackable TensorFlow objects). The wrapped list was modified "
566 f"outside the wrapper (its final value was {self._storage}, its value"
567 " when a checkpoint dependency was added was "
568 f"{self._last_wrapped_list_snapshot}), which breaks "
569 "restoration on object creation.\n\nIf you don't need this list "
570 "checkpointed, wrap it in a NoDependency object; it will be "
571 "subsequently ignored.")
572 children = super()._trackable_children(save_type, **kwargs)
574 if save_type == base.SaveType.SAVEDMODEL:
575 # Add functions to be serialized.
576 children.update({
577 str(key): value
578 for key, value in enumerate(self)
579 if _is_function(value)
580 })
582 return children
584 def _has_mutation_or_trackable(self):
585 """Short-circuits a check for trackables if there's already a mutation."""
586 if self._non_append_mutation:
587 return True
588 return any(isinstance(element, base.Trackable) for element in self._storage)
590 def __delitem__(self, key):
591 self._check_external_modification()
592 if self._has_mutation_or_trackable():
593 self._non_append_mutation = True
594 del self._storage[key]
595 self._update_snapshot()
597 def __setitem__(self, key, value):
598 self._check_external_modification()
600 if isinstance(key, slice):
601 # Note: this is quite inefficient, but the list API supports a broad range
602 # of slice setters (e.g. truncate, extend, replace) and imitating this
603 # for a range of Python versions is non-trivial.
604 storage_copy = list(self._storage)
605 self._storage[key] = value
607 len_before = len(storage_copy)
608 len_now = len(self._storage)
609 for i in range(max(len_before, len_now)):
610 value_now = self._storage[i] if i < len_now else None
611 value_before = storage_copy[i] if i < len_before else None
613 if isinstance(value_before, base.Trackable):
614 self._non_append_mutation = True
616 if value_now is not None and value_now != value_before:
617 self._storage[i] = self._track_value(self._storage[i],
618 self._name_element(i))
620 else:
621 if isinstance(self._storage[key], base.Trackable):
622 self._non_append_mutation = True
623 self._storage[key] = self._track_value(value, self._name_element(key))
625 self._update_snapshot()
627 def append(self, value):
628 """Add a new trackable value."""
629 self._check_external_modification()
630 super().append(value)
631 self._update_snapshot()
633 def extend(self, values):
634 """Add a sequence of trackable values."""
635 self._check_external_modification()
636 super().extend(values)
637 self._update_snapshot()
639 def __imul__(self, y):
640 if y <= 0:
641 self._check_external_modification()
642 if self._has_mutation_or_trackable():
643 self._non_append_mutation = True
644 self._storage *= y
645 self._update_snapshot()
646 return self
648 # Relies on super() calling append, which updates the snapshot.
649 return super().__imul__(y)
651 def __eq__(self, other):
652 return self._storage == getattr(other, "_storage", other)
654 def __ne__(self, other):
655 return self._storage != getattr(other, "_storage", other)
657 def __lt__(self, other):
658 return self._storage < getattr(other, "_storage", other)
660 def __le__(self, other):
661 return self._storage <= getattr(other, "_storage", other)
663 def __gt__(self, other):
664 return self._storage > getattr(other, "_storage", other)
666 def __ge__(self, other):
667 return self._storage >= getattr(other, "_storage", other)
669 def __hash__(self):
670 # List wrappers need to compare like regular lists, and so like regular
671 # lists they don't belong in hash tables.
672 raise TypeError("unhashable type: 'ListWrapper'")
674 def insert(self, index, obj):
675 self._check_external_modification()
676 if (self._has_mutation_or_trackable() or isinstance(obj, base.Trackable)):
677 self._non_append_mutation = True
678 self._storage.insert(index, obj)
679 self._update_snapshot()
681 def sort(self):
682 self._check_external_modification()
683 if self._has_mutation_or_trackable():
684 self._non_append_mutation = True
685 self._storage.sort()
686 self._update_snapshot()
688 def __setslice__(self, i, j, y):
689 self.__setitem__(slice(i, j), y)
691 def __delslice__(self, i, j):
692 self._check_external_modification()
693 if self._has_mutation_or_trackable():
694 self._non_append_mutation = True
695 del self._storage[slice(i, j)]
696 self._update_snapshot()
698 def _track_value(self, value, name):
699 """Allows storage of non-trackable objects."""
700 try:
701 value = super()._track_value(value=value, name=name)
702 except ValueError:
703 # Even if this value isn't trackable, we need to make sure
704 # NoDependency objects get unwrapped.
705 value = sticky_attribute_assignment(
706 trackable=self, value=value, name=name)
707 return value
709 def __repr__(self):
710 return "ListWrapper(%s)" % (repr(self._storage),)
713class Mapping(TrackableDataStructure, collections_abc.Mapping):
714 """An append-only trackable mapping data structure with string keys.
716 Maintains checkpoint dependencies on its contents (which must also be
717 trackable), named based on its keys.
719 Note that once a key has been added, it may not be deleted or replaced.
720 """
722 def __init__(self, *args, **kwargs):
723 """Construct a new sequence. Arguments are passed to `dict()`."""
724 super().__init__()
725 self._storage = self._make_storage(*args, **kwargs)
726 self._storage.update(
727 {key: self._track_value(
728 value, name=self._name_element(key))
729 for key, value in self._storage.items()})
731 def __copy__(self):
732 return type(self)(copy.copy(self._storage))
734 def __deepcopy__(self, memo):
735 return type(self)(copy.deepcopy(self._storage, memo))
737 def _make_storage(self, *args, **kwargs):
738 return dict(*args, **kwargs)
740 @property
741 def _values(self):
742 """Collect values for TrackableDataStructure."""
743 # Sort items deterministically by key
744 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
745 if ordered:
746 return ordered[1]
747 return []
749 def _name_element(self, key):
750 if not isinstance(key, str):
751 raise TypeError(
752 f"Mapping accepts only string keys, but got a key {repr(key)}.")
753 return str(key)
755 def __setitem__(self, key, value):
756 name = self._name_element(key)
757 value = self._track_value(value, name=name)
758 current_value = self._storage.setdefault(key, value)
759 if current_value is not value:
760 raise ValueError(
761 "Mappings are an append-only data structure. Tried to overwrite the "
762 f"key '{key}' with value {value}, but it already contains "
763 f"{current_value}")
765 def update(self, *args, **kwargs):
766 for key, value in dict(*args, **kwargs).items():
767 self[key] = value
769 def __getitem__(self, key):
770 return self._storage[key]
772 def __len__(self):
773 return len(self._storage)
775 def __repr__(self):
776 return "Mapping(%s)" % (repr(self._storage),)
778 def __iter__(self):
779 return iter(self._storage)
782class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
783 """Wraps built-in dicts to support restore-on-create for variables.
785 _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping,
786 _DictWrapper allows non-string keys and values and arbitrary mutations (delete
787 keys, reassign values). Like ListWrapper, these mutations mean that
788 _DictWrapper will raise an exception on save.
789 """
791 def __init__(self, wrapped_dict=None):
792 if wrapped_dict is None:
793 # Allow zero-argument construction, e.g. from session.run's re-wrapping.
794 wrapped_dict = {}
795 if not isinstance(wrapped_dict, collections_abc.Mapping):
796 # Allow construction from a sequence, e.g. from nest.pack_sequence_as.
797 wrapped_dict = dict(wrapped_dict)
798 wrapt.ObjectProxy.__init__(self, wrapped_dict)
799 TrackableDataStructure.__init__(self)
800 self._self_non_string_key = False
801 self._self_external_modification = False
802 self.__wrapped__.update(
803 {key: self._track_value(
804 value, name=self._name_element(key))
805 for key, value in self.__wrapped__.items()})
806 self._update_snapshot()
808 def __reduce_ex__(self, protocol):
809 return (self.__class__,
810 (self.__wrapped__,))
812 def __getattribute__(self, name):
813 if (hasattr(type(self), name)
814 and isinstance(getattr(type(self), name), property)):
815 # Bypass ObjectProxy for properties. Whether this workaround is necessary
816 # appears to depend on the Python version but not the wrapt version: 3.4
817 # in particular seems to look up properties on the wrapped object instead
818 # of the wrapper without this logic.
819 return object.__getattribute__(self, name)
820 else:
821 return super().__getattribute__(name)
823 def copy(self):
824 return copy.copy(self)
826 # pylint: disable=protected-access
827 def __copy__(self):
828 copied = _DictWrapper(copy.copy(self.__wrapped__))
829 copied._self_external_modification = self._self_external_modification
830 copied._self_non_string_key = self._self_non_string_key
831 return copied
833 def __deepcopy__(self, memo):
834 copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
835 copied._self_external_modification = self._self_external_modification
836 copied._self_non_string_key = self._self_non_string_key
837 return copied
838 # pylint: enable=protected-access
840 @property
841 def _values(self):
842 """Collect values for TrackableDataStructure."""
843 # Sort items deterministically by key
844 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
845 if ordered:
846 return ordered[1]
847 return []
849 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
850 """Check that the object is saveable before listing its dependencies."""
851 self._check_self_external_modification()
852 if self._self_non_string_key:
853 raise ValueError(
854 f"Unable to save the object {self} (a dictionary wrapper constructed "
855 "automatically on attribute assignment). The wrapped dictionary "
856 "contains a non-string key which maps to a trackable object or "
857 "mutable data structure.\n\nIf you don't need this dictionary "
858 "checkpointed, wrap it in a non-trackable "
859 "object; it will be subsequently ignored.")
860 if self._self_external_modification:
861 raise ValueError(
862 f"Unable to save the object {self} (a dictionary wrapper constructed "
863 "automatically on attribute assignment). The wrapped dictionary was "
864 f"modified outside the wrapper (its final value was {self}, its value"
865 " when a checkpoint dependency was added was "
866 f"{self._self_last_wrapped_dict_snapshot}), which breaks "
867 "restoration on object creation.\n\nIf you don't need this "
868 "dictionary checkpointed, wrap it in a "
869 "non-trackable object; it will be subsequently ignored.")
870 assert not self._dirty # Any reason for dirtiness should have an exception.
871 children = super()._trackable_children(save_type, **kwargs)
873 if save_type == base.SaveType.SAVEDMODEL:
874 # Add functions to be serialized.
875 children.update(
876 {key: value for key, value in self.items() if _is_function(value)})
878 return children
880 @property
881 def _dirty(self):
882 """Check if there has already been a mutation which prevents saving."""
883 return (self._self_external_modification
884 or self._self_non_string_key)
886 def _check_self_external_modification(self):
887 """Checks for any changes to the wrapped dict not through the wrapper."""
888 if self._dirty:
889 return
890 if self != self._self_last_wrapped_dict_snapshot:
891 self._self_external_modification = True
892 self._self_last_wrapped_dict_snapshot = None
894 def _update_snapshot(self):
895 """Acknowledges tracked changes to the wrapped dict."""
896 self._attribute_sentinel.invalidate_all()
897 if self._dirty:
898 return
899 self._self_last_wrapped_dict_snapshot = dict(self)
901 def _track_value(self, value, name):
902 """Allows storage of non-trackable objects."""
903 if isinstance(name, str):
904 string_key = True
905 else:
906 name = "-non_string_key"
907 string_key = False
908 try:
909 no_dependency = isinstance(value, NoDependency)
910 value = super()._track_value(value=value, name=name)
911 if not (string_key or no_dependency):
912 # A non-string key maps to a trackable value. This data structure
913 # is not saveable.
914 self._self_non_string_key = True
915 return value
916 except ValueError:
917 # Even if this value isn't trackable, we need to make sure
918 # NoDependency objects get unwrapped.
919 return sticky_attribute_assignment(
920 trackable=self, value=value, name=name)
922 def _name_element(self, key):
923 """Tells TrackableDataStructure to use keys as names as-is."""
924 return key
926 def __setitem__(self, key, value):
927 """Allow any modifications, but possibly mark the wrapper as unsaveable."""
928 self._check_self_external_modification()
929 self._maybe_initialize_trackable()
930 no_dep = isinstance(value, NoDependency)
931 if isinstance(key, str):
932 value = self._track_value(value, name=key)
933 else:
934 value = wrap_or_unwrap(value)
935 if not no_dep and isinstance(value, base.Trackable):
936 # Non-string keys are OK as long as we have no reason to add a
937 # dependency on the value (either because the value is not
938 # trackable, or because it was wrapped in a NoDependency object).
939 self._self_non_string_key = True
940 self.__wrapped__[key] = value
942 self._update_snapshot()
944 def __delitem__(self, key):
945 self._check_self_external_modification()
946 del self.__wrapped__[key]
947 self._update_snapshot()
949 def __repr__(self):
950 return "DictWrapper(%s)" % (repr(self.__wrapped__),)
952 def __hash__(self):
953 raise TypeError("unhashable type: 'DictWrapper'")
955 def __eq__(self, other):
956 # Override the TrackableDataStructure "== -> is" forwarding and go back to
957 # the wrapt implementation.
958 return self.__wrapped__ == other
960 def update(self, *args, **kwargs):
961 for key, value in dict(*args, **kwargs).items():
962 self[key] = value
965class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy):
966 """Trackable wrapper for tuples and namedtuples."""
968 def __init__(self, original_wrapped_tuple=()):
969 add_dependency = []
970 substituted_wrapped_tuple = []
971 for element in original_wrapped_tuple:
972 if isinstance(element, NoDependency):
973 add_dependency.append(False)
974 else:
975 add_dependency.append(True)
976 substituted_wrapped_tuple.append(wrap_or_unwrap(element))
977 try:
978 fields = original_wrapped_tuple._fields
979 except AttributeError:
980 # Not a namedtuple
981 is_namedtuple = False
982 else:
983 is_namedtuple = True
984 original_type = type(original_wrapped_tuple)
985 # Flag to poison saving if we can't re-construct a namedtupled because its
986 # __new__ takes different keyword arguments than its _fields.
987 self._self_tuple_is_constructable = True
988 if is_namedtuple:
989 try:
990 # NamedTuples take N arguments, unlike tuple which takes a sequence.
991 substituted_wrapped_tuple = original_type(
992 **dict(zip(fields, substituted_wrapped_tuple)))
993 except TypeError:
994 wrapt.ObjectProxy.__init__(self, original_wrapped_tuple)
995 TrackableDataStructure.__init__(self)
996 self._self_tuple_is_constructable = False
997 return
998 else:
999 substituted_wrapped_tuple = original_type(substituted_wrapped_tuple)
1000 wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple)
1001 TrackableDataStructure.__init__(self)
1003 if is_namedtuple:
1004 # For namedtuples, also track by names for compatibility with
1005 # dictionaries.
1006 for name, should_depend, element in zip(
1007 fields, add_dependency, substituted_wrapped_tuple):
1008 if should_depend:
1009 self._track_value(element, name=name)
1011 # Track by index as well, for compatibility with lists.
1012 for index, (should_depend, element) in enumerate(
1013 zip(add_dependency, substituted_wrapped_tuple)):
1014 if should_depend:
1015 self._track_value(element, name="%d" % (index,))
1017 @property
1018 def _values(self):
1019 """Collect values for TrackableDataStructure."""
1020 return self
1022 def _track_value(self, value, name):
1023 """Allows storage of non-trackable objects."""
1024 try:
1025 value = super()._track_value(value=value, name=name)
1026 except ValueError:
1027 # Even if this value isn't trackable, we need to make sure
1028 # NoDependency objects get unwrapped.
1029 value = sticky_attribute_assignment(
1030 trackable=self, value=value, name=name)
1031 return value
1033 def __repr__(self):
1034 return "_TupleWrapper(%s)" % (repr(self.__wrapped__),)
1036 def __hash__(self):
1037 # Override the TrackableDataStructure hash forwarding and go back to
1038 # the wrapt implementation.
1039 return hash(self.__wrapped__)
1041 def __eq__(self, other):
1042 # Override the TrackableDataStructure "== -> is" forwarding and go back to
1043 # the wrapt implementation.
1044 return self.__wrapped__ == other
1046 def __copy__(self):
1047 return _TupleWrapper(copy.copy(self.__wrapped__))
1049 def __deepcopy__(self, memo):
1050 return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo))
1052 def __reduce_ex__(self, protocol):
1053 return (self.__class__,
1054 (self.__wrapped__,))
1056 # imul and iadd are the only tuple-relevant in-place operators. They need to
1057 # be special-cased to avoid mutating the original proxy object.
1058 def __imul__(self, y):
1059 """Avoid running self.__wrapped__ *= y, which mutates `self`."""
1060 return self.__wrapped__ * y
1062 def __iadd__(self, y):
1063 """Avoid running self.__wrapped__ += y, which mutates `self`."""
1064 return self.__wrapped__ + y
1066 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
1067 if not self._self_tuple_is_constructable:
1068 raise ValueError(
1069 f"Unable to save because the namedtuple {self.__wrapped__} is not "
1070 "constructable from its _fields (i.e. __new__ is overridden). "
1071 f"Expected keyword arguments {self.__wrapped__._fields}. If you do "
1072 "not need to save this object, consider wrapping it in a custom "
1073 "object that does not inherit from tuple.")
1074 return super()._trackable_children(save_type, **kwargs)
1076 def __getattribute__(self, name):
1077 if name != "__wrapped__" and hasattr(self.__wrapped__, name):
1078 # Prefer attributes on the wrapped object when they conflict with
1079 # attributes on the wrapper object.
1080 return getattr(self.__wrapped__, name)
1082 if (hasattr(type(self), name)
1083 and isinstance(getattr(type(self), name), property)):
1084 # Bypass ObjectProxy for properties. Whether this workaround is necessary
1085 # appears to depend on the Python version but not the wrapt version: 3.4
1086 # in particular seems to look up properties on the wrapped object instead
1087 # of the wrapper without this logic.
1088 return object.__getattribute__(self, name)
1089 else:
1090 return super().__getattribute__(name)
1093def _is_function(x):
1094 return isinstance(x, (def_function.Function, defun.ConcreteFunction))
1097def set_list_item(list_object, index_string, value):
1098 item_index = int(index_string)
1099 if len(list_object) <= item_index:
1100 list_object.extend([None] * (1 + item_index - len(list_object)))
1101 list_object[item_index] = value
1104def set_tuple_item(list_object, index_string, value):
1105 try:
1106 item_index = int(index_string)
1107 except ValueError:
1108 # Ignore namedtuple fields.
1109 return
1110 if len(list_object) <= item_index:
1111 list_object.extend([None] * (1 + item_index - len(list_object)))
1112 list_object[item_index] = value