1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions implementing Layer SavedModel serialization."""
16
17from tensorflow.python.keras.mixed_precision import policy
18from tensorflow.python.keras.saving.saved_model import base_serialization
19from tensorflow.python.keras.saving.saved_model import constants
20from tensorflow.python.keras.saving.saved_model import save_impl
21from tensorflow.python.keras.saving.saved_model import serialized_attributes
22from tensorflow.python.keras.utils import generic_utils
23from tensorflow.python.trackable import data_structures
24from tensorflow.python.util import nest
25
26
27class LayerSavedModelSaver(base_serialization.SavedModelSaver):
28 """Implements Layer SavedModel serialization."""
29
30 @property
31 def object_identifier(self):
32 return constants.LAYER_IDENTIFIER
33
34 @property
35 def python_properties(self):
36 # TODO(kathywu): Add python property validator
37 return self._python_properties_internal()
38
39 def _python_properties_internal(self):
40 """Returns dictionary of all python properties."""
41 # TODO(kathywu): Add support for metrics serialization.
42 # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
43 # the python config serialization has caught up.
44 metadata = dict(
45 name=self.obj.name,
46 trainable=self.obj.trainable,
47 expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
48 dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access
49 batch_input_shape=getattr(self.obj, '_batch_input_shape', None),
50 stateful=self.obj.stateful,
51 must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
52 )
53
54 metadata.update(get_serialized(self.obj))
55 if self.obj.input_spec is not None:
56 # Layer's input_spec has already been type-checked in the property setter.
57 metadata['input_spec'] = nest.map_structure(
58 lambda x: generic_utils.serialize_keras_object(x) if x else None,
59 self.obj.input_spec)
60 if (self.obj.activity_regularizer is not None and
61 hasattr(self.obj.activity_regularizer, 'get_config')):
62 metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
63 self.obj.activity_regularizer)
64 if self.obj._build_input_shape is not None: # pylint: disable=protected-access
65 metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
66 return metadata
67
68 def objects_to_serialize(self, serialization_cache):
69 return (self._get_serialized_attributes(
70 serialization_cache).objects_to_serialize)
71
72 def functions_to_serialize(self, serialization_cache):
73 return (self._get_serialized_attributes(
74 serialization_cache).functions_to_serialize)
75
76 def _get_serialized_attributes(self, serialization_cache):
77 """Generates or retrieves serialized attributes from cache."""
78 keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {})
79 if self.obj in keras_cache:
80 return keras_cache[self.obj]
81
82 serialized_attr = keras_cache[self.obj] = (
83 serialized_attributes.SerializedAttributes.new(self.obj))
84
85 if (save_impl.should_skip_serialization(self.obj) or
86 self.obj._must_restore_from_config): # pylint: disable=protected-access
87 return serialized_attr
88
89 object_dict, function_dict = self._get_serialized_attributes_internal(
90 serialization_cache)
91
92 serialized_attr.set_and_validate_objects(object_dict)
93 serialized_attr.set_and_validate_functions(function_dict)
94 return serialized_attr
95
96 def _get_serialized_attributes_internal(self, serialization_cache):
97 """Returns dictionary of serialized attributes."""
98 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
99 functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
100 # Attribute validator requires that the default save signature is added to
101 # function dict, even if the value is None.
102 functions['_default_save_signature'] = None
103 return objects, functions
104
105
106# TODO(kathywu): Move serialization utils (and related utils from
107# generic_utils.py) to a separate file.
108def get_serialized(obj):
109 with generic_utils.skip_failed_serialization():
110 # Store the config dictionary, which may be used when reviving the object.
111 # When loading, the program will attempt to revive the object from config,
112 # and if that fails, the object will be revived from the SavedModel.
113 return generic_utils.serialize_keras_object(obj)
114
115
116class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
117 """InputLayer serialization."""
118
119 @property
120 def object_identifier(self):
121 return constants.INPUT_LAYER_IDENTIFIER
122
123 @property
124 def python_properties(self):
125
126 return dict(
127 class_name=type(self.obj).__name__,
128 name=self.obj.name,
129 dtype=self.obj.dtype,
130 sparse=self.obj.sparse,
131 ragged=self.obj.ragged,
132 batch_input_shape=self.obj._batch_input_shape, # pylint: disable=protected-access
133 config=self.obj.get_config())
134
135 def objects_to_serialize(self, serialization_cache):
136 return {}
137
138 def functions_to_serialize(self, serialization_cache):
139 return {}
140
141
142class RNNSavedModelSaver(LayerSavedModelSaver):
143 """RNN layer serialization."""
144
145 @property
146 def object_identifier(self):
147 return constants.RNN_LAYER_IDENTIFIER
148
149 def _get_serialized_attributes_internal(self, serialization_cache):
150 objects, functions = (
151 super(RNNSavedModelSaver, self)._get_serialized_attributes_internal(
152 serialization_cache))
153 states = data_structures.wrap_or_unwrap(self.obj.states)
154 # SaveModel require all the objects to be Trackable when saving.
155 # If the states is still a tuple after wrap_or_unwrap, it means it doesn't
156 # contain any trackable item within it, eg empty tuple or (None, None) for
157 # stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can
158 # make it a Trackable again for saving. When loaded, ConvLSTM2D is
159 # able to handle the tuple/list conversion.
160 if isinstance(states, tuple):
161 states = data_structures.wrap_or_unwrap(list(states))
162 objects['states'] = states
163 return objects, functions
164
165
166class IndexLookupLayerSavedModelSaver(LayerSavedModelSaver):
167 """Index lookup layer serialization."""
168
169 @property
170 def python_properties(self):
171 # TODO(kathywu): Add python property validator
172 metadata = self._python_properties_internal()
173 if metadata['config'].get('has_static_table', False):
174 metadata['config']['vocabulary'] = None
175 return metadata