Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/trackable_view.py: 38%
40 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Manages a Trackable object graph."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import collections
17import weakref
19from tensorflow.python.trackable import base
20from tensorflow.python.trackable import converter
21from tensorflow.python.util import object_identity
22from tensorflow.python.util.tf_export import tf_export
25@tf_export("train.TrackableView", v1=[])
26class TrackableView(object):
27 """Gathers and serializes a trackable view.
29 Example usage:
31 >>> class SimpleModule(tf.Module):
32 ... def __init__(self, name=None):
33 ... super().__init__(name=name)
34 ... self.a_var = tf.Variable(5.0)
35 ... self.b_var = tf.Variable(4.0)
36 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
38 >>> root = SimpleModule(name="root")
39 >>> root.leaf = SimpleModule(name="leaf")
40 >>> trackable_view = tf.train.TrackableView(root)
42 Pass root to tf.train.TrackableView.children() to get the dictionary of all
43 children directly linked to root by name.
44 >>> trackable_view_children = trackable_view.children(root)
45 >>> for item in trackable_view_children.items():
46 ... print(item)
47 ('a_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
48 ('b_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
49 ('vars', ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
50 numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
51 ('leaf', ...)
53 """
55 def __init__(self, root):
56 """Configure the trackable view.
58 Args:
59 root: A `Trackable` object whose variables (including the variables of
60 dependencies, recursively) should be saved. May be a weak reference.
61 """
62 # TrackableView should never contain a strong reference to root, since it
63 # may result in a cycle:
64 # root -> deferred dependencies -> CheckpointPosition
65 # -> CheckpointRestoreCoordinator -> TrackableView -> root
66 self._root_ref = (root if isinstance(root, weakref.ref)
67 else weakref.ref(root))
69 @classmethod
70 def children(cls, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
71 """Returns all child trackables attached to obj.
73 Args:
74 obj: A `Trackable` object.
75 save_type: A string, can be 'savedmodel' or 'checkpoint'.
76 **kwargs: kwargs to use when retrieving the object's children.
78 Returns:
79 Dictionary of all children attached to the object with name to trackable.
80 """
81 # pylint: disable=protected-access
82 obj._maybe_initialize_trackable()
83 children = {}
84 for name, ref in obj._trackable_children(save_type, **kwargs).items():
85 ref = converter.convert_to_trackable(ref, parent=obj)
86 children[name] = ref
87 return children
89 @property
90 def root(self):
91 if isinstance(self._root_ref, weakref.ref):
92 derefed = self._root_ref()
93 assert derefed is not None
94 return derefed
95 else:
96 return self._root_ref
98 def descendants(self):
99 """Returns a list of all nodes from self.root using a breadth first traversal."""
100 return self._descendants_with_paths()[0]
102 def _descendants_with_paths(self):
103 """Returns a list of all nodes and its paths from self.root using a breadth first traversal."""
104 bfs_sorted = []
105 to_visit = collections.deque([self.root])
106 node_paths = object_identity.ObjectIdentityDictionary()
107 node_paths[self.root] = ()
108 while to_visit:
109 current_trackable = to_visit.popleft()
110 bfs_sorted.append(current_trackable)
111 for name, dependency in self.children(current_trackable).items():
112 if dependency not in node_paths:
113 node_paths[dependency] = (
114 node_paths[current_trackable] +
115 (base.TrackableReference(name, dependency),))
116 to_visit.append(dependency)
117 return bfs_sorted, node_paths