Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/load.py: 15%
585 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel deserialization."""
17import re
18import types
19import warnings
21import tensorflow.compat.v1.logging as logging
22import tensorflow.compat.v2 as tf
23from google.protobuf import message
25from keras.src import backend
26from keras.src import regularizers
27from keras.src.engine import input_spec
28from keras.src.optimizers.legacy import optimizer_v2
29from keras.protobuf import saved_metadata_pb2
30from keras.protobuf import versions_pb2
31from keras.src.saving import object_registration
32from keras.src.saving.legacy import model_config
33from keras.src.saving.legacy import saving_utils
34from keras.src.saving.legacy import serialization
35from keras.src.saving.legacy.saved_model import constants
36from keras.src.saving.legacy.saved_model import json_utils
37from keras.src.saving.legacy.saved_model import utils
38from keras.src.saving.legacy.saved_model.serialized_attributes import (
39 CommonEndpoints,
40)
41from keras.src.utils import layer_utils
42from keras.src.utils import metrics_utils
43from keras.src.utils import tf_inspect
44from keras.src.utils.generic_utils import LazyLoader
46# To avoid circular dependencies between keras/engine and keras/saving,
47# code in keras/saving must delay imports.
49# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
50# once the issue with copybara is fixed.
52models_lib = LazyLoader("models_lib", globals(), "keras.src.models")
53base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer")
54layers_module = LazyLoader("layers_module", globals(), "keras.src.layers")
55input_layer = LazyLoader("input_layer", globals(), "keras.src.engine.input_layer")
56functional_lib = LazyLoader(
57 "functional_lib", globals(), "keras.src.engine.functional"
58)
59training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training")
60training_lib_v1 = LazyLoader(
61 "training_lib_v1", globals(), "keras.src.engine.training_v1"
62)
63metrics = LazyLoader("metrics", globals(), "keras.src.metrics")
64base_rnn = LazyLoader("base_rnn", globals(), "keras.src.layers.rnn.base_rnn")
67PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
68 CommonEndpoints.all_checkpointable_objects
69)
70PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
73def load(path, compile=True, options=None):
74 """Loads Keras objects from a SavedModel.
76 Any Keras layer or model saved to the SavedModel will be loaded back
77 as Keras objects. Other objects are loaded as regular trackable objects
78 (same as `tf.saved_model.load`).
80 Currently, Keras saving/loading only retains the Keras object's weights,
81 losses, and call function.
83 The loaded model can be re-compiled, but the original optimizer, compiled
84 loss functions, and metrics are not retained. This is temporary, and
85 `model.save` will soon be able to serialize compiled models.
87 Args:
88 path: Path to SavedModel.
89 compile: If true, compile the model after loading it.
90 options: Optional `tf.saved_model.LoadOptions` object that specifies
91 options for loading from SavedModel.
93 Returns:
94 Object loaded from SavedModel.
95 """
96 # TODO(kathywu): Add saving/loading of optimizer, compiled losses and
97 # metrics.
98 # TODO(kathywu): Add code to load from objects that contain all endpoints
100 # Look for metadata file or parse the SavedModel
101 metadata = saved_metadata_pb2.SavedMetadata()
102 meta_graph_def = tf.__internal__.saved_model.parse_saved_model(
103 path
104 ).meta_graphs[0]
105 object_graph_def = meta_graph_def.object_graph_def
106 path_to_metadata_pb = tf.io.gfile.join(path, constants.SAVED_METADATA_PATH)
107 if tf.compat.v1.gfile.Exists(path_to_metadata_pb):
108 try:
109 with tf.io.gfile.GFile(path_to_metadata_pb, "rb") as f:
110 file_content = f.read()
111 metadata.ParseFromString(file_content)
112 except message.DecodeError as e:
113 raise IOError(
114 f"Cannot parse keras metadata at path {path_to_metadata_pb}: "
115 f"Received error: {e}"
116 )
117 else:
118 logging.warning(
119 "SavedModel saved prior to TF 2.5 detected when loading "
120 "Keras model. Please ensure that you are saving the model "
121 "with model.save() or tf.keras.models.save_model(), *NOT* "
122 "tf.saved_model.save(). To confirm, there should be a file "
123 'named "keras_metadata.pb" in the SavedModel directory.'
124 )
125 _read_legacy_metadata(object_graph_def, metadata, path)
127 if not metadata.nodes:
128 # When there are no Keras objects, return the results from the core
129 # loader
130 return tf.saved_model.load(path, options=options)
132 metadata = _update_to_current_version(metadata)
133 # Recreate layers and metrics using the info stored in the metadata.
134 keras_loader = KerasObjectLoader(metadata, object_graph_def)
135 keras_loader.load_layers(compile=compile)
137 # Generate a dictionary of all loaded nodes.
138 nodes_to_load = {"root": None}
139 for node_id, loaded_node in keras_loader.loaded_nodes.items():
140 nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
141 with warnings.catch_warnings():
142 warnings.filterwarnings(
143 "ignore", message="Trying to load ShardedVariables"
144 )
145 loaded = tf.__internal__.saved_model.load_partial(
146 path, nodes_to_load, options=options
147 )
149 # Finalize the loaded layers and remove the extra tracked dependencies.
150 keras_loader.finalize_objects()
151 keras_loader.del_tracking()
153 model = loaded["root"]
155 if isinstance(model, training_lib.Model) and compile:
156 # TODO(kathywu): Use compiled objects from SavedModel, instead of
157 # creating new objects from the training config.
158 training_config = model._serialized_attributes["metadata"].get(
159 "training_config", None
160 )
161 if training_config is not None:
162 model.compile(
163 **saving_utils.compile_args_from_training_config(
164 training_config
165 ),
166 from_serialized=True,
167 )
168 saving_utils.try_build_compiled_arguments(model)
169 if isinstance(model.optimizer, optimizer_v2.OptimizerV2):
170 if model.optimizer.get_slot_names():
171 logging.warning(
172 "Your optimizer uses slots. "
173 "Slots cannot be restored from saved_model, "
174 "as a result, your model is starting with "
175 "a new initialized optimizer."
176 )
177 else:
178 logging.warning(
179 "No training configuration found in save file, so the "
180 "model was *not* compiled. Compile it manually."
181 )
183 # Force variables and resources to initialize.
184 if not tf.executing_eagerly():
185 sess = backend.get_session() # Variables are initialized by this call.
186 sess.run(
187 tf.compat.v1.get_collection(
188 tf.compat.v1.GraphKeys.TABLE_INITIALIZERS
189 )
190 )
192 return model
195def _update_to_current_version(metadata):
196 """Applies version updates to the metadata proto for backwards compat."""
197 for node in metadata.nodes:
198 if node.version.producer == 1 and node.identifier in [
199 constants.MODEL_IDENTIFIER,
200 constants.SEQUENTIAL_IDENTIFIER,
201 constants.NETWORK_IDENTIFIER,
202 ]:
203 node_metadata = json_utils.decode(node.metadata)
204 save_spec = node_metadata.get("save_spec")
206 if save_spec is not None:
207 node_metadata["full_save_spec"] = ([save_spec], {})
208 node.metadata = json_utils.Encoder().encode(node_metadata)
209 return metadata
212def _read_legacy_metadata(object_graph_def, metadata, path):
213 """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
214 # Older SavedModels store the metadata directly in the proto instead of the
215 # separate pb file.
216 node_paths = _generate_object_paths(object_graph_def)
217 for node_id, proto in enumerate(object_graph_def.nodes):
218 if (
219 proto.WhichOneof("kind") == "user_object"
220 and proto.user_object.identifier
221 in constants.KERAS_OBJECT_IDENTIFIERS
222 ):
223 if not proto.user_object.metadata:
224 raise ValueError(
225 "Unable to create a Keras model from SavedModel at "
226 f"{path}. This SavedModel was exported with "
227 "`tf.saved_model.save`, and lacks the Keras metadata file. "
228 "Please save your Keras model by calling `model.save` "
229 "or `tf.keras.models.save_model`. Note that "
230 "you can still load this SavedModel with "
231 "`tf.saved_model.load`."
232 )
233 metadata.nodes.add(
234 node_id=node_id,
235 node_path=node_paths[node_id],
236 version=versions_pb2.VersionDef(
237 producer=1, min_consumer=1, bad_consumers=[]
238 ),
239 identifier=proto.user_object.identifier,
240 metadata=proto.user_object.metadata,
241 )
244def _generate_object_paths(object_graph_def):
245 """Traverses through an ObjectGraphDef and builds a map of all node
246 paths."""
247 paths = {0: "root"}
248 nodes_to_visit = [0]
250 while nodes_to_visit:
251 current_node = nodes_to_visit.pop()
252 current_path = paths[current_node]
253 for reference in object_graph_def.nodes[current_node].children:
254 if reference.node_id in paths:
255 continue
256 paths[reference.node_id] = f"{current_path}.{reference.local_name}"
257 nodes_to_visit.append(reference.node_id)
259 return paths
262def _is_graph_network(layer):
263 """Determines whether the layer is a graph network."""
265 if isinstance(layer, RevivedNetwork):
266 return False
267 elif isinstance(layer, functional_lib.Functional):
268 return layer._is_graph_network or isinstance(
269 layer, models_lib.Sequential
270 )
271 return False
274class KerasObjectLoader:
275 """Loader that recreates Keras objects (e.g.
277 layers, models).
279 Layers and models are revived from either the config or SavedModel following
280 these rules:
281 1. If object is a graph network (i.e. Sequential or Functional) then it will
282 be initialized using the structure from the config only after the
283 children layers have been created. Graph networks must be initialized
284 with inputs and outputs, so all child layers must be created beforehand.
285 2. If object's config exists and the class can be found, then revive from
286 config.
287 3. Object may have already been created if its parent was revived from
288 config. In this case, do nothing.
289 4. If nothing of the above applies, compose the various artifacts from the
290 SavedModel to create a subclassed layer or model. At this time, custom
291 metrics are not supported.
293 """
295 def __init__(self, metadata, object_graph_def):
296 self._metadata = {x.node_id: x for x in metadata.nodes}
297 self._proto = object_graph_def
299 self._node_paths = {
300 node_data.node_id: node_data.node_path
301 for node_data in metadata.nodes
302 }
303 self.loaded_nodes = {} # Maps node path -> loaded node
305 # Store all node ids that have already been traversed when tracking
306 # nodes that were recreated from the config.
307 self._traversed_nodes_from_config = set()
309 # Maps model id -> (blank model obj, list of child layer or their node
310 # ids) This tracks all layers in functional and sequential models. These
311 # models are only reconstructed after all of their child layers have
312 # been created.
313 self.model_layer_dependencies = {}
314 self._models_to_reconstruct = []
316 def del_tracking(self):
317 """Removes tracked references that are only used when loading the
318 model."""
319 # Now that the node object has been fully loaded, and the checkpoint has
320 # been restored, the object no longer needs to track objects added from
321 # SerializedAttributes. (Note that saving a training checkpoint still
322 # functions correctly, because layers and variables are tracked
323 # separately by the Layer object.)
324 # TODO(kathywu): Instead of outright deleting these nodes (which would
325 # make restoring from a different checkpoint tricky), mark them as extra
326 # dependencies that are OK to overwrite.
327 for node in self.loaded_nodes.values():
328 node = node[0]
329 if not isinstance(node, base_layer.Layer):
330 # Loaded nodes can contain other trackable objects created when
331 # loading layers from the config, such as variables.
332 continue
333 for name in PUBLIC_ATTRIBUTES:
334 node._delete_tracking(name)
336 if isinstance(node, functional_lib.Functional):
337 # Delete the temporary layer dependencies, which were used to
338 # restore the checkpointed values. When the model is live, the
339 # user can delete or add layers to the model at any time, so
340 # these layer dependencies may be obsolete.
341 dependencies = list(node._self_unconditional_dependency_names)
342 for name in dependencies:
343 if (
344 re.match(r"^layer(_with_weights)?-[\d+]", name)
345 is not None
346 ):
347 node._delete_tracking(name)
349 def _add_children_recreated_from_config(self, obj, proto, node_id):
350 """Recursively records objects recreated from config."""
352 if node_id in self._traversed_nodes_from_config:
353 return
355 parent_path = self._node_paths[node_id]
356 self._traversed_nodes_from_config.add(node_id)
357 obj._maybe_initialize_trackable()
358 if isinstance(obj, base_layer.Layer) and not obj.built:
359 metadata = json_utils.decode(self._metadata[node_id].metadata)
360 self._try_build_layer(
361 obj, node_id, metadata.get("build_input_shape")
362 )
364 # Create list of all possible children
365 children = []
366 # Look for direct children
367 for reference in proto.children:
368 obj_child = obj._lookup_dependency(reference.local_name)
369 children.append(
370 (obj_child, reference.node_id, reference.local_name)
371 )
373 # Add metrics that may have been added to the layer._metrics list.
374 # This is stored in the SavedModel as layer.keras_api.layer_metrics in
375 # SavedModels created after Tf 2.2.
376 metric_list_node_id = self._search_for_child_node(
377 node_id, [constants.KERAS_ATTR, "layer_metrics"]
378 )
379 if metric_list_node_id is not None and hasattr(obj, "_metrics"):
380 obj_metrics = {m.name: m for m in obj._metrics}
381 for reference in self._proto.nodes[metric_list_node_id].children:
382 metric = obj_metrics.get(reference.local_name)
383 if metric is not None:
384 metric_path = "{}.layer_metrics.{}".format(
385 constants.KERAS_ATTR, reference.local_name
386 )
387 children.append((metric, reference.node_id, metric_path))
389 for obj_child, child_id, child_name in children:
390 child_proto = self._proto.nodes[child_id]
392 if not isinstance(obj_child, tf.__internal__.tracking.Trackable):
393 continue
394 if (
395 child_proto.user_object.identifier
396 in tf.__internal__.saved_model.load.registered_identifiers()
397 ):
398 setter = tf.__internal__.saved_model.load.get_setter(
399 child_proto.user_object
400 )
401 elif (
402 obj_child._object_identifier
403 in constants.KERAS_OBJECT_IDENTIFIERS
404 ):
405 setter = _revive_setter
406 else:
407 setter = setattr
409 if child_id in self.loaded_nodes:
410 if self.loaded_nodes[child_id][0] is not obj_child:
411 # This means that the same trackable object is referenced by
412 # two different objects that were recreated from the config.
413 logging.warning(
414 "Looks like there is an object (perhaps variable or "
415 "layer) that is shared between different "
416 "layers/models. This may cause issues when restoring "
417 "the variable values. Object: {}".format(obj_child)
418 )
419 continue
421 # Overwrite variable names with the ones saved in the SavedModel.
422 if (
423 child_proto.WhichOneof("kind") == "variable"
424 and child_proto.variable.name
425 ):
426 obj_child._handle_name = child_proto.variable.name + ":0"
428 if isinstance(
429 obj_child, tf.__internal__.tracking.TrackableDataStructure
430 ):
431 setter = lambda *args: None
433 child_path = f"{parent_path}.{child_name}"
434 self._node_paths[child_id] = child_path
435 self._add_children_recreated_from_config(
436 obj_child, child_proto, child_id
437 )
438 self.loaded_nodes[child_id] = obj_child, setter
440 def load_layers(self, compile=True):
441 """Load all layer nodes from the metadata."""
442 # Load metrics after models and layers, since it's likely that models
443 # and layers will create the metric when initialized (this avoids
444 # wasting time by creating objects multiple times).
445 metric_list = []
446 for node_metadata in self._metadata.values():
447 if node_metadata.identifier == constants.METRIC_IDENTIFIER:
448 metric_list.append(node_metadata)
449 continue
451 self.loaded_nodes[node_metadata.node_id] = self._load_layer(
452 node_metadata.node_id,
453 node_metadata.identifier,
454 node_metadata.metadata,
455 )
457 for node_metadata in metric_list:
458 try:
459 self.loaded_nodes[node_metadata.node_id] = self._load_layer(
460 node_metadata.node_id,
461 node_metadata.identifier,
462 node_metadata.metadata,
463 )
464 except ValueError as e:
465 # Metrics are only needed when the model is compiled later. We
466 # ignore errors when trying to load custom metrics when
467 # `compile=False` until custom metrics are serialized properly
468 # (b/135550038).
469 if compile:
470 raise e
471 logging.warning(
472 "Unable to restore custom metric. Please ensure that "
473 "the layer implements `get_config` and `from_config` "
474 "when saving. In addition, please use the "
475 "`custom_objects` arg when calling `load_model()`."
476 )
478 def _load_layer(self, node_id, identifier, metadata):
479 """Load a single layer from a SavedUserObject proto."""
480 metadata = json_utils.decode(metadata)
482 # If node was already created
483 if node_id in self.loaded_nodes:
484 node, setter = self.loaded_nodes[node_id]
486 # Revive setter requires the object to have a
487 # `_serialized_attributes` property. Add it here.
488 _maybe_add_serialized_attributes(node, metadata)
490 config = metadata.get("config")
491 if _is_graph_network(node) and serialization.validate_config(
492 config
493 ):
494 child_nodes = self._get_child_layer_node_ids(node_id)
495 self.model_layer_dependencies[node_id] = (node, child_nodes)
496 if not child_nodes:
497 self._models_to_reconstruct.append(node_id)
498 return node, setter
500 # Detect whether this object can be revived from the config. If not,
501 # then revive from the SavedModel instead.
502 obj, setter = self._revive_from_config(identifier, metadata, node_id)
503 if obj is None:
504 obj, setter = revive_custom_object(identifier, metadata)
506 # Add an attribute that stores the extra functions/objects saved in the
507 # SavedModel. Most of these functions/objects are ignored, but some are
508 # used later in the loading process (e.g. the list of regularization
509 # losses, or the training config of compiled models).
510 _maybe_add_serialized_attributes(obj, metadata)
511 return obj, setter
513 def _revive_from_config(self, identifier, metadata, node_id):
514 """Revives a layer/model from config, or returns None."""
515 if identifier == constants.METRIC_IDENTIFIER:
516 obj = self._revive_metric_from_config(metadata)
517 else:
518 obj = self._revive_graph_network(
519 identifier, metadata, node_id
520 ) or self._revive_layer_or_model_from_config(metadata, node_id)
522 if obj is None:
523 return None, None
525 setter = self._config_node_setter(_revive_setter)
526 self._add_children_recreated_from_config(
527 obj, self._proto.nodes[node_id], node_id
528 )
529 return obj, setter
531 def _revive_graph_network(self, identifier, metadata, node_id):
532 """Revives a graph network from config."""
533 # Determine whether the metadata contains information for reviving a
534 # functional or Sequential model.
535 config = metadata.get("config")
536 if not serialization.validate_config(config):
537 return None
539 class_name = tf.compat.as_str(metadata["class_name"])
540 if object_registration.get_registered_object(class_name) is not None:
541 return None
542 model_is_functional_or_sequential = (
543 metadata.get("is_graph_network", False)
544 or class_name == "Sequential"
545 or class_name == "Functional"
546 )
547 if not model_is_functional_or_sequential:
548 return None
550 # Revive functional and sequential models as blank model objects for now
551 # ( must be initialized to enable setattr tracking and attribute
552 # caching). Reconstruction of the network is deferred until all of the
553 # model's layers have been revived.
554 if class_name == "Sequential":
555 model = models_lib.Sequential(name=config["name"])
556 # The model is a custom Sequential model.
557 elif identifier == constants.SEQUENTIAL_IDENTIFIER:
558 # Uses the custom class name, since the config does not have one.
559 model = models_lib.Sequential(name=class_name)
560 else:
561 model = models_lib.Functional(
562 inputs=[], outputs=[], name=config["name"]
563 )
565 # Record this model and its layers. This will later be used to
566 # reconstruct the model.
567 layers = self._get_child_layer_node_ids(node_id)
568 self.model_layer_dependencies[node_id] = (model, layers)
569 if not layers:
570 self._models_to_reconstruct.append(node_id)
571 return model
573 def _revive_layer_or_model_from_config(self, metadata, node_id):
574 """Revives a layer/custom model from config; returns None if
575 infeasible."""
576 # Check that the following requirements are met for reviving from
577 # config:
578 # 1. Object can be deserialized from config.
579 # 2. If the object needs to be built, then the build input shape can
580 # be found.
581 class_name = metadata.get("class_name")
582 config = metadata.get("config")
583 shared_object_id = metadata.get("shared_object_id")
584 must_restore_from_config = metadata.get("must_restore_from_config")
585 if not serialization.validate_config(config):
586 return None
588 try:
589 try:
590 obj = model_config.model_from_config(
591 serialization.serialize_keras_class_and_config(
592 class_name, config, shared_object_id=shared_object_id
593 )
594 )
595 except (TypeError, KeyError) as e:
596 # A name conflict has occurred. The `class_name` is in the Keras
597 # native framework; however, the value in the framework is
598 # different from the user's class definition which confuses the
599 # KerasObjectLoader.
600 builtin_layer = layers_module.get_builtin_layer(class_name)
601 if builtin_layer:
602 raise RuntimeError(
603 f"Unable to restore object of class '{class_name}'. "
604 "One of several possible causes could be "
605 "a missing custom object. "
606 "Decorate your custom object with "
607 "`@keras.utils.register_keras_serializable()` and "
608 "include that file in your program, "
609 "or pass your class in a "
610 "`keras.utils.CustomObjectScope` "
611 "that wraps this load call. "
612 f"\n\nException: {e}"
613 ) from e
614 else:
615 raise
616 except Exception as e:
617 if must_restore_from_config:
618 raise e
619 else:
620 return None
622 # Use the dtype, name, and trainable status. Often times these are not
623 # specified in custom configs, so retrieve their values from the
624 # metadata.
626 obj._name = metadata["name"]
627 if metadata.get("trainable") is not None:
628 obj.trainable = metadata["trainable"]
629 if metadata.get("dtype") is not None:
630 obj._set_dtype_policy(metadata["dtype"])
631 if metadata.get("stateful") is not None:
632 obj.stateful = metadata["stateful"]
633 if metadata.get("autocast") is not None:
634 obj._autocast = metadata["autocast"]
635 # Restore model save spec for subclassed models. (layers do not store a
636 # SaveSpec)
637 if isinstance(obj, training_lib.Model):
638 full_save_spec = metadata.get("full_save_spec")
639 if full_save_spec is not None:
640 args_spec, kwargs_spec = full_save_spec
641 inputs_spec = args_spec.pop(0)
642 obj._set_save_spec(inputs_spec, args_spec, kwargs_spec)
644 build_input_shape = metadata.get("build_input_shape")
645 built = self._try_build_layer(obj, node_id, build_input_shape)
647 if not built:
648 # If the layer cannot be built, revive a custom layer instead.
649 return None
650 return obj
652 def _revive_metric_from_config(self, metadata):
653 """Revives a metric object using the config saved in the metadata."""
654 class_name = tf.compat.as_str(metadata["class_name"])
655 config = metadata.get("config")
657 if not serialization.validate_config(config):
658 return None
660 try:
661 obj = metrics.deserialize(
662 serialization.serialize_keras_class_and_config(
663 class_name, config
664 )
665 )
666 except ValueError:
667 return None
669 build_input_shape = metadata.get("build_input_shape")
670 if build_input_shape is not None and hasattr(obj, "_build"):
671 obj._build(build_input_shape)
673 return obj
675 def _try_build_layer(self, obj, node_id, build_input_shape):
676 """Attempts to build the layer."""
677 if obj.built or hasattr(obj.build, "_is_default"):
678 obj.built = True
679 return True
681 if build_input_shape is None:
682 build_input_shape = self._infer_inputs(
683 node_id, convert_to_shapes=True
684 )
686 if build_input_shape is not None:
687 obj.build(build_input_shape)
688 base_layer.Layer.build(obj, build_input_shape)
689 return True
691 return False
693 def get_path(self, node_id):
694 return self._node_paths[node_id]
696 def finalize_objects(self):
697 """Finish setting up Keras objects.
699 This function is executed after all objects and functions have been
700 created. Call functions and losses are attached to each layer, and once
701 all layers have been fully set up, graph networks are initialized.
703 Subclassed models that are revived from the SavedModel are treated like
704 layers, and have their call/loss functions attached here.
705 """
706 # Finish setting up layers and subclassed models. This step attaches
707 # call functions and losses to each object, and sets model
708 # inputs/outputs.
709 layers_revived_from_config = []
710 layers_revived_from_saved_model = []
711 for node_id, (node, _) in self.loaded_nodes.items():
712 if (
713 not isinstance(node, base_layer.Layer)
714 # Don't finalize models until all layers have finished loading.
715 or node_id in self.model_layer_dependencies
716 ):
717 continue
719 self._unblock_model_reconstruction(node_id, node)
721 if isinstance(node, input_layer.InputLayer):
722 continue
723 elif isinstance(node, metrics.Metric):
724 continue
726 if isinstance(node, (RevivedLayer, RevivedInputLayer)):
727 layers_revived_from_saved_model.append(node)
728 else:
729 layers_revived_from_config.append(node)
731 _finalize_saved_model_layers(layers_revived_from_saved_model)
732 _finalize_config_layers(layers_revived_from_config)
734 # Initialize graph networks, now that layer dependencies have been
735 # resolved.
736 self._reconstruct_all_models()
738 def _unblock_model_reconstruction(self, layer_id, layer):
739 """Removes layer from blocking model reconstruction."""
740 for model_id, v in self.model_layer_dependencies.items():
741 _, layers = v
742 if layer_id not in layers:
743 continue
744 layers[layers.index(layer_id)] = layer
745 if all(isinstance(x, base_layer.Layer) for x in layers):
746 self._models_to_reconstruct.append(model_id)
748 def _reconstruct_all_models(self):
749 """Reconstructs the network structure of all models."""
750 all_initialized_models = set()
751 while self._models_to_reconstruct:
752 model_id = self._models_to_reconstruct.pop(0)
753 all_initialized_models.add(model_id)
754 model, layers = self.model_layer_dependencies[model_id]
755 self._reconstruct_model(model_id, model, layers)
756 _finalize_config_layers([model])
758 if all_initialized_models != set(self.model_layer_dependencies.keys()):
759 # This should not happen.
760 uninitialized_model_ids = (
761 set(self.model_layer_dependencies.keys())
762 - all_initialized_models
763 )
764 uninitialized_model_names = [
765 self.model_layer_dependencies[model_id][0].name
766 for model_id in uninitialized_model_ids
767 ]
768 raise ValueError(
769 "Error loading model(s) in the SavedModel format. "
770 "The following model(s) could not be initialized: "
771 f"{uninitialized_model_names}"
772 )
774 def _reconstruct_model(self, model_id, model, layers):
775 """Reconstructs the network structure."""
776 config = json_utils.decode(self._metadata[model_id].metadata)["config"]
778 # Set up model inputs
779 if model.inputs:
780 # Inputs may already be created if the model is instantiated in
781 # another object's __init__.
782 pass
783 elif isinstance(model, models_lib.Sequential):
784 if not layers or not isinstance(layers[0], input_layer.InputLayer):
785 if config["layers"][0]["class_name"] == "InputLayer":
786 layers.insert(
787 0,
788 input_layer.InputLayer.from_config(
789 config["layers"][0]["config"]
790 ),
791 )
792 elif "batch_input_shape" in config["layers"][0]["config"]:
793 batch_input_shape = config["layers"][0]["config"][
794 "batch_input_shape"
795 ]
796 layers.insert(
797 0,
798 input_layer.InputLayer(
799 input_shape=batch_input_shape[1:],
800 batch_size=batch_input_shape[0],
801 dtype=layers[0].dtype,
802 name=layers[0].name + "_input",
803 ),
804 )
805 model.__init__(layers, name=config["name"])
806 if not model.inputs:
807 first_layer = self._get_child_layer_node_ids(model_id)[0]
808 input_specs = self._infer_inputs(first_layer)
809 input_shapes = self._infer_inputs(
810 first_layer, convert_to_shapes=True
811 )
812 model._set_inputs(input_specs)
813 if not model.built and not isinstance(input_specs, dict):
814 model.build(input_shapes)
815 else: # Reconstruct functional model
816 (
817 inputs,
818 outputs,
819 created_layers,
820 ) = functional_lib.reconstruct_from_config(
821 config, created_layers={layer.name: layer for layer in layers}
822 )
823 model.__init__(inputs, outputs, name=config["name"])
824 functional_lib.connect_ancillary_layers(model, created_layers)
826 # Set model dtype.
827 _set_network_attributes_from_metadata(model)
829 # Unblock models that are dependent on this model.
830 self._unblock_model_reconstruction(model_id, model)
832 def _get_child_layer_node_ids(self, node_id):
833 """Returns the node ids of each layer in a Sequential/Functional
834 model."""
835 # Sequential and Functional track layers with names following the format
836 # "layer-N". Use this to generate the list of layers.
837 num_layers = 0
838 child_layers = {}
839 pattern = re.compile("layer-(\\d+)")
841 for child in self._proto.nodes[node_id].children:
842 m = pattern.match(child.local_name)
843 if m is None:
844 continue
845 layer_n = int(m.group(1))
846 num_layers = max(layer_n + 1, num_layers)
847 child_layers[layer_n] = child.node_id
849 ordered = []
850 for n in range(num_layers):
851 child = child_layers.get(n)
852 if child is None:
853 break
854 ordered.append(child)
855 return ordered
857 def _search_for_child_node(self, parent_id, path_to_child):
858 """Returns node id of child node.
860 A helper method for traversing the object graph proto.
862 As an example, say that the object graph proto in the SavedModel
863 contains an object with the following child and grandchild attributes:
865 `parent.child_a.child_b`
867 This method can be used to retrieve the node id of `child_b` using the
868 parent's node id by calling:
870 `_search_for_child_node(parent_id, ['child_a', 'child_b'])`.
872 Args:
873 parent_id: node id of parent node
874 path_to_child: list of children names.
876 Returns:
877 node_id of child, or None if child isn't found.
878 """
879 if not path_to_child:
880 return parent_id
882 for child in self._proto.nodes[parent_id].children:
883 if child.local_name == path_to_child[0]:
884 return self._search_for_child_node(
885 child.node_id, path_to_child[1:]
886 )
887 return None
889 def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
890 """Infers input shape of layer from SavedModel functions."""
891 call_fn_id = self._search_for_child_node(
892 layer_node_id, ["call_and_return_all_conditional_losses"]
893 )
894 if call_fn_id is None:
895 return None
897 concrete_functions = self._proto.nodes[
898 call_fn_id
899 ].function.concrete_functions
900 if not concrete_functions:
901 return None
902 call_fn_name = concrete_functions[0]
903 call_fn_proto = self._proto.concrete_functions[call_fn_name]
904 structured_input_signature = tf.__internal__.saved_model.decode_proto(
905 call_fn_proto.canonicalized_input_signature
906 )
907 inputs = structured_input_signature[0][0]
908 if convert_to_shapes:
909 return tf.nest.map_structure(lambda spec: spec.shape, inputs)
910 else:
911 return inputs
913 def _config_node_setter(self, setter):
914 """Creates edges for nodes that are recreated from config."""
916 def setattr_wrapper(obj, name, value):
917 # Avoid overwriting attributes of objects recreated from the config.
918 if obj._lookup_dependency(name) is None:
919 setter(obj, name, value)
921 return setattr_wrapper
924def _finalize_saved_model_layers(layers):
925 """Runs the final steps of loading Keras Layers from SavedModel."""
927 # 1. Set up call functions for all layers initialized from the SavedModel (
928 # and not the config)
929 for layer in layers:
930 layer.built = True
931 layer_call = getattr(
932 _get_keras_attr(layer), "call_and_return_conditional_losses", None
933 )
934 if layer_call and layer_call.concrete_functions:
935 call_spec = layer_utils.CallFunctionSpec(
936 tf_inspect.getfullargspec(layer_call)
937 )
938 layer.call = utils.use_wrapped_call(
939 layer, layer_call, call_spec, return_method=True
940 )
941 expects_training_arg = layer._serialized_attributes["metadata"][
942 "expects_training_arg"
943 ]
944 if "training" in layer_call.function_spec.arg_names:
945 # This could change the value of `expects_training_arg` if this
946 # layer doesn't expect a training arg, but has a child layer
947 # that does.
948 expects_training_arg = True
949 layer._init_call_fn_args(expects_training_arg)
950 else:
951 layer.call = types.MethodType(
952 _unable_to_call_layer_due_to_serialization_issue, layer
953 )
955 for layer in layers:
956 # 2. Set model inputs and outputs.
957 if isinstance(layer, RevivedNetwork):
958 _set_network_attributes_from_metadata(layer)
960 if hasattr(
961 _get_keras_attr(layer), "call_and_return_conditional_losses"
962 ):
963 call_fn = _get_keras_attr(
964 layer
965 ).call_and_return_conditional_losses
966 if not call_fn.concrete_functions:
967 continue
968 if call_fn.input_signature is None:
969 args, kwargs = infer_inputs_from_restored_call_function(
970 call_fn
971 )
972 args = list(args)
973 inputs = args.pop(0)
974 else:
975 args = call_fn.input_signature
976 args = list(args)
977 inputs = args.pop(0)
978 kwargs = None
979 layer._set_save_spec(inputs, args, kwargs)
981 # V1 models require calling _set_inputs to set the `.inputs`
982 # attr. Skip this step when there are multiple tensor inputs
983 # (this behavior is not well supported in V1 models).
984 if not any(
985 isinstance(x, tf.TensorSpec)
986 for x in tf.nest.flatten([args, kwargs])
987 ):
988 layer._set_inputs(inputs)
990 # 3. Add losses that aren't generated by the layer.call function.
991 _restore_layer_unconditional_losses(layer)
992 _restore_layer_activation_loss(layer)
994 # 4. Restore metrics list
995 _restore_layer_metrics(layer)
998def _unable_to_call_layer_due_to_serialization_issue(
999 layer, *unused_args, **unused_kwargs
1000):
1001 """Replaces the `layer.call` if the layer was not fully serialized.
1003 Keras Model/Layer serialization is relatively relaxed because SavedModels
1004 are not always loaded back as keras models. Thus, when there is an issue
1005 tracing a non-signature function, a warning is logged instead of raising an
1006 error. This results in a SavedModel where the model's call function is
1007 saved, but the internal layer call functions are not.
1009 When deserialized with `tf.keras.models.load_model`, the internal layers
1010 which do not have serialized call functions should raise an error when
1011 called.
1013 Args:
1014 layer: Layer without the serialized call function.
1016 Raises:
1017 ValueError
1018 """
1020 raise ValueError(
1021 f"Cannot call custom layer {layer.name} of type {type(layer)}, because "
1022 "the call function was not serialized to the SavedModel."
1023 "Please try one of the following methods to fix this issue:"
1024 "\n\n(1) Implement `get_config` and `from_config` in the layer/model "
1025 "class, and pass the object to the `custom_objects` argument when "
1026 "loading the model. For more details, see: "
1027 "https://www.tensorflow.org/guide/keras/save_and_serialize"
1028 "\n\n(2) Ensure that the subclassed model or layer overwrites `call` "
1029 "and not `__call__`. The input shape and dtype will be automatically "
1030 "recorded when the object is called, and used when saving. To manually "
1031 "specify the input shape/dtype, decorate the call function with "
1032 "`@tf.function(input_signature=...)`."
1033 )
1036def _finalize_config_layers(layers):
1037 """Runs the final steps of loading Keras Layers from config."""
1038 for layer in layers:
1039 # It is assumed that layers define their unconditional losses after
1040 # being recreated from the config and built. The exceptions to this are
1041 # Functional and Sequential models, which only store conditional losses
1042 # (losses dependent on the inputs) in the config. Unconditional losses
1043 # like weight regularization must be revived from the SavedModel.
1044 if _is_graph_network(layer):
1045 _restore_layer_unconditional_losses(layer)
1047 # Some layers, like Dense, record their activation loss function in the
1048 # config. However, not all layers do this, so the activation loss may be
1049 # missing when restored from the config/hdf5.
1050 # TODO(kathywu): Investigate ways to improve the config to ensure
1051 # consistent loading behavior between HDF5 and SavedModel.
1052 _restore_layer_activation_loss(layer)
1054 # Restore metrics list.
1055 _restore_layer_metrics(layer)
1057 # Restore RNN layer states.
1058 if (
1059 isinstance(layer, base_rnn.RNN)
1060 and layer.stateful
1061 and hasattr(_get_keras_attr(layer), "states")
1062 ):
1063 layer.states = getattr(_get_keras_attr(layer), "states", None)
1064 for variable in tf.nest.flatten(layer.states):
1065 backend.track_variable(variable)
1067 # Perform any layer defined finalization of the layer state.
1068 layer.finalize_state()
1071def _finalize_metric(metric):
1072 metric.update_state = types.MethodType(
1073 metrics_utils.update_state_wrapper(metric.keras_api.update_state),
1074 metric,
1075 )
1076 metric.result = metric.keras_api.result
1079def _restore_layer_unconditional_losses(layer):
1080 """Restore unconditional losses from SavedModel."""
1081 if hasattr(_get_keras_attr(layer), "layer_regularization_losses"):
1082 losses = getattr(
1083 _get_keras_attr(layer), "layer_regularization_losses", []
1084 )
1085 else:
1086 # Some earlier SavedModels may not have layer_regularization_losses
1087 # serialized separately. Fall back to using the regularization_losses
1088 # list if it does not exist.
1089 losses = layer._serialized_attributes.get("regularization_losses", [])
1090 for loss in losses:
1091 layer.add_loss(loss)
1094def _restore_layer_activation_loss(layer):
1095 """Restore actiation loss from SavedModel."""
1096 # Use wrapped activity regularizer function if the layer's activity
1097 # regularizer wasn't created during initialization.
1098 activity_regularizer = getattr(
1099 _get_keras_attr(layer), "activity_regularizer_fn", None
1100 )
1101 if activity_regularizer and not layer.activity_regularizer:
1102 try:
1103 layer.activity_regularizer = activity_regularizer
1104 except AttributeError:
1105 # This may happen if a layer wrapper is saved with an activity
1106 # regularizer. The wrapper object's activity regularizer is
1107 # unsettable.
1108 pass
1111def revive_custom_object(identifier, metadata):
1112 """Revives object from SavedModel."""
1113 if tf.compat.v1.executing_eagerly_outside_functions():
1114 model_class = training_lib.Model
1115 else:
1116 model_class = training_lib_v1.Model
1118 revived_classes = {
1119 constants.INPUT_LAYER_IDENTIFIER: (
1120 RevivedInputLayer,
1121 input_layer.InputLayer,
1122 ),
1123 constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
1124 constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
1125 constants.NETWORK_IDENTIFIER: (
1126 RevivedNetwork,
1127 functional_lib.Functional,
1128 ),
1129 constants.SEQUENTIAL_IDENTIFIER: (
1130 RevivedNetwork,
1131 models_lib.Sequential,
1132 ),
1133 }
1134 parent_classes = revived_classes.get(identifier, None)
1136 class_name = tf.compat.as_str(metadata["class_name"])
1137 if parent_classes is not None:
1138 parent_classes = revived_classes[identifier]
1139 revived_cls = type(class_name, parent_classes, {})
1140 return revived_cls._init_from_metadata(metadata)
1141 else:
1142 raise ValueError(
1143 f'Unable to restore custom object of class "{class_name}" '
1144 f"(type {identifier}). Please make sure that this class is "
1145 "included in the `custom_objects` arg when calling `load_model()`. "
1146 "Also, check that the class implements `get_config` and "
1147 f"`from_config`.\n\nComplete metadata: {metadata}"
1148 )
1151def _restore_layer_metrics(layer):
1152 metrics_list = getattr(_get_keras_attr(layer), "layer_metrics", {})
1153 layer_metrics = {m.name: m for m in layer._metrics}
1154 for name, metric in metrics_list.items():
1155 if name not in layer_metrics:
1156 # Metrics may be added during initialization/building of custom
1157 # layers.
1158 layer._metrics.append(metric)
1161# TODO(kathywu): Centrally define keys and functions for both serialization and
1162# deserialization.
1163class RevivedLayer:
1164 """Keras layer loaded from a SavedModel."""
1166 @classmethod
1167 def _init_from_metadata(cls, metadata):
1168 """Create revived layer from metadata stored in the SavedModel proto."""
1169 init_args = dict(name=metadata["name"], trainable=metadata["trainable"])
1170 if metadata.get("dtype") is not None:
1171 init_args["dtype"] = metadata["dtype"]
1172 if metadata.get("batch_input_shape") is not None:
1173 init_args["batch_input_shape"] = metadata["batch_input_shape"]
1175 revived_obj = cls(**init_args)
1177 with utils.no_automatic_dependency_tracking_scope(revived_obj):
1179 revived_obj._call_spec.expects_training_arg = metadata[
1180 "expects_training_arg"
1181 ]
1182 config = metadata.get("config")
1183 if serialization.validate_config(config):
1184 revived_obj._config = config
1185 if metadata.get("input_spec") is not None:
1186 revived_obj.input_spec = recursively_deserialize_keras_object(
1187 metadata["input_spec"],
1188 module_objects={"InputSpec": input_spec.InputSpec},
1189 )
1190 if metadata.get("activity_regularizer") is not None:
1191 revived_obj.activity_regularizer = regularizers.deserialize(
1192 metadata["activity_regularizer"]
1193 )
1194 if metadata.get("_is_feature_layer") is not None:
1195 revived_obj._is_feature_layer = metadata["_is_feature_layer"]
1196 if metadata.get("stateful") is not None:
1197 revived_obj.stateful = metadata["stateful"]
1198 if metadata.get("autocast") is not None:
1199 revived_obj._autocast = metadata["autocast"]
1200 if metadata.get("preserve_input_structure_in_config") is not None:
1201 revived_obj._preserve_input_structure_in_config = metadata[
1202 "preserve_input_structure_in_config"
1203 ]
1205 return revived_obj, _revive_setter
1207 @property
1208 def keras_api(self):
1209 return self._serialized_attributes.get(constants.KERAS_ATTR, None)
1211 def get_config(self):
1212 if hasattr(self, "_config"):
1213 return self._config
1214 else:
1215 raise NotImplementedError
1218def _revive_setter(layer, name, value):
1219 """Setter function that saves some attributes to separate dictionary."""
1220 # Many attributes in the SavedModel conflict with properties defined in
1221 # Layer and Model. Save these attributes to a separate dictionary.
1222 if name in PUBLIC_ATTRIBUTES:
1224 if isinstance(value, tf.__internal__.tracking.Trackable):
1225 layer._track_trackable(value, name=name)
1226 layer._serialized_attributes[name] = value
1228 elif (
1229 isinstance(layer, functional_lib.Functional)
1230 and re.match(r"^layer(_with_weights)?-[\d+]", name) is not None
1231 ):
1232 # Edges named "layer-n" or "layer_with_weights-n", which are tracked in
1233 # network._track_layers, should not be added as an attribute. They
1234 # should be temporarily added as a dependency so that checkpointed
1235 # values can be restored. These dependencies are manually deleted in
1236 # KerasObjectLoader.del_tracking.
1238 # Set `overwrite=True` in the case that `layer` already tracks a
1239 # different layer-n. This may cause variable values to not be loaded
1240 # properly in the original layer-n, but we already warn the users about
1241 # this (ctrl-f "shared between different layers/models").
1242 layer._track_trackable(value, name, overwrite=True)
1243 elif getattr(layer, name, None) is not None:
1244 # Don't overwrite already defined attributes.
1245 pass
1246 else:
1247 setattr(layer, name, value)
1250class RevivedInputLayer:
1251 """InputLayer loaded from a SavedModel."""
1253 @classmethod
1254 def _init_from_metadata(cls, metadata):
1255 """Revives the saved InputLayer from the Metadata."""
1256 init_args = dict(
1257 name=metadata["name"],
1258 dtype=metadata["dtype"],
1259 sparse=metadata["sparse"],
1260 ragged=metadata["ragged"],
1261 batch_input_shape=metadata["batch_input_shape"],
1262 )
1263 revived_obj = cls(**init_args)
1264 with utils.no_automatic_dependency_tracking_scope(revived_obj):
1265 revived_obj._config = metadata["config"]
1267 return revived_obj, setattr
1269 def get_config(self):
1270 return self._config
1273def recursively_deserialize_keras_object(config, module_objects=None):
1274 """Deserialize Keras object from a nested structure."""
1275 if isinstance(config, dict):
1276 if "class_name" in config:
1277 return serialization.deserialize_keras_object(
1278 config, module_objects=module_objects
1279 )
1280 else:
1281 return {
1282 key: recursively_deserialize_keras_object(
1283 config[key], module_objects
1284 )
1285 for key in config
1286 }
1287 elif isinstance(config, (tuple, list)):
1288 return [
1289 recursively_deserialize_keras_object(x, module_objects)
1290 for x in config
1291 ]
1292 else:
1293 raise ValueError(
1294 "Unable to decode Keras layer config. Config should be a "
1295 f"dictionary, tuple or list. Received: config={config}"
1296 )
1299def infer_inputs_from_restored_call_function(fn):
1300 """Returns TypeSpec of inputs from a restored call function.
1302 Args:
1303 fn: Restored layer call function. It is assumed that `fn` has at least one
1304 concrete function and that the inputs are in the first argument.
1306 Returns:
1307 TypeSpec of call function inputs in the form of (args, kwargs)
1308 """
1310 def common_spec(x, y):
1311 if not isinstance(x, tf.TypeSpec):
1312 # Doesn't particularly matter what is returned in this case because
1313 # the result will be filtered out in _set_input_shape.
1314 return x
1316 result = x._without_tensor_names().most_specific_common_supertype(
1317 [y._without_tensor_names()]
1318 )
1319 if result is None:
1320 # Please file a bug if you are being hindered by this error.
1321 raise TypeError(f"No common supertype of {x} and {y}.")
1322 return result
1324 spec = fn.concrete_functions[0].structured_input_signature
1325 for concrete in fn.concrete_functions[1:]:
1326 spec2 = concrete.structured_input_signature
1327 spec = tf.nest.map_structure(common_spec, spec, spec2)
1328 return spec
1331class RevivedNetwork(RevivedLayer):
1332 """Keras network of layers loaded from a SavedModel."""
1334 @classmethod
1335 def _init_from_metadata(cls, metadata):
1336 """Create revived network from metadata stored in the SavedModel
1337 proto."""
1338 revived_obj = cls(name=metadata["name"])
1340 # Store attributes revived from SerializedAttributes in a un-tracked
1341 # dictionary. The attributes are the ones listed in CommonEndpoints or
1342 # "keras_api" for keras-specific attributes.
1343 with utils.no_automatic_dependency_tracking_scope(revived_obj):
1345 revived_obj._call_spec.expects_training_arg = metadata[
1346 "expects_training_arg"
1347 ]
1348 config = metadata.get("config")
1349 if serialization.validate_config(config):
1350 revived_obj._config = config
1352 if metadata.get("activity_regularizer") is not None:
1353 revived_obj.activity_regularizer = regularizers.deserialize(
1354 metadata["activity_regularizer"]
1355 )
1356 if metadata.get("autocast") is not None:
1357 revived_obj._autocast = metadata["autocast"]
1359 return revived_obj, _revive_setter
1362def _set_network_attributes_from_metadata(revived_obj):
1363 """Sets attributes recorded in the metadata."""
1364 with utils.no_automatic_dependency_tracking_scope(revived_obj):
1366 metadata = revived_obj._serialized_attributes["metadata"]
1367 if metadata.get("dtype") is not None:
1368 revived_obj._set_dtype_policy(metadata["dtype"])
1369 revived_obj._trainable = metadata["trainable"]
1372def _maybe_add_serialized_attributes(layer, metadata):
1373 # Store attributes revived from SerializedAttributes in a un-tracked
1374 # dictionary. The attributes are the ones listed in CommonEndpoints or
1375 # "keras_api" for keras-specific attributes.
1376 if not hasattr(layer, "_serialized_attributes"):
1377 with utils.no_automatic_dependency_tracking_scope(layer):
1378 layer._serialized_attributes = {"metadata": metadata}
1381def _get_keras_attr(layer):
1382 return getattr(layer, "_serialized_attributes", {}).get(
1383 constants.KERAS_ATTR, None
1384 )