Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/trackable_view.py: 38%

40 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Manages a Trackable object graph.""" 

2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16import collections 

17import weakref 

18 

19from tensorflow.python.trackable import base 

20from tensorflow.python.trackable import converter 

21from tensorflow.python.util import object_identity 

22from tensorflow.python.util.tf_export import tf_export 

23 

24 

25@tf_export("train.TrackableView", v1=[]) 

26class TrackableView(object): 

27 """Gathers and serializes a trackable view. 

28 

29 Example usage: 

30 

31 >>> class SimpleModule(tf.Module): 

32 ... def __init__(self, name=None): 

33 ... super().__init__(name=name) 

34 ... self.a_var = tf.Variable(5.0) 

35 ... self.b_var = tf.Variable(4.0) 

36 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] 

37 

38 >>> root = SimpleModule(name="root") 

39 >>> root.leaf = SimpleModule(name="leaf") 

40 >>> trackable_view = tf.train.TrackableView(root) 

41 

42 Pass root to tf.train.TrackableView.children() to get the dictionary of all 

43 children directly linked to root by name. 

44 >>> trackable_view_children = trackable_view.children(root) 

45 >>> for item in trackable_view_children.items(): 

46 ... print(item) 

47 ('a_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>) 

48 ('b_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>) 

49 ('vars', ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, 

50 numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>])) 

51 ('leaf', ...) 

52 

53 """ 

54 

55 def __init__(self, root): 

56 """Configure the trackable view. 

57 

58 Args: 

59 root: A `Trackable` object whose variables (including the variables of 

60 dependencies, recursively) should be saved. May be a weak reference. 

61 """ 

62 # TrackableView should never contain a strong reference to root, since it 

63 # may result in a cycle: 

64 # root -> deferred dependencies -> CheckpointPosition 

65 # -> CheckpointRestoreCoordinator -> TrackableView -> root 

66 self._root_ref = (root if isinstance(root, weakref.ref) 

67 else weakref.ref(root)) 

68 

69 @classmethod 

70 def children(cls, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): 

71 """Returns all child trackables attached to obj. 

72 

73 Args: 

74 obj: A `Trackable` object. 

75 save_type: A string, can be 'savedmodel' or 'checkpoint'. 

76 **kwargs: kwargs to use when retrieving the object's children. 

77 

78 Returns: 

79 Dictionary of all children attached to the object with name to trackable. 

80 """ 

81 # pylint: disable=protected-access 

82 obj._maybe_initialize_trackable() 

83 children = {} 

84 for name, ref in obj._trackable_children(save_type, **kwargs).items(): 

85 ref = converter.convert_to_trackable(ref, parent=obj) 

86 children[name] = ref 

87 return children 

88 

89 @property 

90 def root(self): 

91 if isinstance(self._root_ref, weakref.ref): 

92 derefed = self._root_ref() 

93 assert derefed is not None 

94 return derefed 

95 else: 

96 return self._root_ref 

97 

98 def descendants(self): 

99 """Returns a list of all nodes from self.root using a breadth first traversal.""" 

100 return self._descendants_with_paths()[0] 

101 

102 def _descendants_with_paths(self): 

103 """Returns a list of all nodes and its paths from self.root using a breadth first traversal.""" 

104 bfs_sorted = [] 

105 to_visit = collections.deque([self.root]) 

106 node_paths = object_identity.ObjectIdentityDictionary() 

107 node_paths[self.root] = () 

108 while to_visit: 

109 current_trackable = to_visit.popleft() 

110 bfs_sorted.append(current_trackable) 

111 for name, dependency in self.children(current_trackable).items(): 

112 if dependency not in node_paths: 

113 node_paths[dependency] = ( 

114 node_paths[current_trackable] + 

115 (base.TrackableReference(name, dependency),)) 

116 to_visit.append(dependency) 

117 return bfs_sorted, node_paths