Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/json_utils.py: 24%

71 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utils for creating and loading the Layer metadata for SavedModel. 

16 

17These are required to retain the original format of the build input shape, since 

18layers and models may have different build behaviors depending on if the shape 

19is a list, tuple, or TensorShape. For example, Network.build() will create 

20separate inputs if the given input_shape is a list, and will create a single 

21input if the given shape is a tuple. 

22""" 

23 

24import collections 

25import enum 

26import json 

27import numpy as np 

28import wrapt 

29 

30from tensorflow.python.framework import dtypes 

31from tensorflow.python.framework import tensor_shape 

32from tensorflow.python.framework import type_spec_registry 

33from tensorflow.python.types import internal 

34 

35 

36class Encoder(json.JSONEncoder): 

37 """JSON encoder and decoder that handles TensorShapes and tuples.""" 

38 

39 def default(self, obj): # pylint: disable=method-hidden 

40 """Encodes objects for types that aren't handled by the default encoder.""" 

41 if isinstance(obj, tensor_shape.TensorShape): 

42 items = obj.as_list() if obj.rank is not None else None 

43 return {'class_name': 'TensorShape', 'items': items} 

44 return get_json_type(obj) 

45 

46 def encode(self, obj): 

47 return super(Encoder, self).encode(_encode_tuple(obj)) 

48 

49 

50def _encode_tuple(x): 

51 if isinstance(x, tuple): 

52 return {'class_name': '__tuple__', 

53 'items': tuple(_encode_tuple(i) for i in x)} 

54 elif isinstance(x, list): 

55 return [_encode_tuple(i) for i in x] 

56 elif isinstance(x, dict): 

57 return {key: _encode_tuple(value) for key, value in x.items()} 

58 else: 

59 return x 

60 

61 

62def decode(json_string): 

63 return json.loads(json_string, object_hook=_decode_helper) 

64 

65 

66def _decode_helper(obj): 

67 """A decoding helper that is TF-object aware.""" 

68 if isinstance(obj, dict) and 'class_name' in obj: 

69 if obj['class_name'] == 'TensorShape': 

70 return tensor_shape.TensorShape(obj['items']) 

71 elif obj['class_name'] == 'TypeSpec': 

72 return type_spec_registry.lookup(obj['type_spec'])._deserialize( # pylint: disable=protected-access 

73 _decode_helper(obj['serialized'])) 

74 elif obj['class_name'] == '__tuple__': 

75 return tuple(_decode_helper(i) for i in obj['items']) 

76 elif obj['class_name'] == '__ellipsis__': 

77 return Ellipsis 

78 return obj 

79 

80 

81def get_json_type(obj): 

82 """Serializes any object to a JSON-serializable structure. 

83 

84 Args: 

85 obj: the object to serialize 

86 

87 Returns: 

88 JSON-serializable structure representing `obj`. 

89 

90 Raises: 

91 TypeError: if `obj` cannot be serialized. 

92 """ 

93 # if obj is a serializable Keras class instance 

94 # e.g. optimizer, layer 

95 if hasattr(obj, 'get_config'): 

96 return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} 

97 

98 # if obj is any numpy type 

99 if type(obj).__module__ == np.__name__: 

100 if isinstance(obj, np.ndarray): 

101 return obj.tolist() 

102 else: 

103 return obj.item() 

104 

105 # misc functions (e.g. loss function) 

106 if callable(obj): 

107 return obj.__name__ 

108 

109 # if obj is a python 'type' 

110 if type(obj).__name__ == type.__name__: 

111 return obj.__name__ 

112 

113 if isinstance(obj, tensor_shape.Dimension): 

114 return obj.value 

115 

116 if isinstance(obj, tensor_shape.TensorShape): 

117 return obj.as_list() 

118 

119 if isinstance(obj, dtypes.DType): 

120 return obj.name 

121 

122 if isinstance(obj, collections.abc.Mapping): 

123 return dict(obj) 

124 

125 if obj is Ellipsis: 

126 return {'class_name': '__ellipsis__'} 

127 

128 if isinstance(obj, wrapt.ObjectProxy): 

129 return obj.__wrapped__ 

130 

131 if isinstance(obj, internal.TypeSpec): 

132 try: 

133 type_spec_name = type_spec_registry.get_name(type(obj)) 

134 return {'class_name': 'TypeSpec', 'type_spec': type_spec_name, 

135 'serialized': obj._serialize()} # pylint: disable=protected-access 

136 except ValueError: 

137 raise ValueError('Unable to serialize {} to JSON, because the TypeSpec ' 

138 'class {} has not been registered.' 

139 .format(obj, type(obj))) 

140 

141 if isinstance(obj, enum.Enum): 

142 return obj.value 

143 

144 raise TypeError('Not JSON Serializable:', obj)