Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/trackable_utils.py: 30%
64 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility methods for the trackable dependencies."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
20import collections
23def pretty_print_node_path(path):
24 if not path:
25 return "root object"
26 else:
27 return "root." + ".".join([p.name for p in path])
30class CyclicDependencyError(Exception):
32 def __init__(self, leftover_dependency_map):
33 """Creates a CyclicDependencyException."""
34 # Leftover edges that were not able to be topologically sorted.
35 self.leftover_dependency_map = leftover_dependency_map
36 super(CyclicDependencyError, self).__init__()
39def order_by_dependency(dependency_map):
40 """Topologically sorts the keys of a map so that dependencies appear first.
42 Uses Kahn's algorithm:
43 https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
45 Args:
46 dependency_map: a dict mapping values to a list of dependencies (other keys
47 in the map). All keys and dependencies must be hashable types.
49 Returns:
50 A sorted array of keys from dependency_map.
52 Raises:
53 CyclicDependencyError: if there is a cycle in the graph.
54 ValueError: If there are values in the dependency map that are not keys in
55 the map.
56 """
57 # Maps trackables -> trackables that depend on them. These are the edges used
58 # in Kahn's algorithm.
59 reverse_dependency_map = collections.defaultdict(set)
60 for x, deps in dependency_map.items():
61 for dep in deps:
62 reverse_dependency_map[dep].add(x)
64 # Validate that all values in the dependency map are also keys.
65 unknown_keys = reverse_dependency_map.keys() - dependency_map.keys()
66 if unknown_keys:
67 raise ValueError("Found values in the dependency map which are not keys: "
68 f"{unknown_keys}")
70 # Generate the list sorted by objects without dependencies -> dependencies.
71 # The returned list will reverse this.
72 reversed_dependency_arr = []
74 # Prefill `to_visit` with all nodes that do not have other objects depending
75 # on them.
76 to_visit = [x for x in dependency_map if x not in reverse_dependency_map]
78 while to_visit:
79 x = to_visit.pop(0)
80 reversed_dependency_arr.append(x)
81 for dep in set(dependency_map[x]):
82 edges = reverse_dependency_map[dep]
83 edges.remove(x)
84 if not edges:
85 to_visit.append(dep)
86 reverse_dependency_map.pop(dep)
88 if reverse_dependency_map:
89 leftover_dependency_map = collections.defaultdict(list)
90 for dep, xs in reverse_dependency_map.items():
91 for x in xs:
92 leftover_dependency_map[x].append(dep)
93 raise CyclicDependencyError(leftover_dependency_map)
95 return reversed(reversed_dependency_arr)
98_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names.
100# Keyword for identifying that the next bit of a checkpoint variable name is a
101# slot name. Checkpoint names for slot variables look like:
102#
103# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name>
104#
105# Where <path to variable> is a full path from the checkpoint root to the
106# variable being slotted for.
107_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
108# Keyword for separating the path to an object from the name of an
109# attribute in checkpoint names. Used like:
110# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
111OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
113# A constant string that is used to reference the save and restore functions of
114# Trackable objects that define `_serialize_to_tensors` and
115# `_restore_from_tensors`. This is written as the key in the
116# `SavedObject.saveable_objects<string, SaveableObject>` map in the SavedModel.
117SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"
120def escape_local_name(name):
121 # We need to support slashes in local names for compatibility, since this
122 # naming scheme is being patched in to things like Layer.add_variable where
123 # slashes were previously accepted. We also want to use slashes to indicate
124 # edges traversed to reach the variable, so we escape forward slashes in
125 # names.
126 return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).replace(
127 r"/", _ESCAPE_CHAR + "S"))
130def object_path_to_string(node_path_arr):
131 """Converts a list of nodes to a string."""
132 return "/".join(
133 (escape_local_name(trackable.name) for trackable in node_path_arr))
136def checkpoint_key(object_path, local_name):
137 """Returns the checkpoint key for a local attribute of an object."""
138 key_suffix = escape_local_name(local_name)
139 if local_name == SERIALIZE_TO_TENSORS_NAME:
140 # In the case that Trackable uses the _serialize_to_tensor API for defining
141 # tensors to save to the checkpoint, the suffix should be the key(s)
142 # returned by `_serialize_to_tensor`. The suffix used here is empty.
143 key_suffix = ""
145 return f"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"
148def extract_object_name(key):
149 """Substrings the checkpoint key to the start of "/.ATTRIBUTES"."""
150 search_key = "/" + OBJECT_ATTRIBUTES_NAME
151 return key[:key.index(search_key)]
154def extract_local_name(key, prefix=None):
155 """Returns the substring after the "/.ATTIBUTES/" in the checkpoint key."""
156 # "local name" refers to the the keys of `Trackable._serialize_to_tensors.`
157 prefix = prefix or ""
158 search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix
159 # If checkpoint is saved from TF1, return key as is.
160 try:
161 return key[key.index(search_key) + len(search_key):]
162 except ValueError:
163 return key
166def slot_variable_key(variable_path, optimizer_path, slot_name):
167 """Returns checkpoint key for a slot variable."""
168 # Name slot variables:
169 #
170 # <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name>
171 #
172 # where <variable name> is exactly the checkpoint name used for the original
173 # variable, including the path from the checkpoint root and the local name in
174 # the object which owns it. Note that we only save slot variables if the
175 # variable it's slotting for is also being saved.
177 return (f"{variable_path}/{_OPTIMIZER_SLOTS_NAME}/{optimizer_path}/"
178 f"{escape_local_name(slot_name)}")