Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/object_registration.py: 57%
51 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python utilities required by Keras."""
17import inspect
18import threading
20# isort: off
21from tensorflow.python.util.tf_export import keras_export
23_GLOBAL_CUSTOM_OBJECTS = {}
24_GLOBAL_CUSTOM_NAMES = {}
25# Thread-local custom objects set by custom_object_scope.
26_THREAD_LOCAL_CUSTOM_OBJECTS = threading.local()
29@keras_export(
30 "keras.saving.custom_object_scope",
31 "keras.utils.custom_object_scope",
32 "keras.utils.CustomObjectScope",
33)
34class CustomObjectScope:
35 """Exposes custom classes/functions to Keras deserialization internals.
37 Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
38 as `tf.keras.models.load_model` or `tf.keras.models.model_from_config`
39 will be able to deserialize any custom object referenced by a
40 saved config (e.g. a custom layer or metric).
42 Example:
44 Consider a custom regularizer `my_regularizer`:
46 ```python
47 layer = Dense(3, kernel_regularizer=my_regularizer)
48 # Config contains a reference to `my_regularizer`
49 config = layer.get_config()
50 ...
51 # Later:
52 with custom_object_scope({'my_regularizer': my_regularizer}):
53 layer = Dense.from_config(config)
54 ```
56 Args:
57 *args: Dictionary or dictionaries of `{name: object}` pairs.
58 """
60 def __init__(self, *args):
61 self.custom_objects = args
62 self.backup = None
64 def __enter__(self):
65 self.backup = _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.copy()
66 for objects in self.custom_objects:
67 _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(objects)
68 return self
70 def __exit__(self, *args, **kwargs):
71 _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.clear()
72 _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(self.backup)
75@keras_export(
76 "keras.saving.get_custom_objects", "keras.utils.get_custom_objects"
77)
78def get_custom_objects():
79 """Retrieves a live reference to the global dictionary of custom objects.
81 Custom objects set using using `custom_object_scope` are not added to the
82 global dictionary of custom objects, and will not appear in the returned
83 dictionary.
85 Example:
87 ```python
88 get_custom_objects().clear()
89 get_custom_objects()['MyObject'] = MyObject
90 ```
92 Returns:
93 Global dictionary mapping registered class names to classes.
94 """
95 return _GLOBAL_CUSTOM_OBJECTS
98@keras_export(
99 "keras.saving.register_keras_serializable",
100 "keras.utils.register_keras_serializable",
101)
102def register_keras_serializable(package="Custom", name=None):
103 """Registers an object with the Keras serialization framework.
105 This decorator injects the decorated class or function into the Keras custom
106 object dictionary, so that it can be serialized and deserialized without
107 needing an entry in the user-provided custom object dict. It also injects a
108 function that Keras will call to get the object's serializable string key.
110 Note that to be serialized and deserialized, classes must implement the
111 `get_config()` method. Functions do not have this requirement.
113 The object will be registered under the key 'package>name' where `name`,
114 defaults to the object name if not passed.
116 Example:
118 ```python
119 # Note that `'my_package'` is used as the `package` argument here, and since
120 # the `name` argument is not provided, `'MyDense'` is used as the `name`.
121 @keras.saving.register_keras_serializable('my_package')
122 class MyDense(keras.layers.Dense):
123 pass
125 assert keras.saving.get_registered_object('my_package>MyDense') == MyDense
126 assert keras.saving.get_registered_name(MyDense) == 'my_package>MyDense'
127 ```
129 Args:
130 package: The package that this class belongs to. This is used for the
131 `key` (which is `"package>name"`) to idenfify the class. Note that this
132 is the first argument passed into the decorator.
133 name: The name to serialize this class under in this package. If not
134 provided or `None`, the class' name will be used (note that this is the
135 case when the decorator is used with only one argument, which becomes
136 the `package`).
138 Returns:
139 A decorator that registers the decorated class with the passed names.
140 """
142 def decorator(arg):
143 """Registers a class with the Keras serialization framework."""
144 class_name = name if name is not None else arg.__name__
145 registered_name = package + ">" + class_name
147 if inspect.isclass(arg) and not hasattr(arg, "get_config"):
148 raise ValueError(
149 "Cannot register a class that does not have a "
150 "get_config() method."
151 )
153 _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
154 _GLOBAL_CUSTOM_NAMES[arg] = registered_name
156 return arg
158 return decorator
161@keras_export(
162 "keras.saving.get_registered_name", "keras.utils.get_registered_name"
163)
164def get_registered_name(obj):
165 """Returns the name registered to an object within the Keras framework.
167 This function is part of the Keras serialization and deserialization
168 framework. It maps objects to the string names associated with those objects
169 for serialization/deserialization.
171 Args:
172 obj: The object to look up.
174 Returns:
175 The name associated with the object, or the default Python name if the
176 object is not registered.
177 """
178 if obj in _GLOBAL_CUSTOM_NAMES:
179 return _GLOBAL_CUSTOM_NAMES[obj]
180 else:
181 return obj.__name__
184@keras_export(
185 "keras.saving.get_registered_object", "keras.utils.get_registered_object"
186)
187def get_registered_object(name, custom_objects=None, module_objects=None):
188 """Returns the class associated with `name` if it is registered with Keras.
190 This function is part of the Keras serialization and deserialization
191 framework. It maps strings to the objects associated with them for
192 serialization/deserialization.
194 Example:
196 ```python
197 def from_config(cls, config, custom_objects=None):
198 if 'my_custom_object_name' in config:
199 config['hidden_cls'] = tf.keras.saving.get_registered_object(
200 config['my_custom_object_name'], custom_objects=custom_objects)
201 ```
203 Args:
204 name: The name to look up.
205 custom_objects: A dictionary of custom objects to look the name up in.
206 Generally, custom_objects is provided by the user.
207 module_objects: A dictionary of custom objects to look the name up in.
208 Generally, module_objects is provided by midlevel library implementers.
210 Returns:
211 An instantiable class associated with `name`, or `None` if no such class
212 exists.
213 """
214 if name in _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__:
215 return _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[name]
216 elif name in _GLOBAL_CUSTOM_OBJECTS:
217 return _GLOBAL_CUSTOM_OBJECTS[name]
218 elif custom_objects and name in custom_objects:
219 return custom_objects[name]
220 elif module_objects and name in module_objects:
221 return module_objects[name]
222 return None
225# Aliases
226custom_object_scope = CustomObjectScope