Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/json_utils.py: 21%

101 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utils for creating and loading the Layer metadata for SavedModel. 

16 

17These are required to retain the original format of the build input shape, since 

18layers and models may have different build behaviors depending on if the shape 

19is a list, tuple, or TensorShape. For example, Network.build() will create 

20separate inputs if the given input_shape is a list, and will create a single 

21input if the given shape is a tuple. 

22""" 

23 

24import collections 

25import enum 

26import functools 

27import json 

28 

29import numpy as np 

30import tensorflow.compat.v2 as tf 

31import wrapt 

32 

33from keras.src.saving import serialization_lib 

34from keras.src.saving.legacy import serialization 

35from keras.src.saving.legacy.saved_model.utils import in_tf_saved_model_scope 

36 

37# isort: off 

38from tensorflow.python.framework import type_spec_registry 

39 

40_EXTENSION_TYPE_SPEC = "_EXTENSION_TYPE_SPEC" 

41 

42 

43class Encoder(json.JSONEncoder): 

44 """JSON encoder and decoder that handles TensorShapes and tuples.""" 

45 

46 def default(self, obj): 

47 """Encodes objects for types that aren't handled by the default 

48 encoder.""" 

49 if isinstance(obj, tf.TensorShape): 

50 items = obj.as_list() if obj.rank is not None else None 

51 return {"class_name": "TensorShape", "items": items} 

52 return get_json_type(obj) 

53 

54 def encode(self, obj): 

55 return super().encode(_encode_tuple(obj)) 

56 

57 

58def _encode_tuple(x): 

59 if isinstance(x, tuple): 

60 return { 

61 "class_name": "__tuple__", 

62 "items": tuple(_encode_tuple(i) for i in x), 

63 } 

64 elif isinstance(x, list): 

65 return [_encode_tuple(i) for i in x] 

66 elif isinstance(x, dict): 

67 return {key: _encode_tuple(value) for key, value in x.items()} 

68 else: 

69 return x 

70 

71 

72def decode(json_string): 

73 return json.loads(json_string, object_hook=_decode_helper) 

74 

75 

76def decode_and_deserialize( 

77 json_string, module_objects=None, custom_objects=None 

78): 

79 """Decodes the JSON and deserializes any Keras objects found in the dict.""" 

80 return json.loads( 

81 json_string, 

82 object_hook=functools.partial( 

83 _decode_helper, 

84 deserialize=True, 

85 module_objects=module_objects, 

86 custom_objects=custom_objects, 

87 ), 

88 ) 

89 

90 

91def _decode_helper( 

92 obj, deserialize=False, module_objects=None, custom_objects=None 

93): 

94 """A decoding helper that is TF-object aware. 

95 

96 Args: 

97 obj: A decoded dictionary that may represent an object. 

98 deserialize: Boolean, defaults to False. When True, deserializes any Keras 

99 objects found in `obj`. 

100 module_objects: A dictionary of built-in objects to look the name up in. 

101 Generally, `module_objects` is provided by midlevel library 

102 implementers. 

103 custom_objects: A dictionary of custom objects to look the name up in. 

104 Generally, `custom_objects` is provided by the end user. 

105 

106 Returns: 

107 The decoded object. 

