Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/layer_serialization.py: 43%

77 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Classes and functions implementing Layer SavedModel serialization.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src.mixed_precision import policy 

20from keras.src.saving.legacy import serialization 

21from keras.src.saving.legacy.saved_model import base_serialization 

22from keras.src.saving.legacy.saved_model import constants 

23from keras.src.saving.legacy.saved_model import save_impl 

24from keras.src.saving.legacy.saved_model import serialized_attributes 

25 

26 

27class LayerSavedModelSaver(base_serialization.SavedModelSaver): 

28 """Implements Layer SavedModel serialization.""" 

29 

30 @property 

31 def object_identifier(self): 

32 return constants.LAYER_IDENTIFIER 

33 

34 @property 

35 def python_properties(self): 

36 # TODO(kathywu): Add python property validator 

37 return self._python_properties_internal() 

38 

39 def _python_properties_internal(self): 

40 """Returns dictionary of all python properties.""" 

41 # TODO(kathywu): Add support for metrics serialization. 

42 # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) 

43 # once the python config serialization has caught up. 

44 metadata = dict( 

45 name=self.obj.name, 

46 trainable=self.obj.trainable, 

47 expects_training_arg=self.obj._expects_training_arg, 

48 dtype=policy.serialize(self.obj._dtype_policy), 

49 batch_input_shape=getattr(self.obj, "_batch_input_shape", None), 

50 stateful=self.obj.stateful, 

51 must_restore_from_config=self.obj._must_restore_from_config, 

52 preserve_input_structure_in_config=self.obj._preserve_input_structure_in_config, # noqa: E501 

53 autocast=self.obj._autocast, 

54 ) 

55 

56 metadata.update(get_serialized(self.obj)) 

57 if self.obj.input_spec is not None: 

58 # Layer's input_spec has already been type-checked in the property 

59 # setter. 

60 metadata["input_spec"] = tf.nest.map_structure( 

61 lambda x: serialization.serialize_keras_object(x) 

62 if x 

63 else None, 

64 self.obj.input_spec, 

65 ) 

66 if self.obj.activity_regularizer is not None and hasattr( 

67 self.obj.activity_regularizer, "get_config" 

68 ): 

69 metadata[ 

70 "activity_regularizer" 

71 ] = serialization.serialize_keras_object( 

72 self.obj.activity_regularizer 

73 ) 

74 if self.obj._build_input_shape is not None: 

75 metadata["build_input_shape"] = self.obj._build_input_shape 

76 return metadata 

77 

78 def objects_to_serialize(self, serialization_cache): 

79 return self._get_serialized_attributes( 

80 serialization_cache 

81 ).objects_to_serialize 

82 

83 def functions_to_serialize(self, serialization_cache): 

84 return self._get_serialized_attributes( 

85 serialization_cache 

86 ).functions_to_serialize 

87 

88 def _get_serialized_attributes(self, serialization_cache): 

89 """Generates or retrieves serialized attributes from cache.""" 

90 keras_cache = serialization_cache.setdefault( 

91 constants.KERAS_CACHE_KEY, {} 

92 ) 

93 if self.obj in keras_cache: 

94 return keras_cache[self.obj] 

95 

96 serialized_attr = keras_cache[ 

97 self.obj 

98 ] = serialized_attributes.SerializedAttributes.new(self.obj) 

99 

100 if ( 

101 save_impl.should_skip_serialization(self.obj) 

102 or self.obj._must_restore_from_config 

103 ): 

104 return serialized_attr 

105 

106 object_dict, function_dict = self._get_serialized_attributes_internal( 

107 serialization_cache 

108 ) 

109 

110 serialized_attr.set_and_validate_objects(object_dict) 

111 serialized_attr.set_and_validate_functions(function_dict) 

112 return serialized_attr 

113 

114 def _get_serialized_attributes_internal(self, serialization_cache): 

115 """Returns dictionary of serialized attributes.""" 

116 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache) 

117 functions = save_impl.wrap_layer_functions( 

118 self.obj, serialization_cache 

119 ) 

120 # Attribute validator requires that the default save signature is added 

121 # to function dict, even if the value is None. 

122 functions["_default_save_signature"] = None 

123 return objects, functions 

124 

125 

126# TODO(kathywu): Move serialization utils (and related utils from 

127# generic_utils.py) to a separate file. 

128def get_serialized(obj): 

129 with serialization.skip_failed_serialization(): 

130 # Store the config dictionary, which may be used when reviving the 

131 # object. When loading, the program will attempt to revive the object 

132 # from config, and if that fails, the object will be revived from the 

133 # SavedModel. 

134 return serialization.serialize_keras_object(obj) 

135 

136 

137class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): 

138 """InputLayer serialization.""" 

139 

140 @property 

141 def object_identifier(self): 

142 return constants.INPUT_LAYER_IDENTIFIER 

143 

144 @property 

145 def python_properties(self): 

146 

147 return dict( 

148 class_name=type(self.obj).__name__, 

149 name=self.obj.name, 

150 dtype=self.obj.dtype, 

151 sparse=self.obj.sparse, 

152 ragged=self.obj.ragged, 

153 batch_input_shape=self.obj._batch_input_shape, 

154 config=self.obj.get_config(), 

155 ) 

156 

157 def objects_to_serialize(self, serialization_cache): 

158 return {} 

159 

160 def functions_to_serialize(self, serialization_cache): 

161 return {} 

162 

163 

164class RNNSavedModelSaver(LayerSavedModelSaver): 

165 """RNN layer serialization.""" 

166 

167 @property 

168 def object_identifier(self): 

169 return constants.RNN_LAYER_IDENTIFIER 

170 

171 def _get_serialized_attributes_internal(self, serialization_cache): 

172 objects, functions = super()._get_serialized_attributes_internal( 

173 serialization_cache 

174 ) 

175 states = tf.__internal__.tracking.wrap(self.obj.states) 

176 # SaveModel require all the objects to be Trackable when saving. If the 

177 # states is still a tuple after wrap_or_unwrap, it means it doesn't 

178 # contain any trackable item within it, eg empty tuple or (None, None) 

179 # for stateless ConvLSTM2D. We convert them to list so that 

180 # wrap_or_unwrap can make it a Trackable again for saving. When loaded, 

181 # ConvLSTM2D is able to handle the tuple/list conversion. 

182 if isinstance(states, tuple): 

183 states = tf.__internal__.tracking.wrap(list(states)) 

184 objects["states"] = states 

185 return objects, functions 

186 

187 

188class VocabularySavedModelSaver(LayerSavedModelSaver): 

189 """Handles vocabulary layer serialization. 

190 

191 This class is needed for StringLookup, IntegerLookup, and TextVectorization, 

192 which all have a vocabulary as part of the config. Currently, we keep this 

193 vocab as part of the config until saving, when we need to clear it to avoid 

194 initializing a StaticHashTable twice (once when restoring the config and 

195 once when restoring restoring module resources). After clearing the vocab, 

196 we persist a property to the layer indicating it was constructed with a 

197 vocab. 

198 """ 

199 

200 @property 

201 def python_properties(self): 

202 # TODO(kathywu): Add python property validator 

203 metadata = self._python_properties_internal() 

204 # Clear the vocabulary from the config during saving. 

205 metadata["config"]["vocabulary"] = None 

206 # Persist a property to track that a vocabulary was passed on 

207 # construction. 

208 metadata["config"][ 

209 "has_input_vocabulary" 

210 ] = self.obj._has_input_vocabulary 

211 return metadata 

212