Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/json_utils.py: 21%
101 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for creating and loading the Layer metadata for SavedModel.
17These are required to retain the original format of the build input shape, since
18layers and models may have different build behaviors depending on if the shape
19is a list, tuple, or TensorShape. For example, Network.build() will create
20separate inputs if the given input_shape is a list, and will create a single
21input if the given shape is a tuple.
22"""
24import collections
25import enum
26import functools
27import json
29import numpy as np
30import tensorflow.compat.v2 as tf
31import wrapt
33from keras.src.saving import serialization_lib
34from keras.src.saving.legacy import serialization
35from keras.src.saving.legacy.saved_model.utils import in_tf_saved_model_scope
37# isort: off
38from tensorflow.python.framework import type_spec_registry
40_EXTENSION_TYPE_SPEC = "_EXTENSION_TYPE_SPEC"
43class Encoder(json.JSONEncoder):
44 """JSON encoder and decoder that handles TensorShapes and tuples."""
46 def default(self, obj):
47 """Encodes objects for types that aren't handled by the default
48 encoder."""
49 if isinstance(obj, tf.TensorShape):
50 items = obj.as_list() if obj.rank is not None else None
51 return {"class_name": "TensorShape", "items": items}
52 return get_json_type(obj)
54 def encode(self, obj):
55 return super().encode(_encode_tuple(obj))
58def _encode_tuple(x):
59 if isinstance(x, tuple):
60 return {
61 "class_name": "__tuple__",
62 "items": tuple(_encode_tuple(i) for i in x),
63 }
64 elif isinstance(x, list):
65 return [_encode_tuple(i) for i in x]
66 elif isinstance(x, dict):
67 return {key: _encode_tuple(value) for key, value in x.items()}
68 else:
69 return x
72def decode(json_string):
73 return json.loads(json_string, object_hook=_decode_helper)
76def decode_and_deserialize(
77 json_string, module_objects=None, custom_objects=None
78):
79 """Decodes the JSON and deserializes any Keras objects found in the dict."""
80 return json.loads(
81 json_string,
82 object_hook=functools.partial(
83 _decode_helper,
84 deserialize=True,
85 module_objects=module_objects,
86 custom_objects=custom_objects,
87 ),
88 )
91def _decode_helper(
92 obj, deserialize=False, module_objects=None, custom_objects=None
93):
94 """A decoding helper that is TF-object aware.
96 Args:
97 obj: A decoded dictionary that may represent an object.
98 deserialize: Boolean, defaults to False. When True, deserializes any Keras
99 objects found in `obj`.
100 module_objects: A dictionary of built-in objects to look the name up in.
101 Generally, `module_objects` is provided by midlevel library
102 implementers.
103 custom_objects: A dictionary of custom objects to look the name up in.
104 Generally, `custom_objects` is provided by the end user.
106 Returns:
107 The decoded object.
108 """
109 if isinstance(obj, dict) and "class_name" in obj:
110 if obj["class_name"] == "TensorShape":
111 return tf.TensorShape(obj["items"])
112 elif obj["class_name"] == "TypeSpec":
113 return type_spec_registry.lookup(obj["type_spec"])._deserialize(
114 _decode_helper(obj["serialized"])
115 )
116 elif obj["class_name"] == "CompositeTensor":
117 spec = obj["spec"]
118 tensors = []
119 for dtype, tensor in obj["tensors"]:
120 tensors.append(
121 tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype))
122 )
123 return tf.nest.pack_sequence_as(
124 _decode_helper(spec), tensors, expand_composites=True
125 )
126 elif obj["class_name"] == "__tuple__":
127 return tuple(_decode_helper(i) for i in obj["items"])
128 elif obj["class_name"] == "__ellipsis__":
129 return Ellipsis
130 elif deserialize and "__passive_serialization__" in obj:
131 # __passive_serialization__ is added by the JSON encoder when
132 # encoding an object that has a `get_config()` method.
133 try:
134 if in_tf_saved_model_scope() or "module" not in obj:
135 return serialization.deserialize_keras_object(
136 obj,
137 module_objects=module_objects,
138 custom_objects=custom_objects,
139 )
140 else:
141 return serialization_lib.deserialize_keras_object(
142 obj,
143 module_objects=module_objects,
144 custom_objects=custom_objects,
145 )
146 except ValueError:
147 pass
148 elif obj["class_name"] == "__bytes__":
149 return obj["value"].encode("utf-8")
150 return obj
153def get_json_type(obj):
154 """Serializes any object to a JSON-serializable structure.
156 Args:
157 obj: the object to serialize
159 Returns:
160 JSON-serializable structure representing `obj`.
162 Raises:
163 TypeError: if `obj` cannot be serialized.
164 """
165 # if obj is a serializable Keras class instance
166 # e.g. optimizer, layer
167 if hasattr(obj, "get_config"):
168 serialized = serialization.serialize_keras_object(obj)
169 serialized["__passive_serialization__"] = True
170 return serialized
172 # if obj is any numpy type
173 if type(obj).__module__ == np.__name__:
174 if isinstance(obj, np.ndarray):
175 return obj.tolist()
176 else:
177 return obj.item()
179 # misc functions (e.g. loss function)
180 if callable(obj):
181 return obj.__name__
183 # if obj is a python 'type'
184 if type(obj).__name__ == type.__name__:
185 return obj.__name__
187 if isinstance(obj, tf.compat.v1.Dimension):
188 return obj.value
190 if isinstance(obj, tf.TensorShape):
191 return obj.as_list()
193 if isinstance(obj, tf.DType):
194 return obj.name
196 if isinstance(obj, collections.abc.Mapping):
197 return dict(obj)
199 if obj is Ellipsis:
200 return {"class_name": "__ellipsis__"}
202 if isinstance(obj, wrapt.ObjectProxy):
203 return obj.__wrapped__
205 if isinstance(obj, tf.TypeSpec):
206 try:
207 type_spec_name = type_spec_registry.get_name(type(obj))
208 return {
209 "class_name": "TypeSpec",
210 "type_spec": type_spec_name,
211 "serialized": obj._serialize(),
212 }
213 except ValueError:
214 raise ValueError(
215 f"Unable to serialize {obj} to JSON, because the TypeSpec "
216 f"class {type(obj)} has not been registered."
217 )
218 if isinstance(obj, tf.__internal__.CompositeTensor):
219 spec = tf.type_spec_from_value(obj)
220 tensors = []
221 for tensor in tf.nest.flatten(obj, expand_composites=True):
222 tensors.append((tensor.dtype.name, tensor.numpy().tolist()))
223 return {
224 "class_name": "CompositeTensor",
225 "spec": get_json_type(spec),
226 "tensors": tensors,
227 }
229 if isinstance(obj, enum.Enum):
230 return obj.value
232 if isinstance(obj, bytes):
233 return {"class_name": "__bytes__", "value": obj.decode("utf-8")}
235 raise TypeError(
236 f"Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}."
237 )