108 """ 

109 if isinstance(obj, dict) and "class_name" in obj: 

110 if obj["class_name"] == "TensorShape": 

111 return tf.TensorShape(obj["items"]) 

112 elif obj["class_name"] == "TypeSpec": 

113 return type_spec_registry.lookup(obj["type_spec"])._deserialize( 

114 _decode_helper(obj["serialized"]) 

115 ) 

116 elif obj["class_name"] == "CompositeTensor": 

117 spec = obj["spec"] 

118 tensors = [] 

119 for dtype, tensor in obj["tensors"]: 

120 tensors.append( 

121 tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype)) 

122 ) 

123 return tf.nest.pack_sequence_as( 

124 _decode_helper(spec), tensors, expand_composites=True 

125 ) 

126 elif obj["class_name"] == "__tuple__": 

127 return tuple(_decode_helper(i) for i in obj["items"]) 

128 elif obj["class_name"] == "__ellipsis__": 

129 return Ellipsis 

130 elif deserialize and "__passive_serialization__" in obj: 

131 # __passive_serialization__ is added by the JSON encoder when 

132 # encoding an object that has a `get_config()` method. 

133 try: 

134 if in_tf_saved_model_scope() or "module" not in obj: 

135 return serialization.deserialize_keras_object( 

136 obj, 

137 module_objects=module_objects, 

138 custom_objects=custom_objects, 

139 ) 

140 else: 

141 return serialization_lib.deserialize_keras_object( 

142 obj, 

143 module_objects=module_objects, 

144 custom_objects=custom_objects, 

145 ) 

146 except ValueError: 

147 pass 

148 elif obj["class_name"] == "__bytes__": 

149 return obj["value"].encode("utf-8") 

150 return obj 

151 

152 

153def get_json_type(obj): 

154 """Serializes any object to a JSON-serializable structure. 

155 

156 Args: 

157 obj: the object to serialize 

158 

159 Returns: 

160 JSON-serializable structure representing `obj`. 

161 

162 Raises: 

163 TypeError: if `obj` cannot be serialized. 

164 """ 

165 # if obj is a serializable Keras class instance 

166 # e.g. optimizer, layer 

167 if hasattr(obj, "get_config"): 

168 serialized = serialization.serialize_keras_object(obj) 

169 serialized["__passive_serialization__"] = True 

170 return serialized 

171 

172 # if obj is any numpy type 

173 if type(obj).__module__ == np.__name__: 

174 if isinstance(obj, np.ndarray): 

175 return obj.tolist() 

176 else: 

177 return obj.item() 

178 

179 # misc functions (e.g. loss function) 

180 if callable(obj): 

181 return obj.__name__ 

182 

183 # if obj is a python 'type' 

184 if type(obj).__name__ == type.__name__: 

185 return obj.__name__ 

186 

187 if isinstance(obj, tf.compat.v1.Dimension): 

188 return obj.value 

189 

190 if isinstance(obj, tf.TensorShape): 

191 return obj.as_list() 

192 

193 if isinstance(obj, tf.DType): 

194 return obj.name 

195 

196 if isinstance(obj, collections.abc.Mapping): 

197 return dict(obj) 

198 

199 if obj is Ellipsis: 

200 return {"class_name": "__ellipsis__"} 

201 

202 if isinstance(obj, wrapt.ObjectProxy): 

203 return obj.__wrapped__ 

204 

205 if isinstance(obj, tf.TypeSpec): 

206 try: 

207 type_spec_name = type_spec_registry.get_name(type(obj)) 

208 return { 

209 "class_name": "TypeSpec", 

210 "type_spec": type_spec_name, 

211 "serialized": obj._serialize(), 

212 } 

213 except ValueError: 

214 raise ValueError( 

215 f"Unable to serialize {obj} to JSON, because the TypeSpec " 

216 f"class {type(obj)} has not been registered." 

217 ) 

218 if isinstance(obj, tf.__internal__.CompositeTensor): 

219 spec = tf.type_spec_from_value(obj) 

220 tensors = [] 

221 for tensor in tf.nest.flatten(obj, expand_composites=True): 

222 tensors.append((tensor.dtype.name, tensor.numpy().tolist())) 

223 return { 

224 "class_name": "CompositeTensor", 

225 "spec": get_json_type(spec), 

226 "tensors": tensors, 

227 } 

228 

229 if isinstance(obj, enum.Enum): 

230 return obj.value 

231 

232 if isinstance(obj, bytes): 

233 return {"class_name": "__bytes__", "value": obj.decode("utf-8")} 

234 

235 raise TypeError( 

236 f"Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}." 

237 ) 

238