Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/registration/registration.py: 55%
97 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Serialization Registration for SavedModel.
17revived_types registration will be migrated to this infrastructure.
19See the Advanced saving section in go/savedmodel-configurability.
20This API is approved for TF internal use only.
21"""
22import collections
23import re
25from tensorflow.python.util import tf_inspect
28# Only allow valid file/directory characters
29_VALID_REGISTERED_NAME = re.compile(r"^[a-zA-Z0-9._-]+$")
32class _PredicateRegistry(object):
33 """Registry with predicate-based lookup.
35 See the documentation for `register_checkpoint_saver` and
36 `register_serializable` for reasons why predicates are required over a
37 class-based registry.
39 Since this class is used for global registries, each object must be registered
40 to unique names (an error is raised if there are naming conflicts). The lookup
41 searches the predicates in reverse order, so that later-registered predicates
42 are executed first.
43 """
44 __slots__ = ("_registry_name", "_registered_map", "_registered_predicates",
45 "_registered_names")
47 def __init__(self, name):
48 self._registry_name = name
49 # Maps registered name -> object
50 self._registered_map = {}
51 # Maps registered name -> predicate
52 self._registered_predicates = {}
53 # Stores names in the order of registration
54 self._registered_names = []
56 @property
57 def name(self):
58 return self._registry_name
60 def register(self, package, name, predicate, candidate):
61 """Registers a candidate object under the package, name and predicate."""
62 if not isinstance(package, str) or not isinstance(name, str):
63 raise TypeError(
64 f"The package and name registered to a {self.name} must be strings, "
65 f"got: package={type(package)}, name={type(name)}")
66 if not callable(predicate):
67 raise TypeError(
68 f"The predicate registered to a {self.name} must be callable, "
69 f"got: {type(predicate)}")
70 registered_name = package + "." + name
71 if not _VALID_REGISTERED_NAME.match(registered_name):
72 raise ValueError(
73 f"Invalid registered {self.name}. Please check that the package and "
74 f"name follow the regex '{_VALID_REGISTERED_NAME.pattern}': "
75 f"(package='{package}', name='{name}')")
76 if registered_name in self._registered_map:
77 raise ValueError(
78 f"The name '{registered_name}' has already been registered to a "
79 f"{self.name}. Found: {self._registered_map[registered_name]}")
81 self._registered_map[registered_name] = candidate
82 self._registered_predicates[registered_name] = predicate
83 self._registered_names.append(registered_name)
85 def lookup(self, obj):
86 """Looks up the registered object using the predicate.
88 Args:
89 obj: Object to pass to each of the registered predicates to look up the
90 registered object.
91 Returns:
92 The object registered with the first passing predicate.
93 Raises:
94 LookupError if the object does not match any of the predicate functions.
95 """
96 return self._registered_map[self.get_registered_name(obj)]
98 def name_lookup(self, registered_name):
99 """Looks up the registered object using the registered name."""
100 try:
101 return self._registered_map[registered_name]
102 except KeyError:
103 raise LookupError(f"The {self.name} registry does not have name "
104 f"'{registered_name}' registered.")
106 def get_registered_name(self, obj):
107 for registered_name in reversed(self._registered_names):
108 predicate = self._registered_predicates[registered_name]
109 if predicate(obj):
110 return registered_name
111 raise LookupError(f"Could not find matching {self.name} for {type(obj)}.")
113 def get_predicate(self, registered_name):
114 try:
115 return self._registered_predicates[registered_name]
116 except KeyError:
117 raise LookupError(f"The {self.name} registry does not have name "
118 f"'{registered_name}' registered.")
120 def get_registrations(self):
121 return self._registered_predicates
123_class_registry = _PredicateRegistry("serializable class")
124_saver_registry = _PredicateRegistry("checkpoint saver")
127def get_registered_class_name(obj):
128 try:
129 return _class_registry.get_registered_name(obj)
130 except LookupError:
131 return None
134def get_registered_class(registered_name):
135 try:
136 return _class_registry.name_lookup(registered_name)
137 except LookupError:
138 return None
141def register_serializable(package="Custom", name=None, predicate=None): # pylint: disable=unused-argument
142 """Decorator for registering a serializable class.
144 THIS METHOD IS STILL EXPERIMENTAL AND MAY CHANGE AT ANY TIME.
146 Registered classes will be saved with a name generated by combining the
147 `package` and `name` arguments. When loading a SavedModel, modules saved with
148 this registered name will be created using the `_deserialize_from_proto`
149 method.
151 By default, only direct instances of the registered class will be saved/
152 restored with the `serialize_from_proto`/`deserialize_from_proto` methods. To
153 extend the registration to subclasses, use the `predicate argument`:
155 ```python
156 class A(tf.Module):
157 pass
159 register_serializable(
160 package="Example", predicate=lambda obj: isinstance(obj, A))(A)
161 ```
163 Args:
164 package: The package that this class belongs to.
165 name: The name to serialize this class under in this package. If None, the
166 class's name will be used.
167 predicate: An optional function that takes a single Trackable argument, and
168 determines whether that object should be serialized with this `package`
169 and `name`. The default predicate checks whether the object's type exactly
170 matches the registered class. Predicates are executed in the reverse order
171 that they are added (later registrations are checked first).
173 Returns:
174 A decorator that registers the decorated class with the passed names and
175 predicate.
176 """
177 def decorator(arg):
178 """Registers a class with the serialization framework."""
179 nonlocal predicate
180 if not tf_inspect.isclass(arg):
181 raise TypeError("Registered serializable must be a class: {}".format(arg))
183 class_name = name if name is not None else arg.__name__
184 if predicate is None:
185 predicate = lambda x: isinstance(x, arg)
186 _class_registry.register(package, class_name, predicate, arg)
187 return arg
189 return decorator
192RegisteredSaver = collections.namedtuple(
193 "RegisteredSaver", ["name", "predicate", "save_fn", "restore_fn"])
194_REGISTERED_SAVERS = {}
195_REGISTERED_SAVER_NAMES = [] # Stores names in the order of registration
198def register_checkpoint_saver(package="Custom",
199 name=None,
200 predicate=None,
201 save_fn=None,
202 restore_fn=None,
203 strict_predicate_restore=True):
204 """Registers functions which checkpoints & restores objects with custom steps.
206 If you have a class that requires complicated coordination between multiple
207 objects when checkpointing, then you will need to register a custom saver
208 and restore function. An example of this is a custom Variable class that
209 splits the variable across different objects and devices, and needs to write
210 checkpoints that are compatible with different configurations of devices.
212 The registered save and restore functions are used in checkpoints and
213 SavedModel.
215 Please make sure you are familiar with the concepts in the [Checkpointing
216 guide](https://www.tensorflow.org/guide/checkpoint), and ops used to save the
217 V2 checkpoint format:
219 * io_ops.SaveV2
220 * io_ops.MergeV2Checkpoints
221 * io_ops.RestoreV2
223 **Predicate**
225 The predicate is a filter that will run on every `Trackable` object connected
226 to the root object. This function determines whether a `Trackable` should use
227 the registered functions.
229 Example: `lambda x: isinstance(x, CustomClass)`
231 **Custom save function**
233 This is how checkpoint saving works normally:
234 1. Gather all of the Trackables with saveable values.
235 2. For each Trackable, gather all of the saveable tensors.
236 3. Save checkpoint shards (grouping tensors by device) with SaveV2
237 4. Merge the shards with MergeCheckpointV2. This combines all of the shard's
238 metadata, and renames them to follow the standard shard pattern.
240 When a saver is registered, Trackables that pass the registered `predicate`
241 are automatically marked as having saveable values. Next, the custom save
242 function replaces steps 2 and 3 of the saving process. Finally, the shards
243 returned by the custom save function are merged with the other shards.
245 The save function takes in a dictionary of `Trackables` and a `file_prefix`
246 string. The function should save checkpoint shards using the SaveV2 op, and
247 list of the shard prefixes. SaveV2 is currently required to work a correctly,
248 because the code merges all of the returned shards, and the `restore_fn` will
249 only be given the prefix of the merged checkpoint. If you need to be able to
250 save and restore from unmerged shards, please file a feature request.
252 Specification and example of the save function:
254 ```
255 def save_fn(trackables, file_prefix):
256 # trackables: A dictionary mapping unique string identifiers to trackables
257 # file_prefix: A unique file prefix generated using the registered name.
258 ...
259 # Gather the tensors to save.
260 ...
261 io_ops.SaveV2(file_prefix, tensor_names, shapes_and_slices, tensors)
262 return file_prefix # Returns a tensor or a list of string tensors
263 ```
265 The save function is executed before the unregistered save ops.
267 **Custom restore function**
269 Normal checkpoint restore behavior:
270 1. Gather all of the Trackables that have saveable values.
271 2. For each Trackable, get the names of the desired tensors to extract from
272 the checkpoint.
273 3. Use RestoreV2 to read the saved values, and pass the restored tensors to
274 the corresponding Trackables.
276 The custom restore function replaces steps 2 and 3.
278 The restore function also takes a dictionary of `Trackables` and a
279 `merged_prefix` string. The `merged_prefix` is different from the
280 `file_prefix`, since it contains the renamed shard paths. To read from the
281 merged checkpoint, you must use `RestoreV2(merged_prefix, ...)`.
283 Specification:
285 ```
286 def restore_fn(trackables, merged_prefix):
287 # trackables: A dictionary mapping unique string identifiers to Trackables
288 # merged_prefix: File prefix of the merged shard names.
290 restored_tensors = io_ops.restore_v2(
291 merged_prefix, tensor_names, shapes_and_slices, dtypes)
292 ...
293 # Restore the checkpoint values for the given Trackables.
294 ```
296 The restore function is executed after the non-registered restore ops.
298 Args:
299 package: Optional, the package that this class belongs to.
300 name: (Required) The name of this saver, which is saved to the checkpoint.
301 When a checkpoint is restored, the name and package are used to find the
302 the matching restore function. The name and package are also used to
303 generate a unique file prefix that is passed to the save_fn.
304 predicate: (Required) A function that returns a boolean indicating whether a
305 `Trackable` object should be checkpointed with this function. Predicates
306 are executed in the reverse order that they are added (later registrations
307 are checked first).
308 save_fn: (Required) A function that takes a dictionary of trackables and a
309 file prefix as the arguments, writes the checkpoint shards for the given
310 Trackables, and returns the list of shard prefixes.
311 restore_fn: (Required) A function that takes a dictionary of trackables and
312 a file prefix as the arguments and restores the trackable values.
313 strict_predicate_restore: If this is `True` (default), then an error will be
314 raised if the predicate fails during checkpoint restoration. If this is
315 `True`, checkpoint restoration will skip running the restore function.
316 This value is generally set to `False` when the predicate does not pass on
317 the Trackables after being saved/loaded from SavedModel.
319 Raises:
320 ValueError: if the package and name are already registered.
321 """
322 if not callable(save_fn):
323 raise TypeError(f"The save_fn must be callable, got: {type(save_fn)}")
324 if not callable(restore_fn):
325 raise TypeError(f"The restore_fn must be callable, got: {type(restore_fn)}")
327 _saver_registry.register(package, name, predicate, (save_fn, restore_fn,
328 strict_predicate_restore))
331def get_registered_saver_name(trackable):
332 """Returns the name of the registered saver to use with Trackable."""
333 try:
334 return _saver_registry.get_registered_name(trackable)
335 except LookupError:
336 return None
339def get_save_function(registered_name):
340 """Returns save function registered to name."""
341 return _saver_registry.name_lookup(registered_name)[0]
344def get_restore_function(registered_name):
345 """Returns restore function registered to name."""
346 return _saver_registry.name_lookup(registered_name)[1]
349def get_strict_predicate_restore(registered_name):
350 """Returns if the registered restore can be ignored if the predicate fails."""
351 return _saver_registry.name_lookup(registered_name)[2]
354def validate_restore_function(trackable, registered_name):
355 """Validates whether the trackable can be restored with the saver.
357 When using a checkpoint saved with a registered saver, that same saver must
358 also be also registered when loading. The name of that saver is saved to the
359 checkpoint and set in the `registered_name` arg.
361 Args:
362 trackable: A `Trackable` object.
363 registered_name: String name of the expected registered saver. This argument
364 should be set using the name saved in a checkpoint.
366 Raises:
367 ValueError if the saver could not be found, or if the predicate associated
368 with the saver does not pass.
369 """
370 try:
371 _saver_registry.name_lookup(registered_name)
372 except LookupError:
373 raise ValueError(
374 f"Error when restoring object {trackable} from checkpoint. This "
375 "object was saved using a registered saver named "
376 f"'{registered_name}', but this saver cannot be found in the "
377 "current context.")
378 if not _saver_registry.get_predicate(registered_name)(trackable):
379 raise ValueError(
380 f"Object {trackable} was saved with the registered saver named "
381 f"'{registered_name}'. However, this saver cannot be used to restore the "
382 "object because the predicate does not pass.")