Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/serialized_attributes.py: 50%
80 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes that list&validate all attributes to serialize to SavedModel.
16"""
18import tensorflow.compat.v2 as tf
20from keras.src.saving.legacy.saved_model import constants
21from keras.src.saving.legacy.saved_model import order_preserving_set as ops
22from keras.src.saving.legacy.saved_model import save_impl
23from keras.src.utils.generic_utils import LazyLoader
25# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
26# once the issue with copybara is fixed.
28base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer")
29training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training")
30metrics = LazyLoader("metrics", globals(), "keras.src.metrics")
31base_rnn = LazyLoader("base_rnn", globals(), "keras.src.layers.rnn.base_rnn")
34class SerializedAttributes:
35 """Class that tracks and validates all serialization attributes.
37 Keras models contain many Python-defined components. For example, the
38 trainable_variable property lists the model's trainable variables by
39 recursively retrieving the trainable variables from each of the child
40 layers. Another example is model.call, a python function that calls child
41 layers and adds ops to the backend graph.
43 Only Tensorflow checkpointable objects and functions can be serialized to
44 SavedModel. Serializing a Keras model as-is results in a checkpointable
45 object that does not resemble a Keras model at all. Thus, extra
46 checkpointable objects and functions must be created during serialization.
48 **Defining new serialized attributes**
49 Child classes should be defined using:
50 SerializedAttributes.with_attributes(
51 'name', checkpointable_objects=[...],
52 functions=[...], copy_from=[...])
53 This class is used to cache generated checkpointable objects and functions,
54 ensuring that new objects and functions are generated a single time.
56 **Usage during serialization**
57 Each Layer/Model object should have a corresponding instance of
58 SerializedAttributes. Create a new instance by calling
59 `SerializedAttributes.new(obj)`. Objects and functions may be saved using
60 `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`.
61 The properties `.checkpointable_objects` and `.functions` returns the cached
62 values.
64 **Adding/changing attributes to save to SavedModel**
65 1. Change the call to `SerializedAttributes.with_attributes` in the correct
66 class:
67 - CommonEndpoints: Base attributes to be added during serialization. If
68 these attributes are present in a Trackable object, it can be
69 deserialized to a Keras Model.
70 - LayerAttributes: Attributes to serialize for Layer objects.
71 - ModelAttributes: Attributes to serialize for Model objects.
72 2. Update class docstring
73 3. Update arguments to any calls to `set_and_validate_*`. For example, if
74 `call_raw_tensors` is added to the ModelAttributes function list, then
75 a `call_raw_tensors` function should be passed to
76 `set_and_validate_functions`.
78 **Common endpoints vs other attributes**
79 Only common endpoints are attached directly to the root object.
80 Keras-specific attributes are saved to a separate trackable object with the
81 name "keras_api". The number of objects attached to the root is limited
82 because any naming conflicts will cause user code to break.
84 Another reason is that this will only affect users who call
85 `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are
86 advanced users who are likely to have defined their own tf.functions and
87 trackable objects. The added Keras-specific attributes are kept out of the
88 way in the "keras_api" namespace.
90 Properties defined in this class may be used to filter out keras-specific
91 attributes:
92 - `functions_to_serialize`: Returns dict of functions to attach to the root
93 object.
94 - `checkpointable_objects_to_serialize`: Returns dict of objects to attach
95 to the root object (including separate trackable object containing
96 keras-specific attributes)
98 All changes to the serialized attributes must be backwards-compatible, so
99 attributes should not be removed or modified without sufficient
100 justification.
101 """
103 @staticmethod
104 def with_attributes(
105 name, checkpointable_objects=None, functions=None, copy_from=None
106 ):
107 """Creates a subclass with all attributes as specified in the arguments.
109 Args:
110 name: Name of subclass
111 checkpointable_objects: List of checkpointable objects to be
112 serialized in the SavedModel.
113 functions: List of functions to be serialized in the SavedModel.
114 copy_from: List of other SerializedAttributes subclasses. The returned
115 class will copy checkpoint objects/functions from each subclass.
117 Returns:
118 Child class with attributes as defined in the `checkpointable_objects`
119 and `functions` lists.
120 """
121 checkpointable_objects = checkpointable_objects or []
122 functions = functions or []
124 if copy_from is not None:
125 for cls in copy_from:
126 checkpointable_objects.extend(cls.all_checkpointable_objects)
127 functions.extend(cls.all_functions)
129 # OrderPreservingSets are used here to guarantee serialization
130 # determinism of Keras objects.
131 classdict = {
132 "all_checkpointable_objects": ops.OrderPreservingSet(
133 checkpointable_objects
134 ),
135 "all_functions": ops.OrderPreservingSet(functions),
136 }
137 return type(name, (SerializedAttributes,), classdict)
139 @staticmethod
140 def new(obj):
141 """Returns a new SerializedAttribute object."""
142 if isinstance(obj, training_lib.Model):
143 return ModelAttributes()
144 elif isinstance(obj, metrics.Metric):
145 return MetricAttributes()
146 elif isinstance(obj, base_rnn.RNN):
147 return RNNAttributes()
148 elif isinstance(obj, base_layer.Layer):
149 return LayerAttributes()
150 else:
151 raise TypeError(
152 "Internal error during serialization. Expected Keras "
153 f"Layer object. Received: {obj} "
154 f"(of type {type(obj)})"
155 )
157 def __init__(self):
158 self._object_dict = {}
159 self._function_dict = {}
160 self._keras_trackable = tf.__internal__.tracking.AutoTrackable()
162 @property
163 def functions(self):
164 """Returns dictionary of all functions."""
165 return {
166 key: value
167 for key, value in self._function_dict.items()
168 if value is not None
169 }
171 @property
172 def checkpointable_objects(self):
173 """Returns dictionary of all checkpointable objects."""
174 return {
175 key: value
176 for key, value in self._object_dict.items()
177 if value is not None
178 }
180 @property
181 def functions_to_serialize(self):
182 """Returns functions to attach to the root object during
183 serialization."""
184 functions = {}
185 for key, v in self.functions.items():
186 if key in CommonEndpoints.all_functions:
187 functions[key] = (
188 v.wrapped_call if isinstance(v, save_impl.LayerCall) else v
189 )
190 return functions
192 @property
193 def objects_to_serialize(self):
194 """Returns objects to attach to the root object during serialization."""
195 objects = {
196 key: value
197 for key, value in self.checkpointable_objects.items()
198 if key in CommonEndpoints.all_checkpointable_objects
199 }
200 objects[constants.KERAS_ATTR] = self._keras_trackable
201 return objects
203 def set_and_validate_functions(self, function_dict):
204 """Saves function dictionary, and validates dictionary values."""
205 for key in self.all_functions:
206 if key in function_dict:
207 if function_dict[
208 key
209 # Not all functions are required
210 ] is not None and not isinstance(
211 function_dict[key],
212 (
213 tf.__internal__.function.Function,
214 tf.types.experimental.ConcreteFunction,
215 save_impl.LayerCall,
216 ),
217 ):
218 raise ValueError(
219 "The tf.function dictionary contained a non-function "
220 f"object: {function_dict[key]} (for key {key}). Only "
221 "tf.function instances or ConcreteFunction instances "
222 "should be passed."
223 )
224 fn = function_dict[key]
225 self._function_dict[key] = fn
227 # Extract TensorFlow `Function` from LayerCall.
228 tf_fn = (
229 fn.wrapped_call
230 if isinstance(fn, save_impl.LayerCall)
231 else fn
232 )
233 setattr(self._keras_trackable, key, tf_fn)
234 else:
235 raise ValueError(
236 f"Function {key} missing from serialized "
237 "tf.function dictionary."
238 )
239 return self.functions
241 def set_and_validate_objects(self, object_dict):
242 """Saves objects to a dictionary, and validates the values."""
243 for key in self.all_checkpointable_objects:
244 if key in object_dict:
245 if not isinstance(
246 object_dict[key], tf.__internal__.tracking.Trackable
247 ):
248 raise ValueError(
249 "The object dictionary contained a non-trackable "
250 f"object: {object_dict[key]} (for key {key}). "
251 "Only trackable objects are "
252 "allowed, such as Keras layers/models or "
253 "tf.Module instances."
254 )
255 self._object_dict[key] = object_dict[key]
256 setattr(self._keras_trackable, key, object_dict[key])
257 else:
258 raise ValueError(
259 f"Object {key} missing from serialized object dictionary."
260 )
261 return self.checkpointable_objects
264class CommonEndpoints(
265 SerializedAttributes.with_attributes(
266 "CommonEndpoints",
267 checkpointable_objects=[
268 "variables",
269 "trainable_variables",
270 "regularization_losses",
271 ],
272 functions=[
273 "__call__",
274 "call_and_return_all_conditional_losses",
275 "_default_save_signature",
276 ],
277 )
278):
279 """Common endpoints shared by all models loadable by Keras.
281 List of all attributes:
282 variables: List of all variables in the model and its sublayers.
283 trainable_variables: List of all trainable variables in the model and its
284 sublayers.
285 regularization_losses: List of all unconditional losses (losses not
286 dependent on the inputs) in the model and its sublayers.
287 __call__: Function that takes inputs and returns the outputs of the model
288 call function.
289 call_and_return_all_conditional_losses: Function that returns a tuple of
290 (call function outputs, list of all losses that depend on the inputs).
291 _default_save_signature: Traced model call function. This is only included
292 if the top level exported object is a Keras model.
293 """
296class LayerAttributes(
297 SerializedAttributes.with_attributes(
298 "LayerAttributes",
299 checkpointable_objects=[
300 "non_trainable_variables",
301 "layers",
302 "metrics",
303 "layer_regularization_losses",
304 "layer_metrics",
305 ],
306 functions=[
307 "call_and_return_conditional_losses",
308 "activity_regularizer_fn",
309 ],
310 copy_from=[CommonEndpoints],
311 )
312):
313 """Layer checkpointable objects + functions saved to the SavedModel.
315 List of all attributes:
316 All attributes from CommonEndpoints
317 non_trainable_variables: List of non-trainable variables in the layer and
318 its sublayers.
319 layers: List of all sublayers.
320 metrics: List of all metrics in the layer and its sublayers.
321 call_and_return_conditional_losses: Function that takes inputs and returns
322 a tuple of (outputs of the call function, list of input-dependent
323 losses). The list of losses excludes the activity regularizer function,
324 which is separate to allow the deserialized Layer object to define a
325 different activity regularizer.
326 activity_regularizer_fn: Callable that returns the activity regularizer
327 loss
328 layer_regularization_losses: List of losses owned only by this layer.
329 layer_metrics: List of metrics owned by this layer.
330 """
333class ModelAttributes(
334 SerializedAttributes.with_attributes(
335 "ModelAttributes", copy_from=[LayerAttributes]
336 )
337):
338 """Model checkpointable objects + functions saved to the SavedModel.
340 List of all attributes:
341 All attributes from LayerAttributes (including CommonEndpoints)
342 """
344 # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`,
345 # which list all losses and metrics defined by `model.compile`.
348class MetricAttributes(
349 SerializedAttributes.with_attributes(
350 "MetricAttributes",
351 checkpointable_objects=["variables"],
352 functions=[],
353 )
354):
355 """Attributes that are added to Metric objects when saved to SavedModel.
357 List of all attributes:
358 variables: list of all variables
359 """
361 pass
364class RNNAttributes(
365 SerializedAttributes.with_attributes(
366 "RNNAttributes",
367 checkpointable_objects=["states"],
368 copy_from=[LayerAttributes],
369 )
370):
371 """RNN checkpointable objects + functions that are saved to the SavedModel.
373 List of all attributes:
374 All attributes from LayerAttributes (including CommonEndpoints)
375 states: List of state variables
376 """