Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/graph_view.py: 38%
53 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Manages a graph of Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import copy
17import weakref
19from tensorflow.python.checkpoint import save_util_v1
20from tensorflow.python.checkpoint import trackable_view
21from tensorflow.python.trackable import base
22from tensorflow.python.util.tf_export import tf_export
25@tf_export("__internal__.tracking.ObjectGraphView", v1=[])
26class ObjectGraphView(trackable_view.TrackableView):
27 """Gathers and serializes an object graph."""
29 def __init__(self, root, attached_dependencies=None):
30 """Configure the graph view.
32 Args:
33 root: A `Trackable` object whose variables (including the variables of
34 dependencies, recursively) should be saved. May be a weak reference.
35 attached_dependencies: List of dependencies to attach to the root object.
36 Used when saving a Checkpoint with a defined root object. To avoid
37 reference cycles, this should use the WeakTrackableReference class.
38 """
39 trackable_view.TrackableView.__init__(self, root)
40 # ObjectGraphView should never contain a strong reference to root, since it
41 # may result in a cycle:
42 # root -> deferred dependencies -> CheckpointPosition
43 # -> CheckpointRestoreCoordinator -> ObjectGraphView -> root
44 self._root_ref = (root if isinstance(root, weakref.ref)
45 else weakref.ref(root))
46 self._attached_dependencies = attached_dependencies
48 def __deepcopy__(self, memo):
49 # By default, weak references are not copied, which leads to surprising
50 # deepcopy behavior. To fix, we first we copy the object itself, then we
51 # make a weak reference to the copy.
52 strong_root = self._root_ref()
53 if strong_root is not None:
54 strong_copy = copy.deepcopy(strong_root, memo)
55 memo[id(self._root_ref)] = weakref.ref(strong_copy)
56 # super() does not have a __deepcopy__, so we need to re-implement it
57 copied = super().__new__(type(self))
58 memo[id(self)] = copied
59 for key, value in vars(self).items():
60 setattr(copied, key, copy.deepcopy(value, memo))
61 return copied
63 def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
64 """Returns list of all child trackables attached to obj.
66 Args:
67 obj: A `Trackable` object.
68 save_type: A string, can be 'savedmodel' or 'checkpoint'.
69 **kwargs: kwargs to use when retrieving the object's children.
71 Returns:
72 List of all children attached to the object.
73 """
74 children = []
75 for name, ref in super(ObjectGraphView,
76 self).children(obj, save_type, **kwargs).items():
77 children.append(base.TrackableReference(name, ref))
79 # GraphView objects may define children of the root object that are not
80 # actually attached, e.g. a Checkpoint object's save_counter.
81 if obj is self.root and self._attached_dependencies:
82 children.extend(self._attached_dependencies)
83 return children
85 def children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
86 """Returns all child trackables attached to obj.
88 Args:
89 obj: A `Trackable` object.
90 save_type: A string, can be 'savedmodel' or 'checkpoint'.
91 **kwargs: kwargs to use when retrieving the object's children.
93 Returns:
94 Dictionary of all children attached to the object with name to trackable.
95 """
96 children = {}
97 for name, ref in self.list_children(obj, **kwargs):
98 children[name] = ref
99 return children
101 @property
102 def attached_dependencies(self):
103 """Returns list of dependencies that should be saved in the checkpoint.
105 These dependencies are not tracked by root, but are in the checkpoint.
106 This is defined when the user creates a Checkpoint with both root and kwargs
107 set.
109 Returns:
110 A list of TrackableReferences.
111 """
112 return self._attached_dependencies
114 @property
115 def root(self):
116 if isinstance(self._root_ref, weakref.ref):
117 derefed = self._root_ref()
118 assert derefed is not None
119 return derefed
120 else:
121 return self._root_ref
123 def breadth_first_traversal(self):
124 return self._breadth_first_traversal()
126 def _breadth_first_traversal(self):
127 """Find shortest paths to all dependencies of self.root."""
128 return super(ObjectGraphView, self)._descendants_with_paths()
130 def serialize_object_graph(self, saveables_cache=None):
131 """Determine checkpoint keys for variables and build a serialized graph.
133 Non-slot variables are keyed based on a shortest path from the root saveable
134 to the object which owns the variable (i.e. the one which called
135 `Trackable._add_variable` to create it).
137 Slot variables are keyed based on a shortest path to the variable being
138 slotted for, a shortest path to their optimizer, and the slot name.
140 Args:
141 saveables_cache: An optional cache storing previously created
142 SaveableObjects created for each Trackable. Maps Trackables to a
143 dictionary of attribute names to Trackable.
145 Returns:
146 A tuple of (named_variables, object_graph_proto, feed_additions):
147 named_variables: A dictionary mapping names to variable objects.
148 object_graph_proto: A TrackableObjectGraph protocol buffer
149 containing the serialized object graph and variable references.
150 feed_additions: A dictionary mapping from Tensors to values which should
151 be fed when saving.
153 Raises:
154 ValueError: If there are invalid characters in an optimizer's slot names.
155 """
156 named_saveable_objects, object_graph_proto, feed_additions, _ = (
157 save_util_v1.serialize_object_graph_with_registered_savers(
158 self, saveables_cache))
159 return named_saveable_objects, object_graph_proto, feed_additions
161 def frozen_saveable_objects(self,
162 object_map=None,
163 to_graph=None,
164 call_with_mapped_captures=None):
165 """Creates SaveableObjects with the current object graph frozen."""
166 return save_util_v1.frozen_saveables_and_savers(
167 self, object_map, to_graph, call_with_mapped_captures)[0]