Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/layer_serialization.py: 43%
77 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions implementing Layer SavedModel serialization."""
17import tensorflow.compat.v2 as tf
19from keras.src.mixed_precision import policy
20from keras.src.saving.legacy import serialization
21from keras.src.saving.legacy.saved_model import base_serialization
22from keras.src.saving.legacy.saved_model import constants
23from keras.src.saving.legacy.saved_model import save_impl
24from keras.src.saving.legacy.saved_model import serialized_attributes
27class LayerSavedModelSaver(base_serialization.SavedModelSaver):
28 """Implements Layer SavedModel serialization."""
30 @property
31 def object_identifier(self):
32 return constants.LAYER_IDENTIFIER
34 @property
35 def python_properties(self):
36 # TODO(kathywu): Add python property validator
37 return self._python_properties_internal()
39 def _python_properties_internal(self):
40 """Returns dictionary of all python properties."""
41 # TODO(kathywu): Add support for metrics serialization.
42 # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec)
43 # once the python config serialization has caught up.
44 metadata = dict(
45 name=self.obj.name,
46 trainable=self.obj.trainable,
47 expects_training_arg=self.obj._expects_training_arg,
48 dtype=policy.serialize(self.obj._dtype_policy),
49 batch_input_shape=getattr(self.obj, "_batch_input_shape", None),
50 stateful=self.obj.stateful,
51 must_restore_from_config=self.obj._must_restore_from_config,
52 preserve_input_structure_in_config=self.obj._preserve_input_structure_in_config, # noqa: E501
53 autocast=self.obj._autocast,
54 )
56 metadata.update(get_serialized(self.obj))
57 if self.obj.input_spec is not None:
58 # Layer's input_spec has already been type-checked in the property
59 # setter.
60 metadata["input_spec"] = tf.nest.map_structure(
61 lambda x: serialization.serialize_keras_object(x)
62 if x
63 else None,
64 self.obj.input_spec,
65 )
66 if self.obj.activity_regularizer is not None and hasattr(
67 self.obj.activity_regularizer, "get_config"
68 ):
69 metadata[
70 "activity_regularizer"
71 ] = serialization.serialize_keras_object(
72 self.obj.activity_regularizer
73 )
74 if self.obj._build_input_shape is not None:
75 metadata["build_input_shape"] = self.obj._build_input_shape
76 return metadata
78 def objects_to_serialize(self, serialization_cache):
79 return self._get_serialized_attributes(
80 serialization_cache
81 ).objects_to_serialize
83 def functions_to_serialize(self, serialization_cache):
84 return self._get_serialized_attributes(
85 serialization_cache
86 ).functions_to_serialize
88 def _get_serialized_attributes(self, serialization_cache):
89 """Generates or retrieves serialized attributes from cache."""
90 keras_cache = serialization_cache.setdefault(
91 constants.KERAS_CACHE_KEY, {}
92 )
93 if self.obj in keras_cache:
94 return keras_cache[self.obj]
96 serialized_attr = keras_cache[
97 self.obj
98 ] = serialized_attributes.SerializedAttributes.new(self.obj)
100 if (
101 save_impl.should_skip_serialization(self.obj)
102 or self.obj._must_restore_from_config
103 ):
104 return serialized_attr
106 object_dict, function_dict = self._get_serialized_attributes_internal(
107 serialization_cache
108 )
110 serialized_attr.set_and_validate_objects(object_dict)
111 serialized_attr.set_and_validate_functions(function_dict)
112 return serialized_attr
114 def _get_serialized_attributes_internal(self, serialization_cache):
115 """Returns dictionary of serialized attributes."""
116 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
117 functions = save_impl.wrap_layer_functions(
118 self.obj, serialization_cache
119 )
120 # Attribute validator requires that the default save signature is added
121 # to function dict, even if the value is None.
122 functions["_default_save_signature"] = None
123 return objects, functions
126# TODO(kathywu): Move serialization utils (and related utils from
127# generic_utils.py) to a separate file.
128def get_serialized(obj):
129 with serialization.skip_failed_serialization():
130 # Store the config dictionary, which may be used when reviving the
131 # object. When loading, the program will attempt to revive the object
132 # from config, and if that fails, the object will be revived from the
133 # SavedModel.
134 return serialization.serialize_keras_object(obj)
137class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
138 """InputLayer serialization."""
140 @property
141 def object_identifier(self):
142 return constants.INPUT_LAYER_IDENTIFIER
144 @property
145 def python_properties(self):
147 return dict(
148 class_name=type(self.obj).__name__,
149 name=self.obj.name,
150 dtype=self.obj.dtype,
151 sparse=self.obj.sparse,
152 ragged=self.obj.ragged,
153 batch_input_shape=self.obj._batch_input_shape,
154 config=self.obj.get_config(),
155 )
157 def objects_to_serialize(self, serialization_cache):
158 return {}
160 def functions_to_serialize(self, serialization_cache):
161 return {}
164class RNNSavedModelSaver(LayerSavedModelSaver):
165 """RNN layer serialization."""
167 @property
168 def object_identifier(self):
169 return constants.RNN_LAYER_IDENTIFIER
171 def _get_serialized_attributes_internal(self, serialization_cache):
172 objects, functions = super()._get_serialized_attributes_internal(
173 serialization_cache
174 )
175 states = tf.__internal__.tracking.wrap(self.obj.states)
176 # SaveModel require all the objects to be Trackable when saving. If the
177 # states is still a tuple after wrap_or_unwrap, it means it doesn't
178 # contain any trackable item within it, eg empty tuple or (None, None)
179 # for stateless ConvLSTM2D. We convert them to list so that
180 # wrap_or_unwrap can make it a Trackable again for saving. When loaded,
181 # ConvLSTM2D is able to handle the tuple/list conversion.
182 if isinstance(states, tuple):
183 states = tf.__internal__.tracking.wrap(list(states))
184 objects["states"] = states
185 return objects, functions
188class VocabularySavedModelSaver(LayerSavedModelSaver):
189 """Handles vocabulary layer serialization.
191 This class is needed for StringLookup, IntegerLookup, and TextVectorization,
192 which all have a vocabulary as part of the config. Currently, we keep this
193 vocab as part of the config until saving, when we need to clear it to avoid
194 initializing a StaticHashTable twice (once when restoring the config and
195 once when restoring restoring module resources). After clearing the vocab,
196 we persist a property to the layer indicating it was constructed with a
197 vocab.
198 """
200 @property
201 def python_properties(self):
202 # TODO(kathywu): Add python property validator
203 metadata = self._python_properties_internal()
204 # Clear the vocabulary from the config during saving.
205 metadata["config"]["vocabulary"] = None
206 # Persist a property to track that a vocabulary was passed on
207 # construction.
208 metadata["config"][
209 "has_input_vocabulary"
210 ] = self.obj._has_input_vocabulary
211 return metadata