Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py: 44%

78 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Classes and functions implementing Layer SavedModel serialization.""" 

16 

17from tensorflow.python.keras.mixed_precision import policy 

18from tensorflow.python.keras.saving.saved_model import base_serialization 

19from tensorflow.python.keras.saving.saved_model import constants 

20from tensorflow.python.keras.saving.saved_model import save_impl 

21from tensorflow.python.keras.saving.saved_model import serialized_attributes 

22from tensorflow.python.keras.utils import generic_utils 

23from tensorflow.python.trackable import data_structures 

24from tensorflow.python.util import nest 

25 

26 

27class LayerSavedModelSaver(base_serialization.SavedModelSaver): 

28 """Implements Layer SavedModel serialization.""" 

29 

30 @property 

31 def object_identifier(self): 

32 return constants.LAYER_IDENTIFIER 

33 

34 @property 

35 def python_properties(self): 

36 # TODO(kathywu): Add python property validator 

37 return self._python_properties_internal() 

38 

39 def _python_properties_internal(self): 

40 """Returns dictionary of all python properties.""" 

41 # TODO(kathywu): Add support for metrics serialization. 

42 # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once 

43 # the python config serialization has caught up. 

44 metadata = dict( 

45 name=self.obj.name, 

46 trainable=self.obj.trainable, 

47 expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access 

48 dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access 

49 batch_input_shape=getattr(self.obj, '_batch_input_shape', None), 

50 stateful=self.obj.stateful, 

51 must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access 

52 ) 

53 

54 metadata.update(get_serialized(self.obj)) 

55 if self.obj.input_spec is not None: 

56 # Layer's input_spec has already been type-checked in the property setter. 

57 metadata['input_spec'] = nest.map_structure( 

58 lambda x: generic_utils.serialize_keras_object(x) if x else None, 

59 self.obj.input_spec) 

60 if (self.obj.activity_regularizer is not None and 

61 hasattr(self.obj.activity_regularizer, 'get_config')): 

62 metadata['activity_regularizer'] = generic_utils.serialize_keras_object( 

63 self.obj.activity_regularizer) 

64 if self.obj._build_input_shape is not None: # pylint: disable=protected-access 

65 metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access 

66 return metadata 

67 

68 def objects_to_serialize(self, serialization_cache): 

69 return (self._get_serialized_attributes( 

70 serialization_cache).objects_to_serialize) 

71 

72 def functions_to_serialize(self, serialization_cache): 

73 return (self._get_serialized_attributes( 

74 serialization_cache).functions_to_serialize) 

75 

76 def _get_serialized_attributes(self, serialization_cache): 

77 """Generates or retrieves serialized attributes from cache.""" 

78 keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {}) 

79 if self.obj in keras_cache: 

80 return keras_cache[self.obj] 

81 

82 serialized_attr = keras_cache[self.obj] = ( 

83 serialized_attributes.SerializedAttributes.new(self.obj)) 

84 

85 if (save_impl.should_skip_serialization(self.obj) or 

86 self.obj._must_restore_from_config): # pylint: disable=protected-access 

87 return serialized_attr 

88 

89 object_dict, function_dict = self._get_serialized_attributes_internal( 

90 serialization_cache) 

91 

92 serialized_attr.set_and_validate_objects(object_dict) 

93 serialized_attr.set_and_validate_functions(function_dict) 

94 return serialized_attr 

95 

96 def _get_serialized_attributes_internal(self, serialization_cache): 

97 """Returns dictionary of serialized attributes.""" 

98 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache) 

99 functions = save_impl.wrap_layer_functions(self.obj, serialization_cache) 

100 # Attribute validator requires that the default save signature is added to 

101 # function dict, even if the value is None. 

102 functions['_default_save_signature'] = None 

103 return objects, functions 

104 

105 

106# TODO(kathywu): Move serialization utils (and related utils from 

107# generic_utils.py) to a separate file. 

108def get_serialized(obj): 

109 with generic_utils.skip_failed_serialization(): 

110 # Store the config dictionary, which may be used when reviving the object. 

111 # When loading, the program will attempt to revive the object from config, 

112 # and if that fails, the object will be revived from the SavedModel. 

113 return generic_utils.serialize_keras_object(obj) 

114 

115 

116class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): 

117 """InputLayer serialization.""" 

118 

119 @property 

120 def object_identifier(self): 

121 return constants.INPUT_LAYER_IDENTIFIER 

122 

123 @property 

124 def python_properties(self): 

125 

126 return dict( 

127 class_name=type(self.obj).__name__, 

128 name=self.obj.name, 

129 dtype=self.obj.dtype, 

130 sparse=self.obj.sparse, 

131 ragged=self.obj.ragged, 

132 batch_input_shape=self.obj._batch_input_shape, # pylint: disable=protected-access 

133 config=self.obj.get_config()) 

134 

135 def objects_to_serialize(self, serialization_cache): 

136 return {} 

137 

138 def functions_to_serialize(self, serialization_cache): 

139 return {} 

140 

141 

142class RNNSavedModelSaver(LayerSavedModelSaver): 

143 """RNN layer serialization.""" 

144 

145 @property 

146 def object_identifier(self): 

147 return constants.RNN_LAYER_IDENTIFIER 

148 

149 def _get_serialized_attributes_internal(self, serialization_cache): 

150 objects, functions = ( 

151 super(RNNSavedModelSaver, self)._get_serialized_attributes_internal( 

152 serialization_cache)) 

153 states = data_structures.wrap_or_unwrap(self.obj.states) 

154 # SaveModel require all the objects to be Trackable when saving. 

155 # If the states is still a tuple after wrap_or_unwrap, it means it doesn't 

156 # contain any trackable item within it, eg empty tuple or (None, None) for 

157 # stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can 

158 # make it a Trackable again for saving. When loaded, ConvLSTM2D is 

159 # able to handle the tuple/list conversion. 

160 if isinstance(states, tuple): 

161 states = data_structures.wrap_or_unwrap(list(states)) 

162 objects['states'] = states 

163 return objects, functions 

164 

165 

166class IndexLookupLayerSavedModelSaver(LayerSavedModelSaver): 

167 """Index lookup layer serialization.""" 

168 

169 @property 

170 def python_properties(self): 

171 # TODO(kathywu): Add python property validator 

172 metadata = self._python_properties_internal() 

173 if metadata['config'].get('has_static_table', False): 

174 metadata['config']['vocabulary'] = None 

175 return metadata