Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/serialization_lib.py: 16%
289 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Object config serialization and deserialization logic."""
17import importlib
18import inspect
19import threading
20import types
21import warnings
23import numpy as np
24import tensorflow.compat.v2 as tf
26from keras.src.saving import object_registration
27from keras.src.saving.legacy import serialization as legacy_serialization
28from keras.src.saving.legacy.saved_model.utils import in_tf_saved_model_scope
29from keras.src.utils import generic_utils
31# isort: off
32from tensorflow.python.util import tf_export
33from tensorflow.python.util.tf_export import keras_export
35PLAIN_TYPES = (str, int, float, bool)
36SHARED_OBJECTS = threading.local()
37SAFE_MODE = threading.local()
38# TODO(nkovela): Debug serialization of decorated functions inside lambdas
39# to allow for serialization of custom_gradient.
40NON_SERIALIZABLE_CLASS_MODULES = ("tensorflow.python.ops.custom_gradient",)
42# List of Keras modules with built-in string representations for Keras defaults
43BUILTIN_MODULES = (
44 "activations",
45 "constraints",
46 "initializers",
47 "losses",
48 "metrics",
49 "optimizers",
50 "regularizers",
51)
54class Config:
55 def __init__(self, **config):
56 self.config = config
58 def serialize(self):
59 return serialize_keras_object(self.config)
62class SafeModeScope:
63 """Scope to propagate safe mode flag to nested deserialization calls."""
65 def __init__(self, safe_mode=True):
66 self.safe_mode = safe_mode
68 def __enter__(self):
69 self.original_value = in_safe_mode()
70 SAFE_MODE.safe_mode = self.safe_mode
72 def __exit__(self, *args, **kwargs):
73 SAFE_MODE.safe_mode = self.original_value
76@keras_export("keras.__internal__.enable_unsafe_deserialization")
77def enable_unsafe_deserialization():
78 """Disables safe mode globally, allowing deserialization of lambdas."""
79 SAFE_MODE.safe_mode = False
82def in_safe_mode():
83 return getattr(SAFE_MODE, "safe_mode", None)
86class ObjectSharingScope:
87 """Scope to enable detection and reuse of previously seen objects."""
89 def __enter__(self):
90 SHARED_OBJECTS.enabled = True
91 SHARED_OBJECTS.id_to_obj_map = {}
92 SHARED_OBJECTS.id_to_config_map = {}
94 def __exit__(self, *args, **kwargs):
95 SHARED_OBJECTS.enabled = False
96 SHARED_OBJECTS.id_to_obj_map = {}
97 SHARED_OBJECTS.id_to_config_map = {}
100def get_shared_object(obj_id):
101 """Retrieve an object previously seen during deserialization."""
102 if getattr(SHARED_OBJECTS, "enabled", False):
103 return SHARED_OBJECTS.id_to_obj_map.get(obj_id, None)
106def record_object_after_serialization(obj, config):
107 """Call after serializing an object, to keep track of its config."""
108 if config["module"] == "__main__":
109 config["module"] = None # Ensures module is None when no module found
110 if not getattr(SHARED_OBJECTS, "enabled", False):
111 return # Not in a sharing scope
112 obj_id = int(id(obj))
113 if obj_id not in SHARED_OBJECTS.id_to_config_map:
114 SHARED_OBJECTS.id_to_config_map[obj_id] = config
115 else:
116 config["shared_object_id"] = obj_id
117 prev_config = SHARED_OBJECTS.id_to_config_map[obj_id]
118 prev_config["shared_object_id"] = obj_id
121def record_object_after_deserialization(obj, obj_id):
122 """Call after deserializing an object, to keep track of it in the future."""
123 if not getattr(SHARED_OBJECTS, "enabled", False):
124 return # Not in a sharing scope
125 SHARED_OBJECTS.id_to_obj_map[obj_id] = obj
128@keras_export(
129 "keras.saving.serialize_keras_object", "keras.utils.serialize_keras_object"
130)
131def serialize_keras_object(obj):
132 """Retrieve the config dict by serializing the Keras object.
134 `serialize_keras_object()` serializes a Keras object to a python dictionary
135 that represents the object, and is a reciprocal function of
136 `deserialize_keras_object()`. See `deserialize_keras_object()` for more
137 information about the config format.
139 Args:
140 obj: the Keras object to serialize.
142 Returns:
143 A python dict that represents the object. The python dict can be
144 deserialized via `deserialize_keras_object()`.
145 """
146 # Fall back to legacy serialization for all TF1 users or if
147 # wrapped by in_tf_saved_model_scope() to explicitly use legacy
148 # saved_model logic.
149 if not tf.__internal__.tf2.enabled() or in_tf_saved_model_scope():
150 return legacy_serialization.serialize_keras_object(obj)
152 if obj is None:
153 return obj
155 if isinstance(obj, PLAIN_TYPES):
156 return obj
158 if isinstance(obj, (list, tuple)):
159 config_arr = [serialize_keras_object(x) for x in obj]
160 return tuple(config_arr) if isinstance(obj, tuple) else config_arr
161 if isinstance(obj, dict):
162 return serialize_dict(obj)
164 # Special cases:
165 if isinstance(obj, bytes):
166 return {
167 "class_name": "__bytes__",
168 "config": {"value": obj.decode("utf-8")},
169 }
170 if isinstance(obj, tf.TensorShape):
171 return obj.as_list() if obj._dims is not None else None
172 if isinstance(obj, tf.Tensor):
173 return {
174 "class_name": "__tensor__",
175 "config": {
176 "value": obj.numpy().tolist(),
177 "dtype": obj.dtype.name,
178 },
179 }
180 if type(obj).__module__ == np.__name__:
181 if isinstance(obj, np.ndarray) and obj.ndim > 0:
182 return {
183 "class_name": "__numpy__",
184 "config": {
185 "value": obj.tolist(),
186 "dtype": obj.dtype.name,
187 },
188 }
189 else:
190 # Treat numpy floats / etc as plain types.
191 return obj.item()
192 if isinstance(obj, tf.DType):
193 return obj.name
194 if isinstance(obj, tf.compat.v1.Dimension):
195 return obj.value
196 if isinstance(obj, types.FunctionType) and obj.__name__ == "<lambda>":
197 warnings.warn(
198 "The object being serialized includes a `lambda`. This is unsafe. "
199 "In order to reload the object, you will have to pass "
200 "`safe_mode=False` to the loading function. "
201 "Please avoid using `lambda` in the "
202 "future, and use named Python functions instead. "
203 f"This is the `lambda` being serialized: {inspect.getsource(obj)}",
204 stacklevel=2,
205 )
206 return {
207 "class_name": "__lambda__",
208 "config": {
209 "value": generic_utils.func_dump(obj),
210 },
211 }
212 if isinstance(obj, tf.TypeSpec):
213 ts_config = obj._serialize()
214 # TensorShape and tf.DType conversion
215 ts_config = list(
216 map(
217 lambda x: x.as_list()
218 if isinstance(x, tf.TensorShape)
219 else (x.name if isinstance(x, tf.DType) else x),
220 ts_config,
221 )
222 )
223 return {
224 "class_name": "__typespec__",
225 "spec_name": obj.__class__.__name__,
226 "module": obj.__class__.__module__,
227 "config": ts_config,
228 "registered_name": None,
229 }
231 inner_config = _get_class_or_fn_config(obj)
232 config_with_public_class = serialize_with_public_class(
233 obj.__class__, inner_config
234 )
236 # TODO(nkovela): Add TF ops dispatch handler serialization for
237 # ops.EagerTensor that contains nested numpy array.
238 # Target: NetworkConstructionTest.test_constant_initializer_with_numpy
239 if isinstance(inner_config, str) and inner_config == "op_dispatch_handler":
240 return obj
242 if config_with_public_class is not None:
244 # Special case for non-serializable class modules
245 if any(
246 mod in config_with_public_class["module"]
247 for mod in NON_SERIALIZABLE_CLASS_MODULES
248 ):
249 return obj
251 get_build_and_compile_config(obj, config_with_public_class)
252 record_object_after_serialization(obj, config_with_public_class)
253 return config_with_public_class
255 # Any custom object or otherwise non-exported object
256 if isinstance(obj, types.FunctionType):
257 module = obj.__module__
258 else:
259 module = obj.__class__.__module__
260 class_name = obj.__class__.__name__
262 if module == "builtins":
263 registered_name = None
264 else:
265 if isinstance(obj, types.FunctionType):
266 registered_name = object_registration.get_registered_name(obj)
267 else:
268 registered_name = object_registration.get_registered_name(
269 obj.__class__
270 )
272 config = {
273 "module": module,
274 "class_name": class_name,
275 "config": inner_config,
276 "registered_name": registered_name,
277 }
278 get_build_and_compile_config(obj, config)
279 record_object_after_serialization(obj, config)
280 return config
283def get_build_and_compile_config(obj, config):
284 if hasattr(obj, "get_build_config"):
285 build_config = obj.get_build_config()
286 if build_config is not None:
287 config["build_config"] = serialize_dict(build_config)
288 if hasattr(obj, "get_compile_config"):
289 compile_config = obj.get_compile_config()
290 if compile_config is not None:
291 config["compile_config"] = serialize_dict(compile_config)
292 return
295def serialize_with_public_class(cls, inner_config=None):
296 """Serializes classes from public Keras API or object registration.
298 Called to check and retrieve the config of any class that has a public
299 Keras API or has been registered as serializable via
300 `keras.saving.register_keras_serializable()`.
301 """
302 # This gets the `keras.*` exported name, such as "keras.optimizers.Adam".
303 keras_api_name = tf_export.get_canonical_name_for_symbol(
304 cls, api_name="keras"
305 )
307 # Case of custom or unknown class object
308 if keras_api_name is None:
309 registered_name = object_registration.get_registered_name(cls)
310 if registered_name is None:
311 return None
313 # Return custom object config with corresponding registration name
314 return {
315 "module": cls.__module__,
316 "class_name": cls.__name__,
317 "config": inner_config,
318 "registered_name": registered_name,
319 }
321 # Split the canonical Keras API name into a Keras module and class name.
322 parts = keras_api_name.split(".")
323 return {
324 "module": ".".join(parts[:-1]),
325 "class_name": parts[-1],
326 "config": inner_config,
327 "registered_name": None,
328 }
331def serialize_with_public_fn(fn, config, fn_module_name=None):
332 """Serializes functions from public Keras API or object registration.
334 Called to check and retrieve the config of any function that has a public
335 Keras API or has been registered as serializable via
336 `keras.saving.register_keras_serializable()`. If function's module name is
337 already known, returns corresponding config.
338 """
339 if fn_module_name:
340 return {
341 "module": fn_module_name,
342 "class_name": "function",
343 "config": config,
344 "registered_name": config,
345 }
346 keras_api_name = tf_export.get_canonical_name_for_symbol(
347 fn, api_name="keras"
348 )
349 if keras_api_name:
350 parts = keras_api_name.split(".")
351 return {
352 "module": ".".join(parts[:-1]),
353 "class_name": "function",
354 "config": config,
355 "registered_name": config,
356 }
357 else:
358 registered_name = object_registration.get_registered_name(fn)
359 if not registered_name and not fn.__module__ == "builtins":
360 return None
361 return {
362 "module": fn.__module__,
363 "class_name": "function",
364 "config": config,
365 "registered_name": registered_name,
366 }
369def _get_class_or_fn_config(obj):
370 """Return the object's config depending on its type."""
371 # Functions / lambdas:
372 if isinstance(obj, types.FunctionType):
373 return obj.__name__
374 # All classes:
375 if hasattr(obj, "get_config"):
376 config = obj.get_config()
377 if not isinstance(config, dict):
378 raise TypeError(
379 f"The `get_config()` method of {obj} should return "
380 f"a dict. It returned: {config}"
381 )
382 return serialize_dict(config)
383 elif hasattr(obj, "__name__"):
384 return object_registration.get_registered_name(obj)
385 else:
386 raise TypeError(
387 f"Cannot serialize object {obj} of type {type(obj)}. "
388 "To be serializable, "
389 "a class must implement the `get_config()` method."
390 )
393def serialize_dict(obj):
394 return {key: serialize_keras_object(value) for key, value in obj.items()}
397@keras_export(
398 "keras.saving.deserialize_keras_object",
399 "keras.utils.deserialize_keras_object",
400)
401def deserialize_keras_object(
402 config, custom_objects=None, safe_mode=True, **kwargs
403):
404 """Retrieve the object by deserializing the config dict.
406 The config dict is a Python dictionary that consists of a set of key-value
407 pairs, and represents a Keras object, such as an `Optimizer`, `Layer`,
408 `Metrics`, etc. The saving and loading library uses the following keys to
409 record information of a Keras object:
411 - `class_name`: String. This is the name of the class,
412 as exactly defined in the source
413 code, such as "LossesContainer".
414 - `config`: Dict. Library-defined or user-defined key-value pairs that store
415 the configuration of the object, as obtained by `object.get_config()`.
416 - `module`: String. The path of the python module, such as
417 "keras.engine.compile_utils". Built-in Keras classes
418 expect to have prefix `keras`.
419 - `registered_name`: String. The key the class is registered under via
420 `keras.saving.register_keras_serializable(package, name)` API. The key has
421 the format of '{package}>{name}', where `package` and `name` are the
422 arguments passed to `register_keras_serializable()`. If `name` is not
423 provided, it uses the class name. If `registered_name` successfully
424 resolves to a class (that was registered), the `class_name` and `config`
425 values in the dict will not be used. `registered_name` is only used for
426 non-built-in classes.
428 For example, the following dictionary represents the built-in Adam optimizer
429 with the relevant config:
431 ```python
432 dict_structure = {
433 "class_name": "Adam",
434 "config": {
435 "amsgrad": false,
436 "beta_1": 0.8999999761581421,
437 "beta_2": 0.9990000128746033,
438 "decay": 0.0,
439 "epsilon": 1e-07,
440 "learning_rate": 0.0010000000474974513,
441 "name": "Adam"
442 },
443 "module": "keras.optimizers",
444 "registered_name": None
445 }
446 # Returns an `Adam` instance identical to the original one.
447 deserialize_keras_object(dict_structure)
448 ```
450 If the class does not have an exported Keras namespace, the library tracks
451 it by its `module` and `class_name`. For example:
453 ```python
454 dict_structure = {
455 "class_name": "LossesContainer",
456 "config": {
457 "losses": [...],
458 "total_loss_mean": {...},
459 },
460 "module": "keras.engine.compile_utils",
461 "registered_name": "LossesContainer"
462 }
464 # Returns a `LossesContainer` instance identical to the original one.
465 deserialize_keras_object(dict_structure)
466 ```
468 And the following dictionary represents a user-customized `MeanSquaredError`
469 loss:
471 ```python
472 @keras.saving.register_keras_serializable(package='my_package')
473 class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
474 ...
476 dict_structure = {
477 "class_name": "ModifiedMeanSquaredError",
478 "config": {
479 "fn": "mean_squared_error",
480 "name": "mean_squared_error",
481 "reduction": "auto"
482 },
483 "registered_name": "my_package>ModifiedMeanSquaredError"
484 }
485 # Returns the `ModifiedMeanSquaredError` object
486 deserialize_keras_object(dict_structure)
487 ```
489 Args:
490 config: Python dict describing the object.
491 custom_objects: Python dict containing a mapping between custom
492 object names the corresponding classes or functions.
493 safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
494 When `safe_mode=False`, loading an object has the potential to
495 trigger arbitrary code execution. This argument is only
496 applicable to the Keras v3 model format. Defaults to `True`.
498 Returns:
499 The object described by the `config` dictionary.
501 """
502 safe_scope_arg = in_safe_mode() # Enforces SafeModeScope
503 safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode
505 module_objects = kwargs.pop("module_objects", None)
506 custom_objects = custom_objects or {}
507 tlco = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__
508 gco = object_registration._GLOBAL_CUSTOM_OBJECTS
509 custom_objects = {**custom_objects, **tlco, **gco}
511 # Optional deprecated argument for legacy deserialization call
512 printable_module_name = kwargs.pop("printable_module_name", "object")
513 if kwargs:
514 raise ValueError(
515 "The following argument(s) are not supported: "
516 f"{list(kwargs.keys())}"
517 )
519 # Fall back to legacy deserialization for all TF1 users or if
520 # wrapped by in_tf_saved_model_scope() to explicitly use legacy
521 # saved_model logic.
522 if not tf.__internal__.tf2.enabled() or in_tf_saved_model_scope():
523 return legacy_serialization.deserialize_keras_object(
524 config, module_objects, custom_objects, printable_module_name
525 )
527 if config is None:
528 return None
530 if (
531 isinstance(config, str)
532 and custom_objects
533 and custom_objects.get(config) is not None
534 ):
535 # This is to deserialize plain functions which are serialized as
536 # string names by legacy saving formats.
537 return custom_objects[config]
539 if isinstance(config, (list, tuple)):
540 return [
541 deserialize_keras_object(
542 x, custom_objects=custom_objects, safe_mode=safe_mode
543 )
544 for x in config
545 ]
547 if module_objects is not None:
548 inner_config, fn_module_name, has_custom_object = None, None, False
549 if isinstance(config, dict):
550 if "config" in config:
551 inner_config = config["config"]
552 if "class_name" not in config:
553 raise ValueError(
554 f"Unknown `config` as a `dict`, config={config}"
555 )
557 # Check case where config is function or class and in custom objects
558 if custom_objects and (
559 config["class_name"] in custom_objects
560 or config.get("registered_name") in custom_objects
561 or (
562 isinstance(inner_config, str)
563 and inner_config in custom_objects
564 )
565 ):
566 has_custom_object = True
568 # Case where config is function but not in custom objects
569 elif config["class_name"] == "function":
570 fn_module_name = config["module"]
571 if fn_module_name == "builtins":
572 config = config["config"]
573 else:
574 config = config["registered_name"]
576 # Case where config is class but not in custom objects
577 else:
578 if config.get("module", "_") is None:
579 raise TypeError(
580 "Cannot deserialize object of type "
581 f"`{config['class_name']}`. If "
582 f"`{config['class_name']}` is a custom class, please "
583 "register it using the "
584 "`@keras.saving.register_keras_serializable()` "
585 "decorator."
586 )
587 config = config["class_name"]
588 if not has_custom_object:
589 # Return if not found in either module objects or custom objects
590 if config not in module_objects:
591 # Object has already been deserialized
592 return config
593 if isinstance(module_objects[config], types.FunctionType):
594 return deserialize_keras_object(
595 serialize_with_public_fn(
596 module_objects[config], config, fn_module_name
597 ),
598 custom_objects=custom_objects,
599 )
600 return deserialize_keras_object(
601 serialize_with_public_class(
602 module_objects[config], inner_config=inner_config
603 ),
604 custom_objects=custom_objects,
605 )
607 if isinstance(config, PLAIN_TYPES):
608 return config
609 if not isinstance(config, dict):
610 raise TypeError(f"Could not parse config: {config}")
612 if "class_name" not in config or "config" not in config:
613 return {
614 key: deserialize_keras_object(
615 value, custom_objects=custom_objects, safe_mode=safe_mode
616 )
617 for key, value in config.items()
618 }
620 class_name = config["class_name"]
621 inner_config = config["config"] or {}
622 custom_objects = custom_objects or {}
624 # Special cases:
625 if class_name == "__tensor__":
626 return tf.constant(inner_config["value"], dtype=inner_config["dtype"])
627 if class_name == "__numpy__":
628 return np.array(inner_config["value"], dtype=inner_config["dtype"])
629 if config["class_name"] == "__bytes__":
630 return inner_config["value"].encode("utf-8")
631 if config["class_name"] == "__lambda__":
632 if safe_mode:
633 raise ValueError(
634 "Requested the deserialization of a `lambda` object. "
635 "This carries a potential risk of arbitrary code execution "
636 "and thus it is disallowed by default. If you trust the "
637 "source of the saved model, you can pass `safe_mode=False` to "
638 "the loading function in order to allow `lambda` loading."
639 )
640 return generic_utils.func_load(inner_config["value"])
641 if config["class_name"] == "__typespec__":
642 obj = _retrieve_class_or_fn(
643 config["spec_name"],
644 config["registered_name"],
645 config["module"],
646 obj_type="class",
647 full_config=config,
648 custom_objects=custom_objects,
649 )
650 # Conversion to TensorShape and tf.DType
651 inner_config = map(
652 lambda x: tf.TensorShape(x)
653 if isinstance(x, list)
654 else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x),
655 inner_config,
656 )
657 return obj._deserialize(tuple(inner_config))
659 # Below: classes and functions.
660 module = config.get("module", None)
661 registered_name = config.get("registered_name", class_name)
663 if class_name == "function":
664 fn_name = inner_config
665 return _retrieve_class_or_fn(
666 fn_name,
667 registered_name,
668 module,
669 obj_type="function",
670 full_config=config,
671 custom_objects=custom_objects,
672 )
674 # Below, handling of all classes.
675 # First, is it a shared object?
676 if "shared_object_id" in config:
677 obj = get_shared_object(config["shared_object_id"])
678 if obj is not None:
679 return obj
681 cls = _retrieve_class_or_fn(
682 class_name,
683 registered_name,
684 module,
685 obj_type="class",
686 full_config=config,
687 custom_objects=custom_objects,
688 )
690 if isinstance(cls, types.FunctionType):
691 return cls
692 if not hasattr(cls, "from_config"):
693 raise TypeError(
694 f"Unable to reconstruct an instance of '{class_name}' because "
695 f"the class is missing a `from_config()` method. "
696 f"Full object config: {config}"
697 )
699 # Instantiate the class from its config inside a custom object scope
700 # so that we can catch any custom objects that the config refers to.
701 custom_obj_scope = object_registration.custom_object_scope(custom_objects)
702 safe_mode_scope = SafeModeScope(safe_mode)
703 with custom_obj_scope, safe_mode_scope:
704 instance = cls.from_config(inner_config)
705 build_config = config.get("build_config", None)
706 if build_config:
707 instance.build_from_config(build_config)
708 compile_config = config.get("compile_config", None)
709 if compile_config:
710 instance.compile_from_config(compile_config)
712 if "shared_object_id" in config:
713 record_object_after_deserialization(
714 instance, config["shared_object_id"]
715 )
716 return instance
719def _retrieve_class_or_fn(
720 name, registered_name, module, obj_type, full_config, custom_objects=None
721):
722 # If there is a custom object registered via
723 # `register_keras_serializable()`, that takes precedence.
724 if obj_type == "function":
725 custom_obj = object_registration.get_registered_object(
726 name, custom_objects=custom_objects
727 )
728 else:
729 custom_obj = object_registration.get_registered_object(
730 registered_name, custom_objects=custom_objects
731 )
732 if custom_obj is not None:
733 return custom_obj
735 if module:
736 # If it's a Keras built-in object,
737 # we cannot always use direct import, because the exported
738 # module name might not match the package structure
739 # (e.g. experimental symbols).
740 if module == "keras" or module.startswith("keras."):
741 api_name = module + "." + name
743 # Legacy internal APIs are stored in TF API naming dict
744 # with `compat.v1` prefix
745 if "__internal__.legacy" in api_name:
746 api_name = "compat.v1." + api_name
748 obj = tf_export.get_symbol_from_name(api_name)
749 if obj is not None:
750 return obj
752 # Configs of Keras built-in functions do not contain identifying
753 # information other than their name (e.g. 'acc' or 'tanh'). This special
754 # case searches the Keras modules that contain built-ins to retrieve
755 # the corresponding function from the identifying string.
756 if obj_type == "function" and module == "builtins":
757 for mod in BUILTIN_MODULES:
758 obj = tf_export.get_symbol_from_name(
759 "keras." + mod + "." + name
760 )
761 if obj is not None:
762 return obj
764 # Retrieval of registered custom function in a package
765 filtered_dict = {
766 k: v
767 for k, v in custom_objects.items()
768 if k.endswith(full_config["config"])
769 }
770 if filtered_dict:
771 return next(iter(filtered_dict.values()))
773 # Otherwise, attempt to retrieve the class object given the `module`
774 # and `class_name`. Import the module, find the class.
775 try:
776 mod = importlib.import_module(module)
777 except ModuleNotFoundError:
778 raise TypeError(
779 f"Could not deserialize {obj_type} '{name}' because "
780 f"its parent module {module} cannot be imported. "
781 f"Full object config: {full_config}"
782 )
783 obj = vars(mod).get(name, None)
785 # Special case for keras.metrics.metrics
786 if obj is None and registered_name is not None:
787 obj = vars(mod).get(registered_name, None)
789 if obj is not None:
790 return obj
792 raise TypeError(
793 f"Could not locate {obj_type} '{name}'. "
794 "Make sure custom classes are decorated with "
795 "`@keras.saving.register_keras_serializable()`. "
796 f"Full object config: {full_config}"
797 )