Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/json_utils.py: 24%
71 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for creating and loading the Layer metadata for SavedModel.
17These are required to retain the original format of the build input shape, since
18layers and models may have different build behaviors depending on if the shape
19is a list, tuple, or TensorShape. For example, Network.build() will create
20separate inputs if the given input_shape is a list, and will create a single
21input if the given shape is a tuple.
22"""
24import collections
25import enum
26import json
27import numpy as np
28import wrapt
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import type_spec_registry
33from tensorflow.python.types import internal
36class Encoder(json.JSONEncoder):
37 """JSON encoder and decoder that handles TensorShapes and tuples."""
39 def default(self, obj): # pylint: disable=method-hidden
40 """Encodes objects for types that aren't handled by the default encoder."""
41 if isinstance(obj, tensor_shape.TensorShape):
42 items = obj.as_list() if obj.rank is not None else None
43 return {'class_name': 'TensorShape', 'items': items}
44 return get_json_type(obj)
46 def encode(self, obj):
47 return super(Encoder, self).encode(_encode_tuple(obj))
50def _encode_tuple(x):
51 if isinstance(x, tuple):
52 return {'class_name': '__tuple__',
53 'items': tuple(_encode_tuple(i) for i in x)}
54 elif isinstance(x, list):
55 return [_encode_tuple(i) for i in x]
56 elif isinstance(x, dict):
57 return {key: _encode_tuple(value) for key, value in x.items()}
58 else:
59 return x
62def decode(json_string):
63 return json.loads(json_string, object_hook=_decode_helper)
66def _decode_helper(obj):
67 """A decoding helper that is TF-object aware."""
68 if isinstance(obj, dict) and 'class_name' in obj:
69 if obj['class_name'] == 'TensorShape':
70 return tensor_shape.TensorShape(obj['items'])
71 elif obj['class_name'] == 'TypeSpec':
72 return type_spec_registry.lookup(obj['type_spec'])._deserialize( # pylint: disable=protected-access
73 _decode_helper(obj['serialized']))
74 elif obj['class_name'] == '__tuple__':
75 return tuple(_decode_helper(i) for i in obj['items'])
76 elif obj['class_name'] == '__ellipsis__':
77 return Ellipsis
78 return obj
81def get_json_type(obj):
82 """Serializes any object to a JSON-serializable structure.
84 Args:
85 obj: the object to serialize
87 Returns:
88 JSON-serializable structure representing `obj`.
90 Raises:
91 TypeError: if `obj` cannot be serialized.
92 """
93 # if obj is a serializable Keras class instance
94 # e.g. optimizer, layer
95 if hasattr(obj, 'get_config'):
96 return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
98 # if obj is any numpy type
99 if type(obj).__module__ == np.__name__:
100 if isinstance(obj, np.ndarray):
101 return obj.tolist()
102 else:
103 return obj.item()
105 # misc functions (e.g. loss function)
106 if callable(obj):
107 return obj.__name__
109 # if obj is a python 'type'
110 if type(obj).__name__ == type.__name__:
111 return obj.__name__
113 if isinstance(obj, tensor_shape.Dimension):
114 return obj.value
116 if isinstance(obj, tensor_shape.TensorShape):
117 return obj.as_list()
119 if isinstance(obj, dtypes.DType):
120 return obj.name
122 if isinstance(obj, collections.abc.Mapping):
123 return dict(obj)
125 if obj is Ellipsis:
126 return {'class_name': '__ellipsis__'}
128 if isinstance(obj, wrapt.ObjectProxy):
129 return obj.__wrapped__
131 if isinstance(obj, internal.TypeSpec):
132 try:
133 type_spec_name = type_spec_registry.get_name(type(obj))
134 return {'class_name': 'TypeSpec', 'type_spec': type_spec_name,
135 'serialized': obj._serialize()} # pylint: disable=protected-access
136 except ValueError:
137 raise ValueError('Unable to serialize {} to JSON, because the TypeSpec '
138 'class {} has not been registered.'
139 .format(obj, type(obj)))
141 if isinstance(obj, enum.Enum):
142 return obj.value
144 raise TypeError('Not JSON Serializable:', obj)