1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes that list&validate all attributes to serialize to SavedModel.
16"""
17
18from tensorflow.python.eager import def_function
19from tensorflow.python.keras.saving.saved_model import constants
20from tensorflow.python.keras.saving.saved_model import save_impl
21from tensorflow.python.keras.utils.generic_utils import LazyLoader
22from tensorflow.python.trackable import base as trackable
23from tensorflow.python.trackable.autotrackable import AutoTrackable
24
25# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
26# once the issue with copybara is fixed.
27# pylint:disable=g-inconsistent-quotes
28base_layer = LazyLoader(
29 "base_layer", globals(),
30 "tensorflow.python.keras.engine.base_layer")
31training_lib = LazyLoader(
32 "training_lib", globals(),
33 "tensorflow.python.keras.engine.training")
34metrics = LazyLoader("metrics", globals(),
35 "tensorflow.python.keras.metrics")
36recurrent = LazyLoader(
37 "recurrent", globals(),
38 "tensorflow.python.keras.layers.recurrent")
39# pylint:enable=g-inconsistent-quotes
40
41
42class SerializedAttributes(object):
43 """Class that tracks and validates all serialization attributes.
44
45 Keras models contain many Python-defined components. For example, the
46 trainable_variable property lists the model's trainable variables by
47 recursively retrieving the trainable variables from each of the child layers.
48 Another example is model.call, a python function that calls child layers and
49 adds ops to the backend graph.
50
51 Only Tensorflow checkpointable objects and functions can be serialized to
52 SavedModel. Serializing a Keras model as-is results in a checkpointable object
53 that does not resemble a Keras model at all. Thus, extra checkpointable
54 objects and functions must be created during serialization.
55
56 **Defining new serialized attributes**
57 Child classes should be defined using:
58 SerializedAttributes.with_attributes(
59 'name', checkpointable_objects=[...], functions=[...], copy_from=[...])
60 This class is used to cache generated checkpointable objects and functions,
61 ensuring that new objects and functions are generated a single time.
62
63 **Usage during serialization**
64 Each Layer/Model object should have a corresponding instance of
65 SerializedAttributes. Create a new instance by calling
66 `SerializedAttributes.new(obj)`. Objects and functions may be saved using
67 `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`.
68 The properties `.checkpointable_objects` and `.functions` returns the cached
69 values.
70
71 **Adding/changing attributes to save to SavedModel**
72 1. Change the call to `SerializedAttributes.with_attributes` in the correct
73 class:
74 - CommonEndpoints: Base attributes to be added during serialization. If
75 these attributes are present in a Trackable object, it can be
76 deserialized to a Keras Model.
77 - LayerAttributes: Attributes to serialize for Layer objects.
78 - ModelAttributes: Attributes to serialize for Model objects.
79 2. Update class docstring
80 3. Update arguments to any calls to `set_and_validate_*`. For example, if
81 `call_raw_tensors` is added to the ModelAttributes function list, then
82 a `call_raw_tensors` function should be passed to
83 `set_and_validate_functions`.
84
85 **Common endpoints vs other attributes**
86 Only common endpoints are attached directly to the root object. Keras-specific
87 attributes are saved to a separate trackable object with the name "keras_api".
88 The number of objects attached to the root is limited because any naming
89 conflicts will cause user code to break.
90
91 Another reason is that this will only affect users who call
92 `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are
93 advanced users who are likely to have defined their own tf.functions and
94 trackable objects. The added Keras-specific attributes are kept out of the way
95 in the "keras_api" namespace.
96
97 Properties defined in this class may be used to filter out keras-specific
98 attributes:
99 - `functions_to_serialize`: Returns dict of functions to attach to the root
100 object.
101 - `checkpointable_objects_to_serialize`: Returns dict of objects to attach to
102 the root object (including separate trackable object containing
103 keras-specific attributes)
104
105 All changes to the serialized attributes must be backwards-compatible, so
106 attributes should not be removed or modified without sufficient justification.
107 """
108
109 @staticmethod
110 def with_attributes(
111 name, checkpointable_objects=None, functions=None, copy_from=None):
112 """Creates a subclass with all attributes as specified in the arguments.
113
114 Args:
115 name: Name of subclass
116 checkpointable_objects: List of checkpointable objects to be serialized
117 in the SavedModel.
118 functions: List of functions to be serialized in the SavedModel.
119 copy_from: List of other SerializedAttributes subclasses. The returned
120 class will copy checkpoint objects/functions from each subclass.
121
122 Returns:
123 Child class with attributes as defined in the `checkpointable_objects`
124 and `functions` lists.
125 """
126 checkpointable_objects = checkpointable_objects or []
127 functions = functions or []
128
129 if copy_from is not None:
130 for cls in copy_from:
131 checkpointable_objects.extend(cls.all_checkpointable_objects)
132 functions.extend(cls.all_functions)
133
134 classdict = {
135 'all_checkpointable_objects': set(checkpointable_objects),
136 'all_functions': set(functions)}
137 return type(name, (SerializedAttributes,), classdict)
138
139 @staticmethod
140 def new(obj):
141 """Returns a new SerializedAttribute object."""
142 if isinstance(obj, training_lib.Model):
143 return ModelAttributes()
144 elif isinstance(obj, metrics.Metric):
145 return MetricAttributes()
146 elif isinstance(obj, recurrent.RNN):
147 return RNNAttributes()
148 elif isinstance(obj, base_layer.Layer):
149 return LayerAttributes()
150 else:
151 raise TypeError('Internal error during serialization: Expected Keras '
152 'Layer object, got {} of type {}'.format(obj, type(obj)))
153
154 def __init__(self):
155 self._object_dict = {}
156 self._function_dict = {}
157 self._keras_trackable = AutoTrackable()
158
159 @property
160 def functions(self):
161 """Returns dictionary of all functions."""
162 return {key: value for key, value in self._function_dict.items()
163 if value is not None}
164
165 @property
166 def checkpointable_objects(self):
167 """Returns dictionary of all checkpointable objects."""
168 return {key: value for key, value in self._object_dict.items()
169 if value is not None}
170
171 @property
172 def functions_to_serialize(self):
173 """Returns functions to attach to the root object during serialization."""
174 functions = {}
175 for key, v in self.functions.items():
176 if key in CommonEndpoints.all_functions:
177 functions[key] = (v.wrapped_call if isinstance(v, save_impl.LayerCall)
178 else v)
179 return functions
180
181 @property
182 def objects_to_serialize(self):
183 """Returns objects to attach to the root object during serialization."""
184 objects = {key: value for key, value in self.checkpointable_objects.items()
185 if key in CommonEndpoints.all_checkpointable_objects}
186 objects[constants.KERAS_ATTR] = self._keras_trackable
187 return objects
188
189 def set_and_validate_functions(self, function_dict):
190 """Saves function dictionary, and validates dictionary values."""
191 for key in self.all_functions:
192 if key in function_dict:
193 if (function_dict[key] is not None and # Not all functions are required
194 not isinstance(function_dict[key],
195 (def_function.Function, save_impl.LayerCall))):
196 raise ValueError(
197 'Function dictionary contained a non-function object: {} (for key'
198 ' {})'.format(function_dict[key], key))
199 fn = function_dict[key]
200 self._function_dict[key] = fn
201
202 # Extract TensorFlow `Function` from LayerCall.
203 tf_fn = fn.wrapped_call if isinstance(fn, save_impl.LayerCall) else fn
204 setattr(self._keras_trackable, key, tf_fn)
205 else:
206 raise ValueError('Function {} missing from serialized function dict.'
207 .format(key))
208 return self.functions
209
210 def set_and_validate_objects(self, object_dict):
211 """Saves objects to a dictionary, and validates the values."""
212 for key in self.all_checkpointable_objects:
213 if key in object_dict:
214 if not isinstance(object_dict[key], trackable.Trackable):
215 raise ValueError(
216 'Object dictionary contained a non-trackable object: {} (for key'
217 ' {})'.format(object_dict[key], key))
218 self._object_dict[key] = object_dict[key]
219 setattr(self._keras_trackable, key, object_dict[key])
220 else:
221 raise ValueError(
222 'Object {} missing from serialized object dict.'.format(key))
223 return self.checkpointable_objects
224
225
226class CommonEndpoints(SerializedAttributes.with_attributes(
227 'CommonEndpoints',
228 checkpointable_objects=['variables', 'trainable_variables',
229 'regularization_losses'],
230 functions=['__call__', 'call_and_return_all_conditional_losses',
231 '_default_save_signature'])):
232 """Common endpoints shared by all models loadable by Keras.
233
234 List of all attributes:
235 variables: List of all variables in the model and its sublayers.
236 trainable_variables: List of all trainable variables in the model and its
237 sublayers.
238 regularization_losses: List of all unconditional losses (losses not
239 dependent on the inputs) in the model and its sublayers.
240 __call__: Function that takes inputs and returns the outputs of the model
241 call function.
242 call_and_return_all_conditional_losses: Function that returns a tuple of
243 (call function outputs, list of all losses that depend on the inputs).
244 _default_save_signature: Traced model call function. This is only included
245 if the top level exported object is a Keras model.
246 """
247
248
249class LayerAttributes(SerializedAttributes.with_attributes(
250 'LayerAttributes',
251 checkpointable_objects=['non_trainable_variables', 'layers', 'metrics',
252 'layer_regularization_losses', 'layer_metrics'],
253 functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'],
254 copy_from=[CommonEndpoints]
255 )):
256 """Layer checkpointable objects + functions that are saved to the SavedModel.
257
258 List of all attributes:
259 All attributes from CommonEndpoints
260 non_trainable_variables: List of non-trainable variables in the layer and
261 its sublayers.
262 layers: List of all sublayers.
263 metrics: List of all metrics in the layer and its sublayers.
264 call_and_return_conditional_losses: Function that takes inputs and returns a
265 tuple of (outputs of the call function, list of input-dependent losses).
266 The list of losses excludes the activity regularizer function, which is
267 separate to allow the deserialized Layer object to define a different
268 activity regularizer.
269 activity_regularizer_fn: Callable that returns the activity regularizer loss
270 layer_regularization_losses: List of losses owned only by this layer.
271 layer_metrics: List of metrics owned by this layer.
272 """
273
274
275class ModelAttributes(SerializedAttributes.with_attributes(
276 'ModelAttributes',
277 copy_from=[LayerAttributes])):
278 """Model checkpointable objects + functions that are saved to the SavedModel.
279
280 List of all attributes:
281 All attributes from LayerAttributes (including CommonEndpoints)
282 """
283 # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which
284 # list all losses and metrics defined by `model.compile`.
285
286
287class MetricAttributes(
288 SerializedAttributes.with_attributes(
289 'MetricAttributes',
290 checkpointable_objects=['variables'],
291 functions=[],
292 )):
293 """Attributes that are added to Metric objects when saved to SavedModel.
294
295 List of all attributes:
296 variables: list of all variables
297 """
298 pass
299
300
301class RNNAttributes(SerializedAttributes.with_attributes(
302 'RNNAttributes',
303 checkpointable_objects=['states'],
304 copy_from=[LayerAttributes])):
305 """RNN checkpointable objects + functions that are saved to the SavedModel.
306
307 List of all attributes:
308 All attributes from LayerAttributes (including CommonEndpoints)
309 states: List of state variables
310 """
311