Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/trackable_utils.py: 30%

64 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility methods for the trackable dependencies.""" 

16from __future__ import absolute_import 

17from __future__ import division 

18from __future__ import print_function 

19 

20import collections 

21 

22 

23def pretty_print_node_path(path): 

24 if not path: 

25 return "root object" 

26 else: 

27 return "root." + ".".join([p.name for p in path]) 

28 

29 

30class CyclicDependencyError(Exception): 

31 

32 def __init__(self, leftover_dependency_map): 

33 """Creates a CyclicDependencyException.""" 

34 # Leftover edges that were not able to be topologically sorted. 

35 self.leftover_dependency_map = leftover_dependency_map 

36 super(CyclicDependencyError, self).__init__() 

37 

38 

39def order_by_dependency(dependency_map): 

40 """Topologically sorts the keys of a map so that dependencies appear first. 

41 

42 Uses Kahn's algorithm: 

43 https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm 

44 

45 Args: 

46 dependency_map: a dict mapping values to a list of dependencies (other keys 

47 in the map). All keys and dependencies must be hashable types. 

48 

49 Returns: 

50 A sorted array of keys from dependency_map. 

51 

52 Raises: 

53 CyclicDependencyError: if there is a cycle in the graph. 

54 ValueError: If there are values in the dependency map that are not keys in 

55 the map. 

56 """ 

57 # Maps trackables -> trackables that depend on them. These are the edges used 

58 # in Kahn's algorithm. 

59 reverse_dependency_map = collections.defaultdict(set) 

60 for x, deps in dependency_map.items(): 

61 for dep in deps: 

62 reverse_dependency_map[dep].add(x) 

63 

64 # Validate that all values in the dependency map are also keys. 

65 unknown_keys = reverse_dependency_map.keys() - dependency_map.keys() 

66 if unknown_keys: 

67 raise ValueError("Found values in the dependency map which are not keys: " 

68 f"{unknown_keys}") 

69 

70 # Generate the list sorted by objects without dependencies -> dependencies. 

71 # The returned list will reverse this. 

72 reversed_dependency_arr = [] 

73 

74 # Prefill `to_visit` with all nodes that do not have other objects depending 

75 # on them. 

76 to_visit = [x for x in dependency_map if x not in reverse_dependency_map] 

77 

78 while to_visit: 

79 x = to_visit.pop(0) 

80 reversed_dependency_arr.append(x) 

81 for dep in set(dependency_map[x]): 

82 edges = reverse_dependency_map[dep] 

83 edges.remove(x) 

84 if not edges: 

85 to_visit.append(dep) 

86 reverse_dependency_map.pop(dep) 

87 

88 if reverse_dependency_map: 

89 leftover_dependency_map = collections.defaultdict(list) 

90 for dep, xs in reverse_dependency_map.items(): 

91 for x in xs: 

92 leftover_dependency_map[x].append(dep) 

93 raise CyclicDependencyError(leftover_dependency_map) 

94 

95 return reversed(reversed_dependency_arr) 

96 

97 

98_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. 

99 

100# Keyword for identifying that the next bit of a checkpoint variable name is a 

101# slot name. Checkpoint names for slot variables look like: 

102# 

103# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name> 

104# 

105# Where <path to variable> is a full path from the checkpoint root to the 

106# variable being slotted for. 

107_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" 

108# Keyword for separating the path to an object from the name of an 

109# attribute in checkpoint names. Used like: 

110# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute> 

111OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" 

112 

113# A constant string that is used to reference the save and restore functions of 

114# Trackable objects that define `_serialize_to_tensors` and 

115# `_restore_from_tensors`. This is written as the key in the 

116# `SavedObject.saveable_objects<string, SaveableObject>` map in the SavedModel. 

117SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS" 

118 

119 

120def escape_local_name(name): 

121 # We need to support slashes in local names for compatibility, since this 

122 # naming scheme is being patched in to things like Layer.add_variable where 

123 # slashes were previously accepted. We also want to use slashes to indicate 

124 # edges traversed to reach the variable, so we escape forward slashes in 

125 # names. 

126 return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).replace( 

127 r"/", _ESCAPE_CHAR + "S")) 

128 

129 

130def object_path_to_string(node_path_arr): 

131 """Converts a list of nodes to a string.""" 

132 return "/".join( 

133 (escape_local_name(trackable.name) for trackable in node_path_arr)) 

134 

135 

136def checkpoint_key(object_path, local_name): 

137 """Returns the checkpoint key for a local attribute of an object.""" 

138 key_suffix = escape_local_name(local_name) 

139 if local_name == SERIALIZE_TO_TENSORS_NAME: 

140 # In the case that Trackable uses the _serialize_to_tensor API for defining 

141 # tensors to save to the checkpoint, the suffix should be the key(s) 

142 # returned by `_serialize_to_tensor`. The suffix used here is empty. 

143 key_suffix = "" 

144 

145 return f"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}" 

146 

147 

148def extract_object_name(key): 

149 """Substrings the checkpoint key to the start of "/.ATTRIBUTES".""" 

150 search_key = "/" + OBJECT_ATTRIBUTES_NAME 

151 return key[:key.index(search_key)] 

152 

153 

154def extract_local_name(key, prefix=None): 

155 """Returns the substring after the "/.ATTIBUTES/" in the checkpoint key.""" 

156 # "local name" refers to the the keys of `Trackable._serialize_to_tensors.` 

157 prefix = prefix or "" 

158 search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix 

159 # If checkpoint is saved from TF1, return key as is. 

160 try: 

161 return key[key.index(search_key) + len(search_key):] 

162 except ValueError: 

163 return key 

164 

165 

166def slot_variable_key(variable_path, optimizer_path, slot_name): 

167 """Returns checkpoint key for a slot variable.""" 

168 # Name slot variables: 

169 # 

170 # <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name> 

171 # 

172 # where <variable name> is exactly the checkpoint name used for the original 

173 # variable, including the path from the checkpoint root and the local name in 

174 # the object which owns it. Note that we only save slot variables if the 

175 # variable it's slotting for is also being saved. 

176 

177 return (f"{variable_path}/{_OPTIMIZER_SLOTS_NAME}/{optimizer_path}/" 

178 f"{escape_local_name(slot_name)}")