Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/serialization.py: 48%
105 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer serialization/deserialization functions."""
17import threading
19import tensorflow.compat.v2 as tf
21from keras.src.engine import base_layer
22from keras.src.engine import input_layer
23from keras.src.engine import input_spec
24from keras.src.layers import activation
25from keras.src.layers import attention
26from keras.src.layers import convolutional
27from keras.src.layers import core
28from keras.src.layers import locally_connected
29from keras.src.layers import merging
30from keras.src.layers import pooling
31from keras.src.layers import regularization
32from keras.src.layers import reshaping
33from keras.src.layers import rnn
34from keras.src.layers.normalization import batch_normalization
35from keras.src.layers.normalization import batch_normalization_v1
36from keras.src.layers.normalization import group_normalization
37from keras.src.layers.normalization import layer_normalization
38from keras.src.layers.normalization import unit_normalization
39from keras.src.layers.preprocessing import category_encoding
40from keras.src.layers.preprocessing import discretization
41from keras.src.layers.preprocessing import hashed_crossing
42from keras.src.layers.preprocessing import hashing
43from keras.src.layers.preprocessing import image_preprocessing
44from keras.src.layers.preprocessing import integer_lookup
45from keras.src.layers.preprocessing import (
46 normalization as preprocessing_normalization,
47)
48from keras.src.layers.preprocessing import string_lookup
49from keras.src.layers.preprocessing import text_vectorization
50from keras.src.layers.rnn import cell_wrappers
51from keras.src.layers.rnn import gru
52from keras.src.layers.rnn import lstm
53from keras.src.metrics import base_metric
54from keras.src.saving import serialization_lib
55from keras.src.saving.legacy import serialization as legacy_serialization
56from keras.src.saving.legacy.saved_model import json_utils
57from keras.src.utils import generic_utils
58from keras.src.utils import tf_inspect as inspect
60# isort: off
61from tensorflow.python.util.tf_export import keras_export
63ALL_MODULES = (
64 base_layer,
65 input_layer,
66 activation,
67 attention,
68 convolutional,
69 core,
70 locally_connected,
71 merging,
72 batch_normalization_v1,
73 group_normalization,
74 layer_normalization,
75 unit_normalization,
76 pooling,
77 image_preprocessing,
78 regularization,
79 reshaping,
80 rnn,
81 hashing,
82 hashed_crossing,
83 category_encoding,
84 discretization,
85 integer_lookup,
86 preprocessing_normalization,
87 string_lookup,
88 text_vectorization,
89)
90ALL_V2_MODULES = (
91 batch_normalization,
92 layer_normalization,
93 cell_wrappers,
94 gru,
95 lstm,
96)
97# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
98# thread-local to avoid concurrent mutations.
99LOCAL = threading.local()
102def populate_deserializable_objects():
103 """Populates dict ALL_OBJECTS with every built-in layer."""
104 global LOCAL
105 if not hasattr(LOCAL, "ALL_OBJECTS"):
106 LOCAL.ALL_OBJECTS = {}
107 LOCAL.GENERATED_WITH_V2 = None
109 if (
110 LOCAL.ALL_OBJECTS
111 and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()
112 ):
113 # Objects dict is already generated for the proper TF version:
114 # do nothing.
115 return
117 LOCAL.ALL_OBJECTS = {}
118 LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()
120 base_cls = base_layer.Layer
121 generic_utils.populate_dict_with_module_objects(
122 LOCAL.ALL_OBJECTS,
123 ALL_MODULES,
124 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
125 )
127 # Overwrite certain V1 objects with V2 versions
128 if tf.__internal__.tf2.enabled():
129 generic_utils.populate_dict_with_module_objects(
130 LOCAL.ALL_OBJECTS,
131 ALL_V2_MODULES,
132 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
133 )
135 # These deserialization aliases are added for backward compatibility,
136 # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
137 # were used as class name for v1 and v2 version of BatchNormalization,
138 # respectively. Here we explicitly convert them to their canonical names.
139 LOCAL.ALL_OBJECTS[
140 "BatchNormalizationV1"
141 ] = batch_normalization_v1.BatchNormalization
142 LOCAL.ALL_OBJECTS[
143 "BatchNormalizationV2"
144 ] = batch_normalization.BatchNormalization
146 # Prevent circular dependencies.
147 from keras.src import models
148 from keras.src.feature_column.sequence_feature_column import (
149 SequenceFeatures,
150 )
151 from keras.src.premade_models.linear import (
152 LinearModel,
153 )
154 from keras.src.premade_models.wide_deep import (
155 WideDeepModel,
156 )
158 LOCAL.ALL_OBJECTS["Input"] = input_layer.Input
159 LOCAL.ALL_OBJECTS["InputSpec"] = input_spec.InputSpec
160 LOCAL.ALL_OBJECTS["Functional"] = models.Functional
161 LOCAL.ALL_OBJECTS["Model"] = models.Model
162 LOCAL.ALL_OBJECTS["SequenceFeatures"] = SequenceFeatures
163 LOCAL.ALL_OBJECTS["Sequential"] = models.Sequential
164 LOCAL.ALL_OBJECTS["LinearModel"] = LinearModel
165 LOCAL.ALL_OBJECTS["WideDeepModel"] = WideDeepModel
167 if tf.__internal__.tf2.enabled():
168 from keras.src.feature_column.dense_features_v2 import (
169 DenseFeatures,
170 )
172 LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures
173 else:
174 from keras.src.feature_column.dense_features import (
175 DenseFeatures,
176 )
178 LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures
180 # Merging layers, function versions.
181 LOCAL.ALL_OBJECTS["add"] = merging.add
182 LOCAL.ALL_OBJECTS["subtract"] = merging.subtract
183 LOCAL.ALL_OBJECTS["multiply"] = merging.multiply
184 LOCAL.ALL_OBJECTS["average"] = merging.average
185 LOCAL.ALL_OBJECTS["maximum"] = merging.maximum
186 LOCAL.ALL_OBJECTS["minimum"] = merging.minimum
187 LOCAL.ALL_OBJECTS["concatenate"] = merging.concatenate
188 LOCAL.ALL_OBJECTS["dot"] = merging.dot
191@keras_export("keras.layers.serialize")
192def serialize(layer, use_legacy_format=False):
193 """Serializes a `Layer` object into a JSON-compatible representation.
195 Args:
196 layer: The `Layer` object to serialize.
198 Returns:
199 A JSON-serializable dict representing the object's config.
201 Example:
203 ```python
204 from pprint import pprint
205 model = tf.keras.models.Sequential()
206 model.add(tf.keras.Input(shape=(16,)))
207 model.add(tf.keras.layers.Dense(32, activation='relu'))
209 pprint(tf.keras.layers.serialize(model))
210 # prints the configuration of the model, as a dict.
211 """
212 if isinstance(layer, base_metric.Metric):
213 raise ValueError(
214 f"Cannot serialize {layer} since it is a metric. "
215 "Please use the `keras.metrics.serialize()` and "
216 "`keras.metrics.deserialize()` APIs to serialize "
217 "and deserialize metrics."
218 )
219 if use_legacy_format:
220 return legacy_serialization.serialize_keras_object(layer)
222 return serialization_lib.serialize_keras_object(layer)
225@keras_export("keras.layers.deserialize")
226def deserialize(config, custom_objects=None, use_legacy_format=False):
227 """Instantiates a layer from a config dictionary.
229 Args:
230 config: dict of the form {'class_name': str, 'config': dict}
231 custom_objects: dict mapping class names (or function names) of custom
232 (non-Keras) objects to class/functions
234 Returns:
235 Layer instance (may be Model, Sequential, Network, Layer...)
237 Example:
239 ```python
240 # Configuration of Dense(32, activation='relu')
241 config = {
242 'class_name': 'Dense',
243 'config': {
244 'activation': 'relu',
245 'activity_regularizer': None,
246 'bias_constraint': None,
247 'bias_initializer': {'class_name': 'Zeros', 'config': {}},
248 'bias_regularizer': None,
249 'dtype': 'float32',
250 'kernel_constraint': None,
251 'kernel_initializer': {'class_name': 'GlorotUniform',
252 'config': {'seed': None}},
253 'kernel_regularizer': None,
254 'name': 'dense',
255 'trainable': True,
256 'units': 32,
257 'use_bias': True
258 }
259 }
260 dense_layer = tf.keras.layers.deserialize(config)
261 ```
262 """
263 populate_deserializable_objects()
264 if not config:
265 raise ValueError(
266 f"Cannot deserialize empty config. Received: config={config}"
267 )
268 if use_legacy_format:
269 return legacy_serialization.deserialize_keras_object(
270 config,
271 module_objects=LOCAL.ALL_OBJECTS,
272 custom_objects=custom_objects,
273 printable_module_name="layer",
274 )
276 return serialization_lib.deserialize_keras_object(
277 config,
278 module_objects=LOCAL.ALL_OBJECTS,
279 custom_objects=custom_objects,
280 printable_module_name="layer",
281 )
284def get_builtin_layer(class_name):
285 """Returns class if `class_name` is registered, else returns None."""
286 if not hasattr(LOCAL, "ALL_OBJECTS"):
287 populate_deserializable_objects()
288 return LOCAL.ALL_OBJECTS.get(class_name)
291def deserialize_from_json(json_string, custom_objects=None):
292 """Instantiates a layer from a JSON string."""
293 populate_deserializable_objects()
294 config = json_utils.decode_and_deserialize(
295 json_string,
296 module_objects=LOCAL.ALL_OBJECTS,
297 custom_objects=custom_objects,
298 )
299 return deserialize(config, custom_objects)