Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/serialization.py: 22%
205 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Legacy serialization logic for Keras models."""
17import threading
18import weakref
20import tensorflow.compat.v2 as tf
22from keras.src.utils import tf_contextlib
23from keras.src.utils import tf_inspect
25# isort: off
26from tensorflow.python.util.tf_export import keras_export
28# Flag that determines whether to skip the NotImplementedError when calling
29# get_config in custom models and layers. This is only enabled when saving to
30# SavedModel, when the config isn't required.
31_SKIP_FAILED_SERIALIZATION = False
32# If a layer does not have a defined config, then the returned config will be a
33# dictionary with the below key.
34_LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"
36# Store a unique, per-object ID for shared objects.
37#
38# We store a unique ID for each object so that we may, at loading time,
39# re-create the network properly. Without this ID, we would have no way of
40# determining whether a config is a description of a new object that
41# should be created or is merely a reference to an already-created object.
42SHARED_OBJECT_KEY = "shared_object_id"
44SHARED_OBJECT_DISABLED = threading.local()
45SHARED_OBJECT_LOADING = threading.local()
46SHARED_OBJECT_SAVING = threading.local()
49# Attributes on the threadlocal variable must be set per-thread, thus we
50# cannot initialize these globally. Instead, we have accessor functions with
51# default values.
52def _shared_object_disabled():
53 """Get whether shared object handling is disabled in a threadsafe manner."""
54 return getattr(SHARED_OBJECT_DISABLED, "disabled", False)
57def _shared_object_loading_scope():
58 """Get the current shared object saving scope in a threadsafe manner."""
59 return getattr(SHARED_OBJECT_LOADING, "scope", NoopLoadingScope())
62def _shared_object_saving_scope():
63 """Get the current shared object saving scope in a threadsafe manner."""
64 return getattr(SHARED_OBJECT_SAVING, "scope", None)
67class DisableSharedObjectScope:
68 """A context manager for disabling handling of shared objects.
70 Disables shared object handling for both saving and loading.
72 Created primarily for use with `clone_model`, which does extra surgery that
73 is incompatible with shared objects.
74 """
76 def __enter__(self):
77 SHARED_OBJECT_DISABLED.disabled = True
78 self._orig_loading_scope = _shared_object_loading_scope()
79 self._orig_saving_scope = _shared_object_saving_scope()
81 def __exit__(self, *args, **kwargs):
82 SHARED_OBJECT_DISABLED.disabled = False
83 SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
84 SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
87class NoopLoadingScope:
88 """The default shared object loading scope. It does nothing.
90 Created to simplify serialization code that doesn't care about shared
91 objects (e.g. when serializing a single object).
92 """
94 def get(self, unused_object_id):
95 return None
97 def set(self, object_id, obj):
98 pass
101class SharedObjectLoadingScope:
102 """A context manager for keeping track of loaded objects.
104 During the deserialization process, we may come across objects that are
105 shared across multiple layers. In order to accurately restore the network
106 structure to its original state, `SharedObjectLoadingScope` allows us to
107 re-use shared objects rather than cloning them.
108 """
110 def __enter__(self):
111 if _shared_object_disabled():
112 return NoopLoadingScope()
114 global SHARED_OBJECT_LOADING
115 SHARED_OBJECT_LOADING.scope = self
116 self._obj_ids_to_obj = {}
117 return self
119 def get(self, object_id):
120 """Given a shared object ID, returns a previously instantiated object.
122 Args:
123 object_id: shared object ID to use when attempting to find
124 already-loaded object.
126 Returns:
127 The object, if we've seen this ID before. Else, `None`.
128 """
129 # Explicitly check for `None` internally to make external calling code a
130 # bit cleaner.
131 if object_id is None:
132 return
133 return self._obj_ids_to_obj.get(object_id)
135 def set(self, object_id, obj):
136 """Stores an instantiated object for future lookup and sharing."""
137 if object_id is None:
138 return
139 self._obj_ids_to_obj[object_id] = obj
141 def __exit__(self, *args, **kwargs):
142 global SHARED_OBJECT_LOADING
143 SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
146class SharedObjectConfig(dict):
147 """A configuration container that keeps track of references.
149 `SharedObjectConfig` will automatically attach a shared object ID to any
150 configs which are referenced more than once, allowing for proper shared
151 object reconstruction at load time.
153 In most cases, it would be more proper to subclass something like
154 `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
155 Unfortunately, python's json encoder does not support `Mapping`s. This is
156 important functionality to retain, since we are dealing with serialization.
158 We should be safe to subclass `dict` here, since we aren't actually
159 overriding any core methods, only augmenting with a new one for reference
160 counting.
161 """
163 def __init__(self, base_config, object_id, **kwargs):
164 self.ref_count = 1
165 self.object_id = object_id
166 super().__init__(base_config, **kwargs)
168 def increment_ref_count(self):
169 # As soon as we've seen the object more than once, we want to attach the
170 # shared object ID. This allows us to only attach the shared object ID
171 # when it's strictly necessary, making backwards compatibility breakage
172 # less likely.
173 if self.ref_count == 1:
174 self[SHARED_OBJECT_KEY] = self.object_id
175 self.ref_count += 1
178class SharedObjectSavingScope:
179 """Keeps track of shared object configs when serializing."""
181 def __enter__(self):
182 if _shared_object_disabled():
183 return None
185 global SHARED_OBJECT_SAVING
187 # Serialization can happen at a number of layers for a number of
188 # reasons. We may end up with a case where we're opening a saving scope
189 # within another saving scope. In that case, we'd like to use the
190 # outermost scope available and ignore inner scopes, since there is not
191 # (yet) a reasonable use case for having these nested and distinct.
192 if _shared_object_saving_scope() is not None:
193 self._passthrough = True
194 return _shared_object_saving_scope()
195 else:
196 self._passthrough = False
198 SHARED_OBJECT_SAVING.scope = self
199 self._shared_objects_config = weakref.WeakKeyDictionary()
200 self._next_id = 0
201 return self
203 def get_config(self, obj):
204 """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
206 Args:
207 obj: The object for which to retrieve the `SharedObjectConfig`.
209 Returns:
210 The SharedObjectConfig for a given object, if already seen. Else,
211 `None`.
212 """
213 try:
214 shared_object_config = self._shared_objects_config[obj]
215 except (TypeError, KeyError):
216 # If the object is unhashable (e.g. a subclass of
217 # `AbstractBaseClass` that has not overridden `__hash__`), a
218 # `TypeError` will be thrown. We'll just continue on without shared
219 # object support.
220 return None
221 shared_object_config.increment_ref_count()
222 return shared_object_config
224 def create_config(self, base_config, obj):
225 """Create a new SharedObjectConfig for a given object."""
226 shared_object_config = SharedObjectConfig(base_config, self._next_id)
227 self._next_id += 1
228 try:
229 self._shared_objects_config[obj] = shared_object_config
230 except TypeError:
231 # If the object is unhashable (e.g. a subclass of
232 # `AbstractBaseClass` that has not overridden `__hash__`), a
233 # `TypeError` will be thrown. We'll just continue on without shared
234 # object support.
235 pass
236 return shared_object_config
238 def __exit__(self, *args, **kwargs):
239 if not getattr(self, "_passthrough", False):
240 global SHARED_OBJECT_SAVING
241 SHARED_OBJECT_SAVING.scope = None
244def serialize_keras_class_and_config(
245 cls_name, cls_config, obj=None, shared_object_id=None
246):
247 """Returns the serialization of the class with the given config."""
248 base_config = {"class_name": cls_name, "config": cls_config}
250 # We call `serialize_keras_class_and_config` for some branches of the load
251 # path. In that case, we may already have a shared object ID we'd like to
252 # retain.
253 if shared_object_id is not None:
254 base_config[SHARED_OBJECT_KEY] = shared_object_id
256 # If we have an active `SharedObjectSavingScope`, check whether we've
257 # already serialized this config. If so, just use that config. This will
258 # store an extra ID field in the config, allowing us to re-create the shared
259 # object relationship at load time.
260 if _shared_object_saving_scope() is not None and obj is not None:
261 shared_object_config = _shared_object_saving_scope().get_config(obj)
262 if shared_object_config is None:
263 return _shared_object_saving_scope().create_config(base_config, obj)
264 return shared_object_config
266 return base_config
269@tf_contextlib.contextmanager
270def skip_failed_serialization():
271 global _SKIP_FAILED_SERIALIZATION
272 prev = _SKIP_FAILED_SERIALIZATION
273 try:
274 _SKIP_FAILED_SERIALIZATION = True
275 yield
276 finally:
277 _SKIP_FAILED_SERIALIZATION = prev
280@keras_export("keras.utils.legacy.serialize_keras_object")
281def serialize_keras_object(instance):
282 """Serialize a Keras object into a JSON-compatible representation.
284 Calls to `serialize_keras_object` while underneath the
285 `SharedObjectSavingScope` context manager will cause any objects re-used
286 across multiple layers to be saved with a special shared object ID. This
287 allows the network to be re-created properly during deserialization.
289 Args:
290 instance: The object to serialize.
292 Returns:
293 A dict-like, JSON-compatible representation of the object's config.
294 """
295 from keras.src.saving import object_registration
297 _, instance = tf.__internal__.decorator.unwrap(instance)
298 if instance is None:
299 return None
301 if hasattr(instance, "get_config"):
302 name = object_registration.get_registered_name(instance.__class__)
303 try:
304 config = instance.get_config()
305 except NotImplementedError as e:
306 if _SKIP_FAILED_SERIALIZATION:
307 return serialize_keras_class_and_config(
308 name, {_LAYER_UNDEFINED_CONFIG_KEY: True}
309 )
310 raise e
311 serialization_config = {}
312 for key, item in config.items():
313 if isinstance(item, str):
314 serialization_config[key] = item
315 continue
317 # Any object of a different type needs to be converted to string or
318 # dict for serialization (e.g. custom functions, custom classes)
319 try:
320 serialized_item = serialize_keras_object(item)
321 if isinstance(serialized_item, dict) and not isinstance(
322 item, dict
323 ):
324 serialized_item["__passive_serialization__"] = True
325 serialization_config[key] = serialized_item
326 except ValueError:
327 serialization_config[key] = item
329 name = object_registration.get_registered_name(instance.__class__)
330 return serialize_keras_class_and_config(
331 name, serialization_config, instance
332 )
333 if hasattr(instance, "__name__"):
334 return object_registration.get_registered_name(instance)
335 raise ValueError(
336 f"Cannot serialize {instance} because it doesn't implement "
337 "`get_config()`."
338 )
341def class_and_config_for_serialized_keras_object(
342 config,
343 module_objects=None,
344 custom_objects=None,
345 printable_module_name="object",
346):
347 """Returns the class name and config for a serialized keras object."""
348 from keras.src.saving import object_registration
350 if (
351 not isinstance(config, dict)
352 or "class_name" not in config
353 or "config" not in config
354 ):
355 raise ValueError(
356 f"Improper config format for {config}. "
357 "Expecting python dict contains `class_name` and `config` as keys"
358 )
360 class_name = config["class_name"]
361 cls = object_registration.get_registered_object(
362 class_name, custom_objects, module_objects
363 )
364 if cls is None:
365 raise ValueError(
366 f"Unknown {printable_module_name}: '{class_name}'. "
367 "Please ensure you are using a `keras.utils.custom_object_scope` "
368 "and that this object is included in the scope. See "
369 "https://www.tensorflow.org/guide/keras/save_and_serialize"
370 "#registering_the_custom_object for details."
371 )
373 cls_config = config["config"]
374 # Check if `cls_config` is a list. If it is a list, return the class and the
375 # associated class configs for recursively deserialization. This case will
376 # happen on the old version of sequential model (e.g. `keras_version` ==
377 # "2.0.6"), which is serialized in a different structure, for example
378 # "{'class_name': 'Sequential',
379 # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
380 if isinstance(cls_config, list):
381 return (cls, cls_config)
383 deserialized_objects = {}
384 for key, item in cls_config.items():
385 if key == "name":
386 # Assume that the value of 'name' is a string that should not be
387 # deserialized as a function. This avoids the corner case where
388 # cls_config['name'] has an identical name to a custom function and
389 # gets converted into that function.
390 deserialized_objects[key] = item
391 elif isinstance(item, dict) and "__passive_serialization__" in item:
392 deserialized_objects[key] = deserialize_keras_object(
393 item,
394 module_objects=module_objects,
395 custom_objects=custom_objects,
396 printable_module_name="config_item",
397 )
398 # TODO(momernick): Should this also have 'module_objects'?
399 elif isinstance(item, str) and tf_inspect.isfunction(
400 object_registration.get_registered_object(item, custom_objects)
401 ):
402 # Handle custom functions here. When saving functions, we only save
403 # the function's name as a string. If we find a matching string in
404 # the custom objects during deserialization, we convert the string
405 # back to the original function.
406 # Note that a potential issue is that a string field could have a
407 # naming conflict with a custom function name, but this should be a
408 # rare case. This issue does not occur if a string field has a
409 # naming conflict with a custom object, since the config of an
410 # object will always be a dict.
411 deserialized_objects[
412 key
413 ] = object_registration.get_registered_object(item, custom_objects)
414 for key, item in deserialized_objects.items():
415 cls_config[key] = deserialized_objects[key]
417 return (cls, cls_config)
420@keras_export("keras.utils.legacy.deserialize_keras_object")
421def deserialize_keras_object(
422 identifier,
423 module_objects=None,
424 custom_objects=None,
425 printable_module_name="object",
426):
427 """Turns the serialized form of a Keras object back into an actual object.
429 This function is for mid-level library implementers rather than end users.
431 Importantly, this utility requires you to provide the dict of
432 `module_objects` to use for looking up the object config; this is not
433 populated by default. If you need a deserialization utility that has
434 preexisting knowledge of built-in Keras objects, use e.g.
435 `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`,
436 etc.
438 Calling `deserialize_keras_object` while underneath the
439 `SharedObjectLoadingScope` context manager will cause any already-seen
440 shared objects to be returned as-is rather than creating a new object.
442 Args:
443 identifier: the serialized form of the object.
444 module_objects: A dictionary of built-in objects to look the name up in.
445 Generally, `module_objects` is provided by midlevel library
446 implementers.
447 custom_objects: A dictionary of custom objects to look the name up in.
448 Generally, `custom_objects` is provided by the end user.
449 printable_module_name: A human-readable string representing the type of
450 the object. Printed in case of exception.
452 Returns:
453 The deserialized object.
455 Example:
457 A mid-level library implementer might want to implement a utility for
458 retrieving an object from its config, as such:
460 ```python
461 def deserialize(config, custom_objects=None):
462 return deserialize_keras_object(
463 identifier,
464 module_objects=globals(),
465 custom_objects=custom_objects,
466 name="MyObjectType",
467 )
468 ```
470 This is how e.g. `keras.layers.deserialize()` is implemented.
471 """
472 from keras.src.saving import object_registration
474 if identifier is None:
475 return None
477 if isinstance(identifier, dict):
478 # In this case we are dealing with a Keras config dictionary.
479 config = identifier
480 (cls, cls_config) = class_and_config_for_serialized_keras_object(
481 config, module_objects, custom_objects, printable_module_name
482 )
484 # If this object has already been loaded (i.e. it's shared between
485 # multiple objects), return the already-loaded object.
486 shared_object_id = config.get(SHARED_OBJECT_KEY)
487 shared_object = _shared_object_loading_scope().get(shared_object_id)
488 if shared_object is not None:
489 return shared_object
491 if hasattr(cls, "from_config"):
492 arg_spec = tf_inspect.getfullargspec(cls.from_config)
493 custom_objects = custom_objects or {}
495 if "custom_objects" in arg_spec.args:
496 tlco = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__
497 deserialized_obj = cls.from_config(
498 cls_config,
499 custom_objects={
500 **object_registration._GLOBAL_CUSTOM_OBJECTS,
501 **tlco,
502 **custom_objects,
503 },
504 )
505 else:
506 with object_registration.CustomObjectScope(custom_objects):
507 deserialized_obj = cls.from_config(cls_config)
508 else:
509 # Then `cls` may be a function returning a class.
510 # in this case by convention `config` holds
511 # the kwargs of the function.
512 custom_objects = custom_objects or {}
513 with object_registration.CustomObjectScope(custom_objects):
514 deserialized_obj = cls(**cls_config)
516 # Add object to shared objects, in case we find it referenced again.
517 _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
519 return deserialized_obj
521 elif isinstance(identifier, str):
522 object_name = identifier
523 if custom_objects and object_name in custom_objects:
524 obj = custom_objects.get(object_name)
525 elif (
526 object_name
527 in object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__
528 ):
529 obj = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[
530 object_name
531 ]
532 elif object_name in object_registration._GLOBAL_CUSTOM_OBJECTS:
533 obj = object_registration._GLOBAL_CUSTOM_OBJECTS[object_name]
534 else:
535 obj = module_objects.get(object_name)
536 if obj is None:
537 raise ValueError(
538 f"Unknown {printable_module_name}: '{object_name}'. "
539 "Please ensure you are using a "
540 "`keras.utils.custom_object_scope` "
541 "and that this object is included in the scope. See "
542 "https://www.tensorflow.org/guide/keras/save_and_serialize"
543 "#registering_the_custom_object for details."
544 )
546 # Classes passed by name are instantiated with no args, functions are
547 # returned as-is.
548 if tf_inspect.isclass(obj):
549 return obj()
550 return obj
551 elif tf_inspect.isfunction(identifier):
552 # If a function has already been deserialized, return as is.
553 return identifier
554 else:
555 raise ValueError(
556 "Could not interpret serialized "
557 f"{printable_module_name}: {identifier}"
558 )
561def validate_config(config):
562 """Determines whether config appears to be a valid layer config."""
563 return (
564 isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
565 )
568def is_default(method):
569 """Check if a method is decorated with the `default` wrapper."""
570 return getattr(method, "_is_default", False)