Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py: 17%
579 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel deserialization."""
17import os
18import re
19import types
21from google.protobuf import message
23from tensorflow.python.eager import context
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.keras import backend
29from tensorflow.python.keras import regularizers
30from tensorflow.python.keras.engine import input_spec
31from tensorflow.python.keras.optimizer_v2 import optimizer_v2
32from tensorflow.python.keras.protobuf import saved_metadata_pb2
33from tensorflow.python.keras.protobuf import versions_pb2
34from tensorflow.python.keras.saving import saving_utils
35from tensorflow.python.keras.saving.saved_model import constants
36from tensorflow.python.keras.saving.saved_model import json_utils
37from tensorflow.python.keras.saving.saved_model import utils
38from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
39from tensorflow.python.keras.utils import generic_utils
40from tensorflow.python.keras.utils import metrics_utils
41from tensorflow.python.keras.utils.generic_utils import LazyLoader
42from tensorflow.python.ops.ragged import ragged_tensor
43from tensorflow.python.platform import gfile
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.saved_model import load as tf_load
46from tensorflow.python.saved_model import loader_impl
47from tensorflow.python.saved_model import nested_structure_coder
48from tensorflow.python.saved_model import revived_types
49from tensorflow.python.trackable import base as trackable
50from tensorflow.python.trackable import data_structures
51from tensorflow.python.util import compat
52from tensorflow.python.util import nest
54# To avoid circular dependencies between keras/engine and keras/saving,
55# code in keras/saving must delay imports.
57# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
58# once the issue with copybara is fixed.
59# pylint:disable=g-inconsistent-quotes
60models_lib = LazyLoader("models_lib", globals(),
61 "tensorflow.python.keras.models")
62base_layer = LazyLoader(
63 "base_layer", globals(),
64 "tensorflow.python.keras.engine.base_layer")
65layers_module = LazyLoader(
66 "layers_module", globals(),
67 "tensorflow.python.keras.layers")
68input_layer = LazyLoader(
69 "input_layer", globals(),
70 "tensorflow.python.keras.engine.input_layer")
71functional_lib = LazyLoader(
72 "functional_lib", globals(),
73 "tensorflow.python.keras.engine.functional")
74training_lib = LazyLoader(
75 "training_lib", globals(),
76 "tensorflow.python.keras.engine.training")
77training_lib_v1 = LazyLoader(
78 "training_lib_v1", globals(),
79 "tensorflow.python.keras.engine.training_v1")
80metrics = LazyLoader("metrics", globals(),
81 "tensorflow.python.keras.metrics")
82recurrent = LazyLoader(
83 "recurrent", globals(),
84 "tensorflow.python.keras.layers.recurrent")
85# pylint:enable=g-inconsistent-quotes
88PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
89 CommonEndpoints.all_checkpointable_objects)
90PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
93def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
94 """Loads Keras objects from a SavedModel.
96 Any Keras layer or model saved to the SavedModel will be loaded back
97 as Keras objects. Other objects are loaded as regular trackable objects (same
98 as `tf.saved_model.load`).
100 Currently, Keras saving/loading only retains the Keras object's weights,
101 losses, and call function.
103 The loaded model can be re-compiled, but the original optimizer, compiled loss
104 functions, and metrics are not retained. This is temporary, and `model.save`
105 will soon be able to serialize compiled models.
107 Args:
108 path: Path to SavedModel.
109 compile: If true, compile the model after loading it.
110 options: Optional `tf.saved_model.LoadOptions` object that specifies
111 options for loading from SavedModel.
114 Returns:
115 Object loaded from SavedModel.
116 """
117 # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
118 # TODO(kathywu): Add code to load from objects that contain all endpoints
120 # Look for metadata file or parse the SavedModel
121 metadata = saved_metadata_pb2.SavedMetadata()
122 meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
123 object_graph_def = meta_graph_def.object_graph_def
124 path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
125 if gfile.Exists(path_to_metadata_pb):
126 try:
127 with gfile.GFile(path_to_metadata_pb, 'rb') as f:
128 file_content = f.read()
129 metadata.ParseFromString(file_content)
130 except message.DecodeError as e:
131 raise IOError('Cannot parse keras metadata {}: {}.'
132 .format(path_to_metadata_pb, str(e)))
133 else:
134 logging.warning('SavedModel saved prior to TF 2.5 detected when loading '
135 'Keras model. Please ensure that you are saving the model '
136 'with model.save() or tf.keras.models.save_model(), *NOT* '
137 'tf.saved_model.save(). To confirm, there should be a file '
138 'named "keras_metadata.pb" in the SavedModel directory.')
139 _read_legacy_metadata(object_graph_def, metadata)
141 if not metadata.nodes:
142 # When there are no Keras objects, return the results from the core loader
143 return tf_load.load(path, options=options)
145 # Recreate layers and metrics using the info stored in the metadata.
146 keras_loader = KerasObjectLoader(metadata, object_graph_def)
147 keras_loader.load_layers(compile=compile)
149 # Generate a dictionary of all loaded nodes.
150 nodes_to_load = {'root': None}
151 for node_id, loaded_node in keras_loader.loaded_nodes.items():
152 nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
153 loaded = tf_load.load_partial(path, nodes_to_load, options=options)
155 # Finalize the loaded layers and remove the extra tracked dependencies.
156 keras_loader.finalize_objects()
157 keras_loader.del_tracking()
159 model = loaded['root']
161 # pylint: disable=protected-access
162 if isinstance(model, training_lib.Model) and compile:
163 # TODO(kathywu): Use compiled objects from SavedModel, instead of
164 # creating new objects from the training config.
165 training_config = model._serialized_attributes['metadata'].get(
166 'training_config', None)
167 if training_config is not None:
168 model.compile(**saving_utils.compile_args_from_training_config(
169 training_config), from_serialized=True)
170 saving_utils.try_build_compiled_arguments(model)
171 if isinstance(model.optimizer, optimizer_v2.OptimizerV2):
172 if (model.optimizer.get_slot_names()):
173 logging.warning('Your optimizer uses slots. '
174 'Slots cannot be restored from saved_model, '
175 'as a result, your model is starting with '
176 'a new initialized optimizer.')
177 else:
178 logging.warning('No training configuration found in save file, so the '
179 'model was *not* compiled. Compile it manually.')
180 # pylint: enable=protected-access
182 # Force variables and resources to initialize.
183 if not context.executing_eagerly():
184 sess = backend.get_session() # Variables are initialized by this call.
185 sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS))
187 return model
190def _read_legacy_metadata(object_graph_def, metadata):
191 """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
192 # Older SavedModels store the metadata directly in the proto instead of the
193 # separate pb file.
194 node_paths = _generate_object_paths(object_graph_def)
195 for node_id, proto in enumerate(object_graph_def.nodes):
196 if (proto.WhichOneof('kind') == 'user_object' and
197 proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
198 if not proto.user_object.metadata:
199 raise ValueError('Unable to create a Keras model from this SavedModel. '
200 'This SavedModel was created with '
201 '`tf.saved_model.save`, and lacks the Keras metadata.'
202 'Please save your Keras model by calling `model.save`'
203 'or `tf.keras.models.save_model`.')
204 metadata.nodes.add(
205 node_id=node_id,
206 node_path=node_paths[node_id],
207 version=versions_pb2.VersionDef(
208 producer=1, min_consumer=1, bad_consumers=[]),
209 identifier=proto.user_object.identifier,
210 metadata=proto.user_object.metadata)
213def _generate_object_paths(object_graph_def):
214 """Traverses through an ObjectGraphDef and builds a map of all node paths."""
215 paths = {0: 'root'}
216 nodes_to_visit = [0]
218 while nodes_to_visit:
219 current_node = nodes_to_visit.pop()
220 current_path = paths[current_node]
221 for reference in object_graph_def.nodes[current_node].children:
222 if reference.node_id in paths:
223 continue
224 paths[reference.node_id] = '{}.{}'.format(current_path,
225 reference.local_name)
226 nodes_to_visit.append(reference.node_id)
228 return paths
231def _is_graph_network(layer):
232 """Determines whether the layer is a graph network."""
233 # pylint: disable=protected-access
234 if isinstance(layer, RevivedNetwork):
235 return False
236 elif isinstance(layer, functional_lib.Functional):
237 return (layer._is_graph_network or
238 isinstance(layer, models_lib.Sequential))
239 return False
242class KerasObjectLoader(object):
243 """Loader that recreates Keras objects (e.g. layers, models).
245 Layers and models are revived from either the config or SavedModel following
246 these rules:
247 1. If object is a graph network (i.e. Sequential or Functional) then it will
248 be initialized using the structure from the config only after the children
249 layers have been created. Graph networks must be initialized with inputs
250 and outputs, so all child layers must be created beforehand.
251 2. If object's config exists and the class can be found, then revive from
252 config.
253 3. Object may have already been created if its parent was revived from config.
254 In this case, do nothing.
255 4. If nothing of the above applies, compose the various artifacts from the
256 SavedModel to create a subclassed layer or model. At this time, custom
257 metrics are not supported.
259 """
261 def __init__(self, metadata, object_graph_def):
262 self._metadata = {x.node_id: x for x in metadata.nodes}
263 self._proto = object_graph_def
265 self._node_paths = {node_data.node_id: node_data.node_path
266 for node_data in metadata.nodes}
267 self.loaded_nodes = {} # Maps node path -> loaded node
269 # Store all node ids that have already been traversed when tracking nodes
270 # that were recreated from the config.
271 self._traversed_nodes_from_config = set()
273 # Maps model id -> (blank model obj, list of child layer or their node ids)
274 # This tracks all layers in functional and sequential models. These models
275 # are only reconstructed after all of their child layers have been created.
276 self.model_layer_dependencies = {}
277 self._models_to_reconstruct = []
279 def del_tracking(self):
280 """Removes tracked references that are only used when loading the model."""
281 # Now that the node object has been fully loaded, and the checkpoint has
282 # been restored, the object no longer needs to track objects added from
283 # SerializedAttributes. (Note that saving a training checkpoint still
284 # functions correctly, because layers and variables are tracked separately
285 # by the Layer object.)
286 # TODO(kathywu): Instead of outright deleting these nodes (which would
287 # make restoring from a different checkpoint tricky), mark them as extra
288 # dependencies that are OK to overwrite.
289 for node in self.loaded_nodes.values():
290 node = node[0]
291 if not isinstance(node, base_layer.Layer):
292 # Loaded nodes can contain other trackable objects created when
293 # loading layers from the config, such as variables.
294 continue
295 for name in PUBLIC_ATTRIBUTES:
296 node._delete_tracking(name) # pylint: disable=protected-access
298 if isinstance(node, functional_lib.Functional):
299 # Delete the temporary layer dependencies, which were used to restore
300 # the checkpointed values. When the model is live, the user can delete
301 # or add layers to the model at any time, so these layer dependencies
302 # may be obsolete.
303 dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access
304 for name in dependencies:
305 if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
306 node._delete_tracking(name) # pylint: disable=protected-access
308 def _add_children_recreated_from_config(self, obj, proto, node_id):
309 """Recursively records objects recreated from config."""
310 # pylint: disable=protected-access
311 if node_id in self._traversed_nodes_from_config:
312 return
314 parent_path = self._node_paths[node_id]
315 self._traversed_nodes_from_config.add(node_id)
316 obj._maybe_initialize_trackable()
317 if isinstance(obj, base_layer.Layer) and not obj.built:
318 metadata = json_utils.decode(self._metadata[node_id].metadata)
319 self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))
321 # Create list of all possible children
322 children = []
323 # Look for direct children
324 for reference in proto.children:
325 obj_child = obj._lookup_dependency(reference.local_name)
326 children.append((obj_child, reference.node_id, reference.local_name))
328 # Add metrics that may have been added to the layer._metrics list.
329 # This is stored in the SavedModel as layer.keras_api.layer_metrics in
330 # SavedModels created after Tf 2.2.
331 metric_list_node_id = self._search_for_child_node(
332 node_id, [constants.KERAS_ATTR, 'layer_metrics'])
333 if metric_list_node_id is not None and hasattr(obj, '_metrics'):
334 obj_metrics = {m.name: m for m in obj._metrics}
335 for reference in self._proto.nodes[metric_list_node_id].children:
336 metric = obj_metrics.get(reference.local_name)
337 if metric is not None:
338 metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR,
339 reference.local_name)
340 children.append((metric, reference.node_id, metric_path))
342 for (obj_child, child_id, child_name) in children:
343 child_proto = self._proto.nodes[child_id]
345 if not isinstance(obj_child, trackable.Trackable):
346 continue
347 if (child_proto.user_object.identifier in
348 revived_types.registered_identifiers()):
349 setter = revived_types.get_setter(child_proto.user_object)
350 elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
351 setter = _revive_setter
352 else:
353 setter = setattr
354 # pylint: enable=protected-access
356 if child_id in self.loaded_nodes:
357 if self.loaded_nodes[child_id][0] is not obj_child:
358 # This means that the same trackable object is referenced by two
359 # different objects that were recreated from the config.
360 logging.warning(
361 'Looks like there is an object (perhaps variable or '
362 'layer) that is shared between different layers/models. '
363 'This may cause issues when restoring the variable '
364 'values. Object: {}'.format(obj_child))
365 continue
367 # Overwrite variable names with the ones saved in the SavedModel.
368 if (child_proto.WhichOneof('kind') == 'variable' and
369 child_proto.variable.name):
370 obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access
372 if isinstance(obj_child, data_structures.TrackableDataStructure):
373 setter = lambda *args: None
375 child_path = '{}.{}'.format(parent_path, child_name)
376 self._node_paths[child_id] = child_path
377 self._add_children_recreated_from_config(
378 obj_child, child_proto, child_id)
379 self.loaded_nodes[child_id] = obj_child, setter
381 def load_layers(self, compile=True): # pylint: disable=redefined-builtin
382 """Load all layer nodes from the metadata."""
383 # Load metrics after models and layers, since it's likely that models
384 # and layers will create the metric when initialized (this avoids wasting
385 # time by creating objects multiple times).
386 metric_list = []
387 for node_metadata in self._metadata.values():
388 if node_metadata.identifier == constants.METRIC_IDENTIFIER:
389 metric_list.append(node_metadata)
390 continue
392 self.loaded_nodes[node_metadata.node_id] = self._load_layer(
393 node_metadata.node_id, node_metadata.identifier,
394 node_metadata.metadata)
396 for node_metadata in metric_list:
397 try:
398 self.loaded_nodes[node_metadata.node_id] = self._load_layer(
399 node_metadata.node_id, node_metadata.identifier,
400 node_metadata.metadata)
401 except ValueError:
402 # Metrics are only needed when the model is compiled later. We ignore
403 # errors when trying to load custom metrics when `compile=False` until
404 # custom metrics are serialized properly (b/135550038).
405 if compile:
406 raise
407 logging.warning('Unable to restore custom metric. Please ensure that '
408 'the layer implements `get_config` and `from_config` '
409 'when saving. In addition, please use the '
410 '`custom_objects` arg when calling `load_model()`.')
412 def _load_layer(self, node_id, identifier, metadata):
413 """Load a single layer from a SavedUserObject proto."""
414 metadata = json_utils.decode(metadata)
416 # If node was already created
417 if node_id in self.loaded_nodes:
418 node, setter = self.loaded_nodes[node_id]
420 # Revive setter requires the object to have a `_serialized_attributes`
421 # property. Add it here.
422 _maybe_add_serialized_attributes(node, metadata)
424 config = metadata.get('config')
425 if _is_graph_network(node) and generic_utils.validate_config(config):
426 child_nodes = self._get_child_layer_node_ids(node_id)
427 self.model_layer_dependencies[node_id] = (node, child_nodes)
428 if not child_nodes:
429 self._models_to_reconstruct.append(node_id)
430 return node, setter
432 # Detect whether this object can be revived from the config. If not, then
433 # revive from the SavedModel instead.
434 obj, setter = self._revive_from_config(identifier, metadata, node_id)
435 if obj is None:
436 obj, setter = revive_custom_object(identifier, metadata)
438 # Add an attribute that stores the extra functions/objects saved in the
439 # SavedModel. Most of these functions/objects are ignored, but some are
440 # used later in the loading process (e.g. the list of regularization
441 # losses, or the training config of compiled models).
442 _maybe_add_serialized_attributes(obj, metadata)
443 return obj, setter
445 def _revive_from_config(self, identifier, metadata, node_id):
446 """Revives a layer/model from config, or returns None."""
447 if identifier == constants.METRIC_IDENTIFIER:
448 obj = self._revive_metric_from_config(metadata)
449 else:
450 obj = (
451 self._revive_graph_network(identifier, metadata, node_id) or
452 self._revive_layer_or_model_from_config(metadata, node_id))
454 if obj is None:
455 return None, None
457 setter = self._config_node_setter(_revive_setter)
458 self._add_children_recreated_from_config(
459 obj, self._proto.nodes[node_id], node_id)
460 return obj, setter
462 def _revive_graph_network(self, identifier, metadata, node_id):
463 """Revives a graph network from config."""
464 # Determine whether the metadata contains information for reviving a
465 # functional or Sequential model.
466 config = metadata.get('config')
467 if not generic_utils.validate_config(config):
468 return None
470 class_name = compat.as_str(metadata['class_name'])
471 if generic_utils.get_registered_object(class_name) is not None:
472 return None
473 model_is_functional_or_sequential = (
474 metadata.get('is_graph_network', False) or
475 class_name == 'Sequential' or
476 class_name == 'Functional')
477 if not model_is_functional_or_sequential:
478 return None
480 # Revive functional and sequential models as blank model objects for now (
481 # must be initialized to enable setattr tracking and attribute caching).
482 # Reconstruction of the network is deferred until all of the model's layers
483 # have been revived.
484 if class_name == 'Sequential':
485 model = models_lib.Sequential(name=config['name'])
486 # The model is a custom Sequential model.
487 elif identifier == constants.SEQUENTIAL_IDENTIFIER:
488 # Uses the custom class name, since the config does not have one.
489 model = models_lib.Sequential(name=class_name)
490 else:
491 model = models_lib.Functional(
492 inputs=[], outputs=[], name=config['name'])
494 # Record this model and its layers. This will later be used to reconstruct
495 # the model.
496 layers = self._get_child_layer_node_ids(node_id)
497 self.model_layer_dependencies[node_id] = (model, layers)
498 if not layers:
499 self._models_to_reconstruct.append(node_id)
500 return model
502 def _revive_layer_or_model_from_config(self, metadata, node_id):
503 """Revives a layer/custom model from config; returns None if infeasible."""
504 # Check that the following requirements are met for reviving from config:
505 # 1. Object can be deserialized from config.
506 # 2. If the object needs to be built, then the build input shape can be
507 # found.
508 class_name = metadata.get('class_name')
509 config = metadata.get('config')
510 shared_object_id = metadata.get('shared_object_id')
511 must_restore_from_config = metadata.get('must_restore_from_config')
512 if not generic_utils.validate_config(config):
513 return None
515 try:
516 obj = layers_module.deserialize(
517 generic_utils.serialize_keras_class_and_config(
518 class_name, config, shared_object_id=shared_object_id))
519 except ValueError:
520 if must_restore_from_config:
521 raise RuntimeError(
522 'Unable to restore a layer of class {cls}. Layers of '
523 'class {cls} require that the class be provided to '
524 'the model loading code, either by registering the '
525 'class using @keras.utils.register_keras_serializable '
526 'on the class def and including that file in your '
527 'program, or by passing the class in a '
528 'keras.utils.CustomObjectScope that wraps this load '
529 'call.'.format(cls=class_name))
530 else:
531 return None
533 # Use the dtype, name, and trainable status. Often times these are not
534 # specified in custom configs, so retrieve their values from the metadata.
535 # pylint: disable=protected-access
536 obj._name = metadata['name']
537 if metadata.get('trainable') is not None:
538 obj.trainable = metadata['trainable']
539 if metadata.get('dtype') is not None:
540 obj._set_dtype_policy(metadata['dtype'])
541 if metadata.get('stateful') is not None:
542 obj.stateful = metadata['stateful']
543 # Restore model save spec for subclassed models. (layers do not store a
544 # SaveSpec)
545 if isinstance(obj, training_lib.Model):
546 save_spec = metadata.get('save_spec')
547 if save_spec is not None:
548 obj._set_save_spec(save_spec)
549 # pylint: enable=protected-access
551 build_input_shape = metadata.get('build_input_shape')
552 built = self._try_build_layer(obj, node_id, build_input_shape)
554 if not built:
555 # If the layer cannot be built, revive a custom layer instead.
556 return None
557 return obj
559 def _revive_metric_from_config(self, metadata):
560 """Revives a metric object using the config saved in the metadata."""
561 class_name = compat.as_str(metadata['class_name'])
562 config = metadata.get('config')
564 if not generic_utils.validate_config(config):
565 return None
567 try:
568 obj = metrics.deserialize(
569 generic_utils.serialize_keras_class_and_config(class_name, config))
570 except ValueError:
571 return None
573 build_input_shape = metadata.get('build_input_shape')
574 if build_input_shape is not None and hasattr(obj, '_build'):
575 obj._build(build_input_shape) # pylint: disable=protected-access
577 return obj
579 def _try_build_layer(self, obj, node_id, build_input_shape):
580 """Attempts to build the layer."""
581 if obj.built or hasattr(obj.build, '_is_default'):
582 obj.built = True
583 return True
585 if build_input_shape is None:
586 build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
588 if build_input_shape is not None:
589 obj.build(build_input_shape)
590 base_layer.Layer.build(obj, build_input_shape)
591 return True
593 return False
595 def _load_edges(self):
596 """Add edges for all nodes that are not waiting on initialization."""
597 for node_id, proto in enumerate(self._proto.nodes):
598 if node_id not in self.model_layer_dependencies:
599 self._add_object_graph_edges(proto, node_id)
601 def get_path(self, node_id):
602 return self._node_paths[node_id]
604 def finalize_objects(self):
605 """Finish setting up Keras objects.
607 This function is executed after all objects and functions have been created.
608 Call functions and losses are attached to each layer, and once all layers
609 have been fully set up, graph networks are initialized.
611 Subclassed models that are revived from the SavedModel are treated like
612 layers, and have their call/loss functions attached here.
613 """
614 # Finish setting up layers and subclassed models. This step attaches call
615 # functions and losses to each object, and sets model inputs/outputs.
616 layers_revived_from_config = []
617 layers_revived_from_saved_model = []
618 for node_id, (node, _) in self.loaded_nodes.items():
619 if (not isinstance(node, base_layer.Layer) or
620 # Don't finalize models until all layers have finished loading.
621 node_id in self.model_layer_dependencies):
622 continue
624 self._unblock_model_reconstruction(node_id, node)
626 if isinstance(node, input_layer.InputLayer):
627 continue
628 elif isinstance(node, metrics.Metric):
629 continue
631 if isinstance(node, (RevivedLayer, RevivedInputLayer)):
632 layers_revived_from_saved_model.append(node)
633 else:
634 layers_revived_from_config.append(node)
636 _finalize_saved_model_layers(layers_revived_from_saved_model)
637 _finalize_config_layers(layers_revived_from_config)
639 # Initialize graph networks, now that layer dependencies have been resolved.
640 self._reconstruct_all_models()
642 def _unblock_model_reconstruction(self, layer_id, layer):
643 """Removes layer from blocking model reconstruction."""
644 for model_id, v in self.model_layer_dependencies.items():
645 _, layers = v
646 if layer_id not in layers:
647 continue
648 layers[layers.index(layer_id)] = layer
649 if all(isinstance(x, base_layer.Layer) for x in layers):
650 self._models_to_reconstruct.append(model_id)
652 def _reconstruct_all_models(self):
653 """Reconstructs the network structure of all models."""
654 all_initialized_models = set()
655 while self._models_to_reconstruct:
656 model_id = self._models_to_reconstruct.pop(0)
657 all_initialized_models.add(model_id)
658 model, layers = self.model_layer_dependencies[model_id]
659 self._reconstruct_model(model_id, model, layers)
660 _finalize_config_layers([model])
662 if all_initialized_models != set(self.model_layer_dependencies.keys()):
663 # This should not happen.
664 uninitialized_model_ids = (
665 set(self.model_layer_dependencies.keys()) - all_initialized_models)
666 uninitialized_model_names = [
667 self.model_layer_dependencies[model_id][0].name
668 for model_id in uninitialized_model_ids]
669 raise ValueError('Error when loading from SavedModel -- the following '
670 'models could not be initialized: {}'
671 .format(uninitialized_model_names))
673 def _reconstruct_model(self, model_id, model, layers):
674 """Reconstructs the network structure."""
675 config = json_utils.decode(self._metadata[model_id].metadata)['config']
677 # Set up model inputs
678 if model.inputs:
679 # Inputs may already be created if the model is instantiated in another
680 # object's __init__.
681 pass
682 elif isinstance(model, models_lib.Sequential):
683 if not layers or not isinstance(layers[0], input_layer.InputLayer):
684 if config['layers'][0]['class_name'] == 'InputLayer':
685 layers.insert(0, input_layer.InputLayer.from_config(
686 config['layers'][0]['config']))
687 elif 'batch_input_shape' in config['layers'][0]['config']:
688 batch_input_shape = config['layers'][0]['config']['batch_input_shape']
689 layers.insert(0, input_layer.InputLayer(
690 input_shape=batch_input_shape[1:],
691 batch_size=batch_input_shape[0],
692 dtype=layers[0].dtype,
693 name=layers[0].name + '_input'))
694 model.__init__(layers, name=config['name'])
695 if not model.inputs:
696 first_layer = self._get_child_layer_node_ids(model_id)[0]
697 input_specs = self._infer_inputs(first_layer)
698 input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
699 model._set_inputs(input_specs) # pylint: disable=protected-access
700 if not model.built and not isinstance(input_specs, dict):
701 model.build(input_shapes)
702 else: # Reconstruct functional model
703 (inputs, outputs,
704 created_layers) = functional_lib.reconstruct_from_config(
705 config, created_layers={layer.name: layer for layer in layers})
706 model.__init__(inputs, outputs, name=config['name'])
707 functional_lib.connect_ancillary_layers(model, created_layers)
709 # Set model dtype.
710 _set_network_attributes_from_metadata(model)
712 # Unblock models that are dependent on this model.
713 self._unblock_model_reconstruction(model_id, model)
715 def _get_child_layer_node_ids(self, node_id):
716 """Returns the node ids of each layer in a Sequential/Functional model."""
717 # Sequential and Functional track layers with names following the format
718 # "layer-N". Use this to generate the list of layers.
719 num_layers = 0
720 child_layers = {}
721 pattern = re.compile('layer-(\\d+)')
723 for child in self._proto.nodes[node_id].children:
724 m = pattern.match(child.local_name)
725 if m is None:
726 continue
727 layer_n = int(m.group(1))
728 num_layers = max(layer_n + 1, num_layers)
729 child_layers[layer_n] = child.node_id
731 ordered = []
732 for n in range(num_layers):
733 child = child_layers.get(n)
734 if child is None:
735 break
736 ordered.append(child)
737 return ordered
739 def _search_for_child_node(self, parent_id, path_to_child):
740 """Returns node id of child node.
742 A helper method for traversing the object graph proto.
744 As an example, say that the object graph proto in the SavedModel contains an
745 object with the following child and grandchild attributes:
747 `parent.child_a.child_b`
749 This method can be used to retrieve the node id of `child_b` using the
750 parent's node id by calling:
752 `_search_for_child_node(parent_id, ['child_a', 'child_b'])`.
754 Args:
755 parent_id: node id of parent node
756 path_to_child: list of children names.
758 Returns:
759 node_id of child, or None if child isn't found.
760 """
761 if not path_to_child:
762 return parent_id
764 for child in self._proto.nodes[parent_id].children:
765 if child.local_name == path_to_child[0]:
766 return self._search_for_child_node(child.node_id, path_to_child[1:])
767 return None
769 def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
770 """Infers input shape of layer from SavedModel functions."""
771 call_fn_id = self._search_for_child_node(
772 layer_node_id, ['call_and_return_all_conditional_losses'])
773 if call_fn_id is None:
774 return None
776 concrete_functions = (
777 self._proto.nodes[call_fn_id].function.concrete_functions)
778 if not concrete_functions:
779 return None
780 call_fn_name = concrete_functions[0]
781 call_fn_proto = self._proto.concrete_functions[call_fn_name]
782 structured_input_signature = nested_structure_coder.decode_proto(
783 call_fn_proto.canonicalized_input_signature)
784 inputs = structured_input_signature[0][0]
785 if convert_to_shapes:
786 return nest.map_structure(lambda spec: spec.shape, inputs)
787 else:
788 return inputs
790 def _config_node_setter(self, setter):
791 """Creates edges for nodes that are recreated from config."""
792 def setattr_wrapper(obj, name, value):
793 # Avoid overwriting attributes of objects recreated from the config.
794 if obj._lookup_dependency(name) is None: # pylint: disable=protected-access
795 setter(obj, name, value)
796 return setattr_wrapper
799def _finalize_saved_model_layers(layers):
800 """Runs the final steps of loading Keras Layers from SavedModel."""
801 # pylint: disable=protected-access
802 # 1. Set up call functions for all layers initialized from the SavedModel (
803 # and not the config)
804 for layer in layers:
805 layer.built = True
806 layer_call = getattr(_get_keras_attr(layer),
807 'call_and_return_conditional_losses', None)
808 if layer_call and layer_call.concrete_functions:
809 layer.call = utils.use_wrapped_call(
810 layer, layer_call, return_method=True)
811 expects_training_arg = layer._serialized_attributes['metadata'][
812 'expects_training_arg']
813 if 'training' in layer_call.function_spec.arg_names:
814 # This could change the value of `expects_training_arg` if this layer
815 # doesn't expect a training arg, but has a child layer that does.
816 expects_training_arg = True
817 layer._init_call_fn_args(expects_training_arg)
818 else:
819 layer.call = types.MethodType(
820 _unable_to_call_layer_due_to_serialization_issue, layer)
822 for layer in layers:
823 # 2. Set model inputs and outputs.
824 if isinstance(layer, RevivedNetwork):
825 _set_network_attributes_from_metadata(layer)
827 if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
828 call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
829 if not call_fn.concrete_functions:
830 continue
831 if call_fn.input_signature is None:
832 inputs = infer_inputs_from_restored_call_function(call_fn)
833 else:
834 inputs = call_fn.input_signature[0]
835 layer._set_inputs(inputs) # pylint: disable=protected-access
837 # 3. Add losses that aren't generated by the layer.call function.
838 _restore_layer_unconditional_losses(layer)
839 _restore_layer_activation_loss(layer)
841 # 4. Restore metrics list
842 _restore_layer_metrics(layer)
844 # pylint: enable=protected-access
847def _unable_to_call_layer_due_to_serialization_issue(
848 layer, *unused_args, **unused_kwargs):
849 """Replaces the `layer.call` if the layer was not fully serialized.
851 Keras Model/Layer serialization is relatively relaxed because SavedModels
852 are not always loaded back as keras models. Thus, when there is an issue
853 tracing a non-signature function, a warning is logged instead of raising an
854 error. This results in a SavedModel where the model's call function is saved,
855 but the internal layer call functions are not.
857 When deserialized with `tf.keras.models.load_model`, the internal layers
858 which do not have serialized call functions should raise an error when called.
860 Args:
861 layer: Layer without the serialized call function.
863 Raises:
864 ValueError
865 """
867 raise ValueError(
868 'Cannot call custom layer {} of type {}, because the call function was '
869 'not serialized to the SavedModel.'
870 'Please try one of the following methods to fix this issue:'
871 '\n\n(1) Implement `get_config` and `from_config` in the layer/model '
872 'class, and pass the object to the `custom_objects` argument when '
873 'loading the model. For more details, see: '
874 'https://www.tensorflow.org/guide/keras/save_and_serialize'
875 '\n\n(2) Ensure that the subclassed model or layer overwrites `call` '
876 'and not `__call__`. The input shape and dtype will be automatically '
877 'recorded when the object is called, and used when saving. To manually '
878 'specify the input shape/dtype, decorate the call function with '
879 '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer)))
882def _finalize_config_layers(layers):
883 """Runs the final steps of loading Keras Layers from config."""
884 for layer in layers:
885 # It is assumed that layers define their unconditional losses after being
886 # recreated from the config and built. The exceptions to this
887 # are Functional and Sequential models, which only store conditional losses
888 # (losses dependent on the inputs) in the config. Unconditional losses like
889 # weight regularization must be revived from the SavedModel.
890 if _is_graph_network(layer):
891 _restore_layer_unconditional_losses(layer)
893 # Some layers, like Dense, record their activation loss function in the
894 # config. However, not all layers do this, so the activation loss may be
895 # missing when restored from the config/hdf5.
896 # TODO(kathywu): Investigate ways to improve the config to ensure consistent
897 # loading behavior between HDF5 and SavedModel.
898 _restore_layer_activation_loss(layer)
900 # Restore metrics list.
901 _restore_layer_metrics(layer)
903 # Restore RNN layer states.
904 if (isinstance(layer, recurrent.RNN) and
905 layer.stateful and
906 hasattr(_get_keras_attr(layer), 'states')):
907 layer.states = getattr(_get_keras_attr(layer), 'states', None)
908 for variable in nest.flatten(layer.states):
909 backend.track_variable(variable)
911 # Perform any layer defined finalization of the layer state.
912 layer.finalize_state()
915def _finalize_metric(metric):
916 metric.update_state = types.MethodType(metrics_utils.update_state_wrapper(
917 metric.keras_api.update_state), metric)
918 metric.result = metric.keras_api.result
921def _restore_layer_unconditional_losses(layer):
922 """Restore unconditional losses from SavedModel."""
923 if hasattr(_get_keras_attr(layer), 'layer_regularization_losses'):
924 losses = getattr(_get_keras_attr(layer), 'layer_regularization_losses', [])
925 else:
926 # Some earlier SavedModels may not have layer_regularization_losses
927 # serialized separately. Fall back to using the regularization_losses
928 # list if it does not exist.
929 losses = layer._serialized_attributes.get('regularization_losses', []) # pylint: disable=protected-access
930 for loss in losses:
931 layer.add_loss(loss)
934def _restore_layer_activation_loss(layer):
935 """Restore actiation loss from SavedModel."""
936 # Use wrapped activity regularizer function if the layer's activity
937 # regularizer wasn't created during initialization.
938 activity_regularizer = getattr(_get_keras_attr(layer),
939 'activity_regularizer_fn', None)
940 if activity_regularizer and not layer.activity_regularizer:
941 try:
942 layer.activity_regularizer = activity_regularizer
943 except AttributeError:
944 # This may happen if a layer wrapper is saved with an activity
945 # regularizer. The wrapper object's activity regularizer is unsettable.
946 pass
949def revive_custom_object(identifier, metadata):
950 """Revives object from SavedModel."""
951 if ops.executing_eagerly_outside_functions():
952 model_class = training_lib.Model
953 else:
954 model_class = training_lib_v1.Model
956 revived_classes = {
957 constants.INPUT_LAYER_IDENTIFIER: (
958 RevivedInputLayer, input_layer.InputLayer),
959 constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
960 constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
961 constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional),
962 constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential),
963 }
964 parent_classes = revived_classes.get(identifier, None)
966 if parent_classes is not None:
967 parent_classes = revived_classes[identifier]
968 revived_cls = type(
969 compat.as_str(metadata['class_name']), parent_classes, {})
970 return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access
971 else:
972 raise ValueError('Unable to restore custom object of type {} currently. '
973 'Please make sure that the layer implements `get_config`'
974 'and `from_config` when saving. In addition, please use '
975 'the `custom_objects` arg when calling `load_model()`.'
976 .format(identifier))
979def _restore_layer_metrics(layer):
980 metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {})
981 layer_metrics = {m.name: m for m in layer._metrics} # pylint: disable=protected-access
982 for name, metric in metrics_list.items():
983 if name not in layer_metrics:
984 # Metrics may be added during initialization/building of custom layers.
985 layer._metrics.append(metric) # pylint: disable=protected-access
988# TODO(kathywu): Centrally define keys and functions for both serialization and
989# deserialization.
990class RevivedLayer(object):
991 """Keras layer loaded from a SavedModel."""
993 @classmethod
994 def _init_from_metadata(cls, metadata):
995 """Create revived layer from metadata stored in the SavedModel proto."""
996 init_args = dict(
997 name=metadata['name'],
998 trainable=metadata['trainable'])
999 if metadata.get('dtype') is not None:
1000 init_args['dtype'] = metadata['dtype']
1001 if metadata.get('batch_input_shape') is not None:
1002 init_args['batch_input_shape'] = metadata['batch_input_shape']
1004 revived_obj = cls(**init_args)
1006 with utils.no_automatic_dependency_tracking_scope(revived_obj):
1007 # pylint:disable=protected-access
1008 revived_obj._expects_training_arg = metadata['expects_training_arg']
1009 config = metadata.get('config')
1010 if generic_utils.validate_config(config):
1011 revived_obj._config = config
1012 if metadata.get('input_spec') is not None:
1013 revived_obj.input_spec = recursively_deserialize_keras_object(
1014 metadata['input_spec'],
1015 module_objects={'InputSpec': input_spec.InputSpec})
1016 if metadata.get('activity_regularizer') is not None:
1017 revived_obj.activity_regularizer = regularizers.deserialize(
1018 metadata['activity_regularizer'])
1019 if metadata.get('_is_feature_layer') is not None:
1020 revived_obj._is_feature_layer = metadata['_is_feature_layer']
1021 if metadata.get('stateful') is not None:
1022 revived_obj.stateful = metadata['stateful']
1023 # pylint:enable=protected-access
1025 return revived_obj, _revive_setter
1027 @property
1028 def keras_api(self):
1029 return self._serialized_attributes.get(constants.KERAS_ATTR, None)
1031 def get_config(self):
1032 if hasattr(self, '_config'):
1033 return self._config
1034 else:
1035 raise NotImplementedError
1038def _revive_setter(layer, name, value):
1039 """Setter function that saves some attributes to separate dictionary."""
1040 # Many attributes in the SavedModel conflict with properties defined in
1041 # Layer and Model. Save these attributes to a separate dictionary.
1042 if name in PUBLIC_ATTRIBUTES:
1043 # pylint: disable=protected-access
1044 if isinstance(value, trackable.Trackable):
1045 layer._track_trackable(value, name=name)
1046 layer._serialized_attributes[name] = value
1047 # pylint: enable=protected-access
1048 elif (isinstance(layer, functional_lib.Functional) and
1049 re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
1050 # Edges named "layer-n" or "layer_with_weights-n", which are tracked in
1051 # network._track_layers, should not be added as an attribute. They should
1052 # be temporarily added as a dependency so that checkpointed values can be
1053 # restored. These dependencies are manually deleted in
1054 # KerasObjectLoader.del_tracking.
1056 # Set `overwrite=True` in the case that `layer` already tracks a different
1057 # layer-n. This may cause variable values to not be loaded properly in the
1058 # original layer-n, but we already warn the users about this
1059 # (ctrl-f "shared between different layers/models").
1060 layer._track_trackable(value, name, overwrite=True) # pylint: disable=protected-access
1061 elif getattr(layer, name, None) is not None:
1062 # Don't overwrite already defined attributes.
1063 pass
1064 else:
1065 setattr(layer, name, value)
1068class RevivedInputLayer(object):
1069 """InputLayer loaded from a SavedModel."""
1071 @classmethod
1072 def _init_from_metadata(cls, metadata):
1073 """Revives the saved InputLayer from the Metadata."""
1074 init_args = dict(
1075 name=metadata['name'],
1076 dtype=metadata['dtype'],
1077 sparse=metadata['sparse'],
1078 ragged=metadata['ragged'],
1079 batch_input_shape=metadata['batch_input_shape'])
1080 revived_obj = cls(**init_args)
1081 with utils.no_automatic_dependency_tracking_scope(revived_obj):
1082 revived_obj._config = metadata['config'] # pylint:disable=protected-access
1084 return revived_obj, setattr
1086 def get_config(self):
1087 return self._config
1090def recursively_deserialize_keras_object(config, module_objects=None):
1091 """Deserialize Keras object from a nested structure."""
1092 if isinstance(config, dict):
1093 if 'class_name' in config:
1094 return generic_utils.deserialize_keras_object(
1095 config, module_objects=module_objects)
1096 else:
1097 return {key: recursively_deserialize_keras_object(config[key],
1098 module_objects)
1099 for key in config}
1100 if isinstance(config, (tuple, list)):
1101 return [recursively_deserialize_keras_object(x, module_objects)
1102 for x in config]
1103 else:
1104 raise ValueError('Unable to decode config: {}'.format(config))
1107def get_common_shape(x, y):
1108 """Find a `TensorShape` that is compatible with both `x` and `y`."""
1109 if x is None != y is None:
1110 raise RuntimeError(
1111 'Cannot find a common shape when LHS shape is None but RHS shape '
1112 'is not (or vice versa): %s vs. %s' % (x, y))
1113 if x is None:
1114 return None # The associated input was not a Tensor, no shape generated.
1115 if not isinstance(x, tensor_shape.TensorShape):
1116 raise TypeError('Expected x to be a TensorShape but saw %s' % (x,))
1117 if not isinstance(y, tensor_shape.TensorShape):
1118 raise TypeError('Expected y to be a TensorShape but saw %s' % (y,))
1119 if x.rank != y.rank or x.rank is None:
1120 return tensor_shape.TensorShape(None)
1121 dims = []
1122 for dim_x, dim_y in zip(x.dims, y.dims):
1123 if (dim_x != dim_y
1124 or tensor_shape.dimension_value(dim_x) is None
1125 or tensor_shape.dimension_value(dim_y) is None):
1126 dims.append(None)
1127 else:
1128 dims.append(tensor_shape.dimension_value(dim_x))
1129 return tensor_shape.TensorShape(dims)
1132def infer_inputs_from_restored_call_function(fn):
1133 """Returns TensorSpec of inputs from a restored call function.
1135 Args:
1136 fn: Restored layer call function. It is assumed that `fn` has at least
1137 one concrete function and that the inputs are in the first argument.
1139 Returns:
1140 TensorSpec of call function inputs.
1141 """
1142 def common_spec(x, y):
1143 common_shape = get_common_shape(x.shape, y.shape)
1144 if isinstance(x, sparse_tensor.SparseTensorSpec):
1145 return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
1146 elif isinstance(x, ragged_tensor.RaggedTensorSpec):
1147 return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype)
1148 return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
1150 spec = fn.concrete_functions[0].structured_input_signature[0][0]
1151 for concrete in fn.concrete_functions[1:]:
1152 spec2 = concrete.structured_input_signature[0][0]
1153 spec = nest.map_structure(common_spec, spec, spec2)
1154 return spec
1157class RevivedNetwork(RevivedLayer):
1158 """Keras network of layers loaded from a SavedModel."""
1160 @classmethod
1161 def _init_from_metadata(cls, metadata):
1162 """Create revived network from metadata stored in the SavedModel proto."""
1163 revived_obj = cls(name=metadata['name'])
1165 # Store attributes revived from SerializedAttributes in a un-tracked
1166 # dictionary. The attributes are the ones listed in CommonEndpoints or
1167 # "keras_api" for keras-specific attributes.
1168 with utils.no_automatic_dependency_tracking_scope(revived_obj):
1169 # pylint:disable=protected-access
1170 revived_obj._expects_training_arg = metadata['expects_training_arg']
1171 config = metadata.get('config')
1172 if generic_utils.validate_config(config):
1173 revived_obj._config = config
1175 if metadata.get('activity_regularizer') is not None:
1176 revived_obj.activity_regularizer = regularizers.deserialize(
1177 metadata['activity_regularizer'])
1178 # pylint:enable=protected-access
1180 return revived_obj, _revive_setter # pylint:disable=protected-access
1183def _set_network_attributes_from_metadata(revived_obj):
1184 """Sets attributes recorded in the metadata."""
1185 with utils.no_automatic_dependency_tracking_scope(revived_obj):
1186 # pylint:disable=protected-access
1187 metadata = revived_obj._serialized_attributes['metadata']
1188 if metadata.get('dtype') is not None:
1189 revived_obj._set_dtype_policy(metadata['dtype'])
1190 revived_obj._trainable = metadata['trainable']
1191 # pylint:enable=protected-access
1194def _maybe_add_serialized_attributes(layer, metadata):
1195 # Store attributes revived from SerializedAttributes in a un-tracked
1196 # dictionary. The attributes are the ones listed in CommonEndpoints or
1197 # "keras_api" for keras-specific attributes.
1198 if not hasattr(layer, '_serialized_attributes'):
1199 with utils.no_automatic_dependency_tracking_scope(layer):
1200 layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
1203def _get_keras_attr(layer):
1204 return getattr(layer, '_serialized_attributes', {}).get(constants.KERAS_ATTR,
1205 None)