Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py: 19%
461 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Import a trackable object from a SavedModel."""
17import collections
18import functools
19import os
20import sys
22from absl import logging
24from tensorflow.core.framework import graph_debug_info_pb2
25from tensorflow.core.function.capture import restore_captures
26from tensorflow.python.checkpoint import checkpoint
27from tensorflow.python.checkpoint import checkpoint_options
28from tensorflow.python.checkpoint import graph_view
29from tensorflow.python.checkpoint import restore
30from tensorflow.python.distribute import distribute_lib
31from tensorflow.python.distribute import distribute_utils
32from tensorflow.python.distribute import values_util
33from tensorflow.python.eager import context
34from tensorflow.python.eager import function
35from tensorflow.python.eager.polymorphic_function import saved_model_utils as function_saved_model_utils
36from tensorflow.python.framework import config
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import errors
40from tensorflow.python.framework import ops
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import control_flow_assert
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import lookup_ops
45from tensorflow.python.ops import resource_variable_ops
46from tensorflow.python.ops import variables
47from tensorflow.python.saved_model import fingerprinting
48from tensorflow.python.saved_model import fingerprinting_utils
49from tensorflow.python.saved_model import function_deserialization
50from tensorflow.python.saved_model import load_options
51from tensorflow.python.saved_model import load_v1_in_v2
52from tensorflow.python.saved_model import loader_impl
53from tensorflow.python.saved_model import path_helpers
54from tensorflow.python.saved_model import registration
55from tensorflow.python.saved_model import revived_types
56from tensorflow.python.saved_model import utils_impl as saved_model_utils
57from tensorflow.python.saved_model.pywrap_saved_model import metrics
58from tensorflow.python.trackable import asset
59from tensorflow.python.trackable import autotrackable
60from tensorflow.python.trackable import base
61from tensorflow.python.trackable import data_structures
62from tensorflow.python.trackable import resource
63from tensorflow.python.trackable import trackable_utils
64from tensorflow.python.training import py_checkpoint_reader
65from tensorflow.python.training.saving import saveable_object_util
66from tensorflow.python.util import nest
67from tensorflow.python.util.tf_export import tf_export
69# API label for SavedModel metrics.
70_LOAD_V2_LABEL = "load_v2"
71# Built-in registrations use the "oneof kind" field in the SavedObject proto,
72# instead of "registered_name" field. The "kind" field has almost the same
73# functionality as the registered_name, but only contains built-in TensorFlow
74# types (like variable, functions, assets).
75_BUILT_IN_REGISTRATIONS = {
76 "asset": asset.Asset,
77 "resource": resource.RestoredResource,
78 "constant": function_saved_model_utils.TrackableConstant}
81def _unused_handle():
82 """Returns a placeholder as a handle that is not supposed to be accessed."""
83 error_message = ("Trying to access a placeholder that is not supposed to be "
84 "executed. This means you are executing a graph generated "
85 "from the cross-replica context in an in-replica context.")
86 save_error_message = (
87 "It seems that you are trying to save a "
88 "tf.types.experimental.ConcreteFunction that involves a distributed "
89 "model, and the model contains parts that are loaded form a SavedModel. "
90 "It's not supported to save such tf.types.experimental.ConcreteFunction. "
91 "Try saving a tf.function with input_signature instead, and file a bug if"
92 " there are still issues.")
94 assert_op = control_flow_assert.Assert(
95 array_ops.placeholder_with_default(False, shape=()), [error_message])
96 if (not context.executing_eagerly()
97 ) and ops.get_default_graph().building_function:
98 ops.get_default_graph().mark_as_unsaveable(save_error_message)
100 with ops.control_dependencies([assert_op]):
101 return array_ops.placeholder(dtype=dtypes.resource)
104class _WrapperFunction(function.ConcreteFunction):
105 """A class wraps a concrete function to handle different distributed contexts.
107 The reason for wrapping a concrete function is because the _captured_inputs
108 fields used for in-replica context and cross-replica context are different.
109 When `load()` is called from within a tf.distribute.strategy scope, the
110 captured inputs are distributed variables. When using these distributed
111 variables during calling the function, we need different approaches when it is
112 in-replica and when it is not in-replica. When it is in replica, naturally we
113 should use the corresponding component of the distributed variable; when it is
114 not in-replica, calling the function should mean that it is constructing a
115 graph that is not actually going to be used. A typical use case is when
116 constructing a functional model. In this case, return a placeholder with a
117 control dependency to ensure that is never accessed.
118 """
120 def __init__(self, concrete_function):
121 # Shallow copy the concrete_function
122 self.__dict__.update(vars(concrete_function))
124 def _call_flat(self, args, captured_inputs):
126 def get_handle(x):
127 return x.handle if distribute_utils.is_distributed_variable(x) else x
129 def get_unused_handle(x):
130 return _unused_handle() if distribute_utils.is_distributed_variable(x) \
131 else x
133 if (distribute_lib.get_replica_context() is not None or
134 values_util.is_saving_non_distributed()):
135 # If we're in the replica context or are saving a non-distributed version
136 # of the model, we resolve the captured variables to the corresponding
137 # resource handle. In both situation we call var.handle, but it has
138 # different behavior. In the replica context, var.handle resolves the
139 # replica local variable handle if the variable is replicated. When saving
140 # a non-distributed version of the model, var.handle resolves to the
141 # primary variable handle, since we only save one copy of a replicated
142 # variable.
143 captured_inputs = list(map(get_handle, captured_inputs))
144 else: # cross-replica context
145 captured_inputs = list(map(get_unused_handle, captured_inputs))
146 return super()._call_flat(args, captured_inputs)
149class Loader(object):
150 """Helper class to load an object-based SavedModel."""
152 def __init__(self, object_graph_proto, saved_model_proto, export_dir,
153 ckpt_options, save_options, filters):
154 meta_graph = saved_model_proto.meta_graphs[0]
155 self._asset_file_def = meta_graph.asset_file_def
156 self._operation_attributes = {
157 node.name: node.attr for node in meta_graph.graph_def.node}
158 self._proto = object_graph_proto
159 self._export_dir = export_dir
160 self._concrete_functions = (
161 function_deserialization.load_function_def_library(
162 library=meta_graph.graph_def.library,
163 saved_object_graph=self._proto,
164 wrapper_function=_WrapperFunction))
165 # Store a set of all concrete functions that have been set up with
166 # captures.
167 self._restored_concrete_functions = set()
168 self._checkpoint_options = ckpt_options
169 self._save_options = save_options
171 # Metagraph has a mapping from FunctionDef name to aliases
172 self._concrete_function_aliases = meta_graph.meta_info_def.function_aliases
173 # Create a mapping from alias to Function, which can be used with
174 # SaveOptions
175 self.function_aliases = {}
177 self._pretty_printer = checkpoint.ObjectGraphProtoPrettyPrinter(self._proto)
179 # Stores user-defined node_filters argument.
180 self._node_filters = filters
181 # Stores map of string paths to integers.
182 self._node_path_to_id = self._convert_node_paths_to_ints()
183 self._loaded_nodes = {}
184 if isinstance(filters, dict):
185 # If node_filters is a dict, then the values may contain already created
186 # trackable objects. In this case, create a dictionary mapping node IDs to
187 # the already created nodes. This dict will be updated in
188 # `_retrieve_all_filtered_nodes` with tracked children.
189 for node_path, node in filters.items():
190 if isinstance(node, tuple):
191 self._loaded_nodes[self._node_path_to_id[node_path]] = node
192 else:
193 self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr)
195 # Get a list of all integer node ids to load, or None if all nodes should be
196 # loaded. This list includes ids of child nodes.
197 self._filtered_nodes = self._retrieve_all_filtered_nodes()
199 # Order all nodes or filtered nodes using the dependencies.
200 self._ordered_node_ids = self._generate_ordered_node_ids()
202 self._load_all()
204 if not save_options.experimental_skip_checkpoint:
205 self._restore_checkpoint()
206 for node in self._nodes:
207 if isinstance(node, resource.CapturableResource):
208 init_op = node._initialize() # pylint: disable=protected-access
209 if not context.executing_eagerly():
210 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
212 def _convert_node_paths_to_ints(self):
213 """Maps all string node paths in node_filters to the int node ids."""
214 if self._node_filters is None:
215 return None
216 path_to_int = {}
217 for node_id in self._node_filters:
218 int_node_id = None
219 if isinstance(node_id, str):
220 node_path = node_id.split(".")
221 if node_path[0] != "root":
222 raise ValueError(
223 "When passing string identifiers to node_filters, the first name"
224 f" must be root. Received {node_path[0]}.")
225 int_node_id = 0
226 for n, name in enumerate(node_path[1:]):
227 int_node_id = self._find_node_child(
228 int_node_id, name, ".".join(node_path[:n+2]))
229 path_to_int[node_id] = int_node_id
230 else:
231 raise TypeError("Elements in node_filters must be strings.")
232 return path_to_int
234 def _retrieve_all_filtered_nodes(self):
235 """Traverses through the object graph to get the IDs of all nodes to load.
237 As a side-effect, if node_filters is a dictionary that contains already-
238 created objects, then the children tracked by those objects will be
239 added to node_filters.
241 Returns:
242 List of all nodes to load, or None if all nodes should be loaded.
244 """
245 if self._node_filters is None:
246 return None # All nodes should be loaded.
248 all_filtered_nodes = set()
249 nodes_to_visit = list(self._node_filters)
251 while nodes_to_visit:
252 node_path = nodes_to_visit.pop(0)
253 node_id = self._node_path_to_id[node_path]
254 if node_id in all_filtered_nodes:
255 continue
256 all_filtered_nodes.add(node_id)
258 node, setter = self._loaded_nodes.get(node_id, (None, None))
259 if node is not None:
260 if not isinstance(node, base.Trackable):
261 raise TypeError(
262 "Error when processing dictionary values passed to nodes_to_load."
263 f"Object at {node_path} is expected to be a checkpointable (i.e. "
264 "'trackable') TensorFlow object (e.g. tf.Variable, tf.Module or "
265 "Keras layer).")
266 node._maybe_initialize_trackable() # pylint: disable=protected-access
268 for reference in self._proto.nodes[node_id].children:
269 child_object, _ = self._loaded_nodes.get(
270 reference.node_id, (None, None))
272 # See if node already tracks the child reference, in which case add the
273 # child to the loaded_nodes dict.
274 if child_object is None and node is not None:
275 child_object = node._lookup_dependency(reference.local_name) # pylint: disable=protected-access
276 if isinstance(child_object, data_structures.TrackableDataStructure):
277 # Make setattr a noop to avoid overwriting already existing data
278 # structures.
279 setter = lambda *args: None
281 self._loaded_nodes[reference.node_id] = (child_object, setter)
283 child_path = "{}.{}".format(node_path, reference.local_name)
284 self._node_path_to_id[child_path] = reference.node_id
285 nodes_to_visit.append(child_path)
287 if 0 in all_filtered_nodes:
288 return None
289 return all_filtered_nodes
291 def _find_node_child(self, node_id, child_name, path):
292 for reference in self._proto.nodes[node_id].children:
293 if reference.local_name == child_name:
294 return reference.node_id
295 raise ValueError(f"Unable to find node {path}.")
297 def _load_all(self):
298 """Loads all nodes and functions from the SavedModel and their edges."""
299 self._load_nodes()
300 self._load_edges()
302 # Set up concrete functions that aren't part of the object graph
303 # (e.g. gradient functions)
304 self._setup_remaining_functions()
305 self._load_checkpoint_save_and_restore_functions()
307 def _load_checkpoint_save_and_restore_functions(self):
308 """Restores the checkpoint-related save/restore functions to all nodes."""
309 temp_session = [None]
310 for node_id, proto in self._iter_all_nodes():
311 node = self.get(node_id)
312 if proto.saveable_objects.keys() == {
313 trackable_utils.SERIALIZE_TO_TENSORS_NAME}:
314 # Restore Trackable serialize- and restore-from-tensor functions.
315 assert len(proto.saveable_objects) == 1
316 saveable_object_proto = next(iter(proto.saveable_objects.values()))
317 save_fn_id = saveable_object_proto.save_function
318 restore_fn_id = saveable_object_proto.restore_function
319 node._serialize_to_tensors = self.get(save_fn_id) # pylint: disable=protected-access
320 node._restore_from_tensors = self.get(restore_fn_id) # pylint: disable=protected-access
321 else:
322 # Restore legacy SaveableObject functions.
323 saveable_fn_by_name = {}
324 for name, saveable_object_proto in proto.saveable_objects.items():
325 save_fn_id = saveable_object_proto.save_function
326 restore_fn_id = saveable_object_proto.restore_function
327 saveable_fn_by_name[name] = (self.get(save_fn_id),
328 self.get(restore_fn_id))
330 node._self_saveable_object_factories = ( # pylint: disable=protected-access
331 saveable_object_util.recreate_saveable_objects(saveable_fn_by_name,
332 temp_session))
334 def _load_edges(self):
335 """Adds edges from objects to other objects and functions."""
336 for node_id, object_proto in self._iter_all_nodes():
337 self._add_object_graph_edges(object_proto, node_id)
339 # If root object isn't loaded, then create edges from the root for
340 # checkpoint compatibility.
341 if self._filtered_nodes is not None and 0 not in self._filtered_nodes:
342 root = self.get(0)
343 for node_path in self._node_filters:
344 loaded_node = self._nodes[self._node_path_to_id[node_path]]
345 path = node_path.split(".")
346 current_node = root
347 for name in path[1:-1]:
348 if not hasattr(current_node, name):
349 setattr(current_node, name, self._recreate_base_user_object()[0])
350 current_node = getattr(current_node, name)
351 if not hasattr(current_node, path[-1]):
352 setattr(current_node, path[-1], loaded_node)
354 def _add_object_graph_edges(self, proto, node_id):
355 """Adds edges from an object to its children."""
356 obj = self._nodes[node_id]
357 setter = self._node_setters[node_id]
359 for reference in proto.children:
360 setter(obj, reference.local_name, self._nodes[reference.node_id])
361 # Note: if an object has an attribute `__call__` add a class method
362 # that allows `obj()` syntax to work. This is done per-instance to
363 # allow `callable` to be used to find out if an object is callable.
364 if reference.local_name == "__call__" and not callable(obj):
365 setattr(type(obj), "__call__", _call_attribute)
367 def _setup_remaining_functions(self):
368 concrete_function_names = sorted(self._proto.concrete_functions.keys())
369 for name in concrete_function_names:
370 if name in self._restored_concrete_functions:
371 continue
372 self._setup_function_captures(name, self._nodes)
374 def _setup_function_captures(self, concrete_function_name, nodes):
375 """Setup captures and variables in a restored function."""
376 if concrete_function_name in self._restored_concrete_functions:
377 return
378 self._restored_concrete_functions.add(concrete_function_name)
379 concrete_function = self._concrete_functions[concrete_function_name]
380 proto = self._proto.concrete_functions[concrete_function_name]
381 inputs = [nodes[node_id] for node_id in proto.bound_inputs]
382 restore_captures.restore_captures(concrete_function, inputs)
384 def _initialize_loaded_nodes(self):
385 nodes = {}
386 node_setters = {}
387 for node_id, (node, setter) in self._loaded_nodes.items():
388 nodes[node_id] = node
389 node_setters[node_id] = setter
390 return nodes, node_setters
392 def _get_node_dependencies(self, proto):
393 """Returns a dictionary of all dependencies of an object.
395 Args:
396 proto: A SavedObject proto.
398 Returns:
399 Dict mapping string dependency name *or* int node id to the node id.
400 The int node id key is used for mapping function captures.
401 """
402 dependencies = {ref.local_name: ref.node_id for ref in proto.dependencies}
403 kind = proto.WhichOneof("kind")
404 if kind == "function":
405 concrete_functions = proto.function.concrete_functions
406 for fn_name in concrete_functions:
407 for bound_input in self._proto.concrete_functions[fn_name].bound_inputs:
408 dependencies[bound_input] = bound_input
409 elif kind == "bare_concrete_function":
410 fn_name = proto.bare_concrete_function.concrete_function_name
411 for bound_input in self._proto.concrete_functions[fn_name].bound_inputs:
412 dependencies[bound_input] = bound_input
413 elif kind == "resource":
414 # Make sure that the resource creator is listed as a dependency.
415 for child in proto.children:
416 if child.local_name == "_create_resource":
417 dependencies["_create_resource"] = child.node_id
418 return dependencies
420 def _generate_ordered_node_ids(self):
421 """Orders the node ids so that dependencies appear first."""
422 if self._filtered_nodes is None:
423 unordered_ids = range(len(self._proto.nodes))
424 else:
425 unordered_ids = list(self._filtered_nodes)
427 # Maps node ids -> list of dependencies (ids of other nodes that must be
428 # loaded before it).
429 dependency_map = collections.defaultdict(list)
430 for node_id in unordered_ids:
431 deps = dependency_map[node_id]
432 if self._loaded_nodes.get(node_id) is not None:
433 # Deps are only used if the node has not been created.
434 continue
435 proto = self._proto.nodes[node_id]
436 for dep in set(self._get_node_dependencies(proto).values()):
437 deps.append(dep)
438 if self._filtered_nodes is not None and dep not in self._filtered_nodes:
439 raise ValueError(
440 "Unable to partially load SavedModel since the specified filter "
441 "does not include all required objects for loading (e.g. "
442 "variables used in functions or deserialization dependencies). "
443 "Please include this path in the filter: "
444 f"{self._pretty_printer.node_names[dep]}")
446 # Add optimizer slot variable to dependency map.
447 prev_slot = None
448 for slot_variable_proto in proto.slot_variables:
449 slot_variable_node_id = slot_variable_proto.slot_variable_node_id
450 # The optimizer and original variable must be created before the slot
451 # variable, since the slot variable is generated using the Optimizer's
452 # add_slot API.
453 slot_deps = dependency_map[slot_variable_node_id]
454 slot_deps.append(node_id)
455 slot_deps.append(slot_variable_proto.original_variable_node_id)
457 if prev_slot is not None:
458 # Add previous slot to deps so that the optimizer slot variables are
459 # added in order. The ordering is needed because the slot name and
460 # variable are both added to ordered lists, which are exposed to the
461 # user via `Optimizer.get_slot_names()` and `Optimizer.weights`.
462 # TODO(kathywu): Maybe enforce some sort of deterministic ordering in
463 # `order_by_dependency` to avoid doing this?
464 slot_deps.append(prev_slot)
465 prev_slot = slot_variable_node_id
466 try:
467 return list(trackable_utils.order_by_dependency(dependency_map))
468 except trackable_utils.CyclicDependencyError:
469 # This should not happen since there is already a validation for cycles
470 # when saving, but raise an error just in case.
471 raise ValueError("Encountered a cycle in the deserialization dependencies"
472 "in the SavedModel. This is extremely unexpected, please"
473 "file a bug and make sure you are not manually modifying"
474 " the SavedModel.")
476 def _iter_all_nodes(self):
477 for node_id in self._ordered_node_ids:
478 yield node_id, self._proto.nodes[node_id]
480 def _load_nodes(self):
481 """Load all saved objects."""
482 # `nodes` maps from node ids to recreated objects
483 # `node_setters` maps from node ids to setter functions
484 # (same signature as setattr) for setting children.
485 nodes, node_setters = self._initialize_loaded_nodes()
487 # Figure out which objects are slot variables. These objects are created
488 # with Optimizer.add_slot rather than _recreate_variable.
489 # Maps slot node id -> optimizer node id, SlotVariableReference proto
490 slot_variable_node_ids = {}
492 for node_id, proto in self._iter_all_nodes():
493 for slot_variable_proto in proto.slot_variables:
494 slot_variable_node_id = slot_variable_proto.slot_variable_node_id
495 slot_variable_node_ids[slot_variable_node_id] = (node_id,
496 slot_variable_proto)
498 # Re-create everything.
499 for node_id, proto in self._iter_all_nodes():
500 if nodes.get(node_id) is not None:
501 continue
502 elif node_id in slot_variable_node_ids:
503 # Use the public Optimizer interface when creating slot variables.
504 optimizer_node_id, slot_variable_proto = slot_variable_node_ids[node_id]
505 optimizer_object = nodes[optimizer_node_id]
506 optimized_variable = nodes[
507 slot_variable_proto.original_variable_node_id]
508 slot_variable = optimizer_object.add_slot(
509 var=optimized_variable,
510 slot_name=slot_variable_proto.slot_name)
511 nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
512 node_setters[slot_variable_proto.slot_variable_node_id] = setattr
513 else:
514 node, setter = self._recreate(proto, node_id, nodes)
515 nodes[node_id] = node
516 node_setters[node_id] = setter
518 # If root object is not loaded, add a dummy root object for checkpoint
519 # compatibility.
520 if 0 not in nodes:
521 nodes[0] = self._recreate_base_user_object()[0]
523 self._nodes = [nodes.get(node_id)
524 for node_id in range(len(self._proto.nodes))]
525 self._node_setters = node_setters
527 def _restore_checkpoint(self):
528 """Load state from checkpoint into the deserialized objects."""
529 variables_path = path_helpers.get_variables_path(self._export_dir)
530 # TODO(b/205010730): Clean use of private methods of TrackableSaver.
531 # pylint: disable=protected-access
532 saver = checkpoint.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
533 with ops.device("CPU"):
534 saver._file_prefix_placeholder = constant_op.constant(variables_path)
535 if self._save_options.allow_partial_checkpoint:
536 load_status = saver.restore(variables_path,
537 self._checkpoint_options).expect_partial()
538 load_status.assert_nontrivial_match()
539 else:
540 load_status = saver.restore(variables_path, self._checkpoint_options)
541 load_status.assert_existing_objects_matched()
542 ckpt = load_status._checkpoint
544 if not context.executing_eagerly():
545 reader = py_checkpoint_reader.NewCheckpointReader(variables_path)
547 # When running in eager mode, the `restore` call above has already run and
548 # restored the state of trackables, and calling `position.restore_ops()`
549 # would re-run the restore. In graph mode, that will return a cached list
550 # of ops that must run to restore the object on that position. We have to
551 # wire them in the initializers of the objects so that they get
552 # initialized properly when using common practices (e.g. the ones used by
553 # ManagedSession) without further user action.
554 for object_id, obj in dict(ckpt.object_by_proto_id).items():
555 position = restore.CheckpointPosition(checkpoint=ckpt,
556 proto_id=object_id)
557 registered_saver = position.get_registered_saver_name()
558 if registered_saver:
559 raise NotImplementedError(
560 "Loading a SavedModel that uses registered checkpoint saver is "
561 f"not supported in graph mode. The loaded object {obj} uses the "
562 f"saver registered with the name {registered_saver}.")
564 restore_ops = position.restore_ops(reader)
565 if restore_ops:
566 if resource_variable_ops.is_resource_variable(obj):
567 if len(restore_ops) == 1:
568 obj._initializer_op = restore_ops[0]
569 else:
570 obj._initializer_op = control_flow_ops.group(*restore_ops)
571 elif (isinstance(obj, lookup_ops.LookupInterface) or
572 isinstance(obj, resource.CapturableResource)):
573 # We don't need to check for eager execution here, since this code
574 # path should only be taken if we are restoring in graph mode.
575 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
576 else:
577 raise NotImplementedError(
578 f"Unable to restore state of object {obj} from the checkpoint.")
580 def adjust_debug_info_func_names(self, debug_info):
581 """Rewrite func names in the debug info by using the concrete func names."""
582 output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
583 output_debug_info.files[:] = debug_info.files
584 for key in debug_info.traces:
585 node, func = key.split("@")
586 new_func = ""
587 if func in self._concrete_functions:
588 new_func = self._concrete_functions[func].function_def.signature.name
589 output_debug_info.traces[node + "@" + new_func].CopyFrom(
590 debug_info.traces[key])
591 return output_debug_info
593 def get(self, node_id):
594 if isinstance(node_id, str):
595 node_id = self._node_path_to_id[node_id]
596 return self._nodes[node_id]
598 def _recreate(self, proto, node_id, nodes):
599 """Creates a Python object from a SavedObject protocol buffer.
601 Args:
602 proto: a SavedObject proto
603 node_id: int, the index of this object in the SavedObjectGraph node list.
604 nodes: dict mapping int node_ids -> created objects.
606 Returns:
607 The recreated object, and the set-attribute function for reconnecting
608 the trackable children.
609 """
610 registered_class = registration.get_registered_class(proto.registered_name)
611 if registered_class is None:
612 registered_class = _BUILT_IN_REGISTRATIONS.get(proto.WhichOneof("kind"))
614 dependencies = {}
615 for key, dep_node_id in self._get_node_dependencies(proto).items():
616 dependencies[key] = nodes[dep_node_id]
618 if registered_class:
619 obj = registered_class._deserialize_from_proto( # pylint: disable=protected-access
620 proto=proto.serialized_user_proto,
621 object_proto=proto,
622 dependencies=dependencies,
623 export_dir=self._export_dir,
624 asset_file_def=self._asset_file_def,
625 operation_attributes=self._operation_attributes)
626 if isinstance(obj, base.Trackable):
627 setter = type(obj)._add_trackable_child # pylint: disable=protected-access
628 else:
629 # Returned object may be non-Trackable (e.g. when restoring captures).
630 setter = setattr
631 return obj, setter
632 else:
633 return self._recreate_default(proto, node_id, dependencies)
635 def _recreate_default(self, proto, node_id, deps):
636 """Creates a Python object from a SavedObject protocol buffer."""
637 factory = {
638 "user_object": (
639 lambda: self._recreate_user_object(proto.user_object, node_id)),
640 "function": lambda: self._recreate_function(proto.function, deps),
641 "bare_concrete_function": functools.partial(
642 self._recreate_bare_concrete_function,
643 proto=proto.bare_concrete_function, dependencies=deps),
644 "variable": lambda: self._recreate_variable(proto.variable),
645 "captured_tensor": functools.partial(
646 self._get_tensor_from_fn, proto.captured_tensor),
647 }
648 kind = proto.WhichOneof("kind")
649 if kind not in factory:
650 raise ValueError(f"Unknown SavedObject type: {kind}. Expected one of "
651 f"{list(factory.keys())}.")
652 return factory[kind]()
654 def _recreate_user_object(self, proto, node_id):
655 """Instantiates a SavedUserObject."""
656 if proto.identifier == "optimizer":
657 # Make sure that the Keras optimizers module is imported. This is needed
658 # to be able to load the "optimizer" object (OptimizerV2), which has
659 # special logic around adding slot variables with `add_slot` in this file.
660 try:
661 import keras.optimizers.legacy as _ # pylint: disable=g-import-not-at-top
662 except ImportError:
663 try:
664 import keras.optimizers.optimizer_v2 as _ # pylint: disable=g-import-not-at-top
665 except ImportError as e:
666 raise ImportError(
667 "Error when importing Keras. Unable to load SavedModel that "
668 "contains an optimizer without the Keras module.") from e
669 looked_up = revived_types.deserialize(proto)
670 if looked_up is None:
671 return self._recreate_base_user_object(proto, node_id)
672 return looked_up
674 def _recreate_base_user_object(self, proto=None, node_id=None):
675 del proto, node_id
676 # Note: each user object has its own class. This allows making each one
677 # individually callable by adding a `__call__` method to the classes of
678 # the objects instances that have a `__call__` property.
680 class _UserObject(autotrackable.AutoTrackable):
681 pass
683 return _UserObject(), setattr
685 def _recreate_function(self, proto, dependencies):
686 fn = function_deserialization.recreate_function(
687 proto, self._concrete_functions)
688 for name in proto.concrete_functions:
689 self._setup_function_captures(name, dependencies)
691 if self._save_options.experimental_load_function_aliases:
692 for name in proto.concrete_functions:
693 if name in self._concrete_function_aliases:
694 alias = self._concrete_function_aliases[name]
695 self.function_aliases[alias] = fn
696 # We only need to save the mapping from alias to a tf.Function
697 # once even though it can appear multiple times in
698 # self._concrete_function_aliases due to one-to-many mapping from
699 # tf.Function to concrete functions.
700 break
702 return fn, setattr
704 def _recreate_bare_concrete_function(self, proto, dependencies):
705 fn = function_deserialization.setup_bare_concrete_function(
706 proto, self._concrete_functions)
707 self._setup_function_captures(proto.concrete_function_name, dependencies)
708 return fn, setattr
710 def _recreate_variable(self, proto):
711 name = proto.name if proto.name else None
712 if name is not None:
713 dbg_name = name
714 else:
715 dbg_name = "<variable loaded from saved model>"
716 synchronization, aggregation, trainable = (
717 variables.validate_synchronization_aggregation_trainable(
718 proto.synchronization, proto.aggregation, proto.trainable,
719 name=dbg_name))
721 def uninitialized_variable_creator(next_creator, **kwargs):
722 """A variable creator that creates uninitialized variables."""
723 del next_creator
724 return resource_variable_ops.UninitializedVariable(**kwargs)
726 # Create a variable_creator_scope that creates uninitialized variables with
727 # a lower priority such that a potential distributed variable_creator_scope
728 # can take precedence.
729 with ops.get_default_graph()._variable_creator_scope( # pylint: disable=protected-access
730 uninitialized_variable_creator,
731 priority=50):
732 saved_device = proto.device
733 load_with_device = (
734 self._save_options.experimental_variable_policy
735 ._save_variable_devices() and config.get_soft_device_placement() and
736 saved_device)
737 if load_with_device:
738 with ops.device(saved_device):
739 return variables.Variable(
740 shape=proto.shape,
741 dtype=proto.dtype,
742 name=name,
743 trainable=trainable,
744 synchronization=synchronization,
745 aggregation=aggregation), setattr
746 else:
747 return variables.Variable(
748 shape=proto.shape,
749 dtype=proto.dtype,
750 name=name,
751 trainable=trainable,
752 synchronization=synchronization,
753 aggregation=aggregation), setattr
755 def _get_tensor_from_fn(self, proto):
756 outer_graph = self._concrete_functions[proto.concrete_function].graph
757 captured_tensor = outer_graph.get_tensor_by_name(proto.name)
758 return captured_tensor, setattr
761def _call_attribute(instance, *args, **kwargs):
762 return instance.__call__(*args, **kwargs)
765@tf_export("saved_model.load", v1=["saved_model.load_v2"])
766def load(export_dir, tags=None, options=None):
767 """Load a SavedModel from `export_dir`.
769 Signatures associated with the SavedModel are available as functions:
771 ```python
772 imported = tf.saved_model.load(path)
773 f = imported.signatures["serving_default"]
774 print(f(x=tf.constant([[1.]])))
775 ```
777 Objects exported with `tf.saved_model.save` additionally have trackable
778 objects and functions assigned to attributes:
780 ```python
781 exported = tf.train.Checkpoint(v=tf.Variable(3.))
782 exported.f = tf.function(
783 lambda x: exported.v * x,
784 input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
785 tf.saved_model.save(exported, path)
786 imported = tf.saved_model.load(path)
787 assert 3. == imported.v.numpy()
788 assert 6. == imported.f(x=tf.constant(2.)).numpy()
789 ```
791 _Loading Keras models_
793 Keras models are trackable, so they can be saved to SavedModel. The object
794 returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
795 `.fit`, `.predict`, etc. methods). A few attributes and functions are still
796 available: `.variables`, `.trainable_variables` and `.__call__`.
798 ```python
799 model = tf.keras.Model(...)
800 tf.saved_model.save(model, path)
801 imported = tf.saved_model.load(path)
802 outputs = imported(inputs)
803 ```
805 Use `tf.keras.models.load_model` to restore the Keras model.
807 _Importing SavedModels from TensorFlow 1.x_
809 SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
810 graph instead of `tf.function` objects. These SavedModels will be loaded with
811 the following attributes:
813 * `.signatures`: A dictionary mapping signature names to functions.
814 * `.prune(feeds, fetches) `: A method which allows you to extract
815 functions for new subgraphs. This is equivalent to importing the SavedModel
816 and naming feeds and fetches in a Session from TensorFlow 1.x.
818 ```python
819 imported = tf.saved_model.load(path_to_v1_saved_model)
820 pruned = imported.prune("x:0", "out:0")
821 pruned(tf.ones([]))
822 ```
824 See `tf.compat.v1.wrap_function` for details.
825 * `.variables`: A list of imported variables.
826 * `.graph`: The whole imported graph.
827 * `.restore(save_path)`: A function that restores variables from a checkpoint
828 saved from `tf.compat.v1.Saver`.
830 _Consuming SavedModels asynchronously_
832 When consuming SavedModels asynchronously (the producer is a separate
833 process), the SavedModel directory will appear before all files have been
834 written, and `tf.saved_model.load` will fail if pointed at an incomplete
835 SavedModel. Rather than checking for the directory, check for
836 "saved_model_dir/saved_model.pb". This file is written atomically as the last
837 `tf.saved_model.save` file operation.
839 Args:
840 export_dir: The SavedModel directory to load from.
841 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
842 if the SavedModel contains a single MetaGraph, as for those exported from
843 `tf.saved_model.save`.
844 options: `tf.saved_model.LoadOptions` object that specifies options for
845 loading.
847 Returns:
848 A trackable object with a `signatures` attribute mapping from signature
849 keys to functions. If the SavedModel was exported by `tf.saved_model.save`,
850 it also points to trackable objects, functions, debug info which it has been
851 saved.
853 Raises:
854 ValueError: If `tags` don't match a MetaGraph in the SavedModel.
855 """
856 if isinstance(export_dir, os.PathLike):
857 export_dir = os.fspath(export_dir)
858 result = load_partial(export_dir, None, tags, options)["root"]
859 return result
862@tf_export("__internal__.saved_model.load_partial", v1=[])
863def load_partial(export_dir, filters, tags=None, options=None):
864 """Partially load a SavedModel (saved from V2).
866 Similar to `tf.saved_model.load`, but with an additional argument that
867 lets you specify which nodes to load.
868 `tf.saved_model.load_partial(export_dir, ["root"])` and
869 `tf.saved_model.load(export_dir)` are equivalent.
871 Note: This only works for SavedModels saved with TensorFlow V2 from
872 `tf.saved_model.save` or Keras. This will not load SavedModels save from
873 the Estimator API.
875 In Tensorflow V2, SavedModel stores the **object graph** of the saved object.
876 The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras
877 layers, etc.) and edges that are the name of the attributes connecting the
878 objects.
880 *Example 1*
882 ```
883 model = tf.Module()
884 model.child_layer = tf.Module()
885 model.child_layer.v = tf.Variable(5.)
886 tf.saved_model.save(model, '/tmp/model')
887 loaded = tf.__internal__.saved_model.load_partial(
888 ... '/tmp/model',
889 ... ['root.child_layer', 'root.child_layer.v'])
890 loaded['root.child_layer'].v.numpy()
891 5.
892 loaded['root.child_layer'].v is loaded['root.child_layer.v']
893 True
895 *Example 2*
896 model = tf.Module()
897 model.child_layer = tf.Module()
898 model.child_layer.v = tf.Variable(5.)
899 >>>
900 tf.saved_model.save(model, '/tmp/model')
901 # Create a variable
902 new_variable = tf.Variable(0.)
903 loaded = tf.__internal__.saved_model.load_partial(
904 ... '/tmp/model',
905 ... {'root.child_layer': None, 'root.child_layer.v': new_variable})
906 loaded['root.child_layer'].v.numpy()
907 5.
908 new_variable.numpy()
909 5.
910 ```
912 **Loading under different distribution strategies**
913 You can load different parts of the model under different distribution
914 strategies. Note that this is very experimental so use with care.
916 ```
917 model = tf.Module()
918 model.layer_1 = tf.Module()
919 model.layer_1.v = tf.Variable(5.)
920 model.layer_2 = tf.Module()
921 model.layer_2.v = tf.Variable(7.)
922 tf.saved_model.save(model, '/tmp/model')
923 # Load with no strategy
924 loaded = tf.__internal__.saved_model.load_partial(
925 ... '/tmp/model',
926 ... ['root.layer_1'])
927 loaded['root.layer_1'].v
928 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
929 strategy = tf.distribute.MirroredStrategy()
930 with strategy.scope():
931 ... loaded2 = tf.__internal__.saved_model.load_partial(
932 ... '/tmp/model',
933 ... ['root.layer_2'])
934 loaded2['root.layer_2'].v
935 MirroredVariable:{
936 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>
937 }
938 ```
940 Args:
941 export_dir: The SavedModel directory to load from.
942 filters: A list or dictionary where each element or key is a string
943 path to nodes that should be loaded. Node paths consist of all the child
944 attribute names to reach that node in the form: `root.{attribute_name}`.
945 The loader will load all of the specified nodes and their recursive
946 descendants. When this option is defined, the loader will return a
947 dictionary mapping the node paths to the loaded objects.
948 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
949 if the SavedModel contains a single MetaGraph, as for those exported from
950 `tf.saved_model.save`.
951 options: `tf.saved_model.LoadOptions` object that specifies options for
952 loading.
954 Returns:
955 A dictionary mapping node paths from the filter to loaded objects.
956 """
957 options = options or load_options.LoadOptions()
958 if tags is not None and not isinstance(tags, set):
959 # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
960 # sequences for nest.flatten, so we put those through as-is.
961 tags = nest.flatten(tags)
962 saved_model_proto, debug_info = (
963 loader_impl.parse_saved_model_with_debug_info(export_dir))
965 if (len(saved_model_proto.meta_graphs) == 1 and
966 saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
967 metrics.IncrementReadApi(_LOAD_V2_LABEL)
968 meta_graph_def = saved_model_proto.meta_graphs[0]
969 # tensor_content field contains raw bytes in litle endian format
970 # which causes problems when loaded on big-endian systems
971 # requiring byteswap
972 if sys.byteorder == "big":
973 saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
974 "big")
975 if (tags is not None
976 and set(tags) != set(meta_graph_def.meta_info_def.tags)):
977 raise ValueError(
978 f"Got an incompatible argument to `tags`: {tags}. The SavedModel at "
979 f"{export_dir} has one MetaGraph with tags "
980 f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, "
981 "pass 'None', or pass matching tags.")
982 object_graph_proto = meta_graph_def.object_graph_def
984 ckpt_options = checkpoint_options.CheckpointOptions(
985 experimental_io_device=options.experimental_io_device)
986 with ops.init_scope():
987 try:
988 loader = Loader(object_graph_proto, saved_model_proto, export_dir,
989 ckpt_options, options, filters)
990 except errors.NotFoundError as err:
991 raise FileNotFoundError(
992 str(err) + "\n You may be trying to load on a different device "
993 "from the computational device. Consider setting the "
994 "`experimental_io_device` option in `tf.saved_model.LoadOptions` "
995 "to the io_device such as '/job:localhost'.")
996 root = loader.get(0)
997 root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
998 root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
999 root.tensorflow_git_version = (
1000 meta_graph_def.meta_info_def.tensorflow_git_version)
1001 metrics.IncrementRead(write_version="2")
1002 else:
1003 if filters:
1004 raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any"
1005 " version) cannot be loaded with node filters.")
1006 with ops.init_scope():
1007 root = load_v1_in_v2.load(export_dir, tags)
1008 root.graph_debug_info = debug_info
1009 # For privacy concerns, please see the note in
1010 # tensorflow/cc/saved_model/metrics.h
1011 metrics.SetReadPath(saved_model_path=str(export_dir))
1013 # Read and log SavedModel checksum, if it is nonzero.
1014 try:
1015 fingerprint = fingerprinting.read_fingerprint(export_dir)
1016 except FileNotFoundError:
1017 logging.info(
1018 "Fingerprint not found. Saved model loading will continue.")
1019 singleprint = ""
1020 except RuntimeError:
1021 logging.exception(
1022 "Fingerprint was found, but there was an error when reading the proto.")
1023 singleprint = ""
1024 else:
1025 metrics.SetReadFingerprint(
1026 fingerprint=fingerprinting_utils.to_proto(
1027 fingerprint).SerializeToString())
1028 singleprint = fingerprint.singleprint()
1029 metrics.SetReadPathAndSingleprint(path=export_dir, singleprint=singleprint)
1031 if options.experimental_load_function_aliases:
1032 if hasattr(root, "function_aliases"):
1033 raise ValueError(
1034 "Could not load with experimental_load_function_aliases option"
1035 " because the top-level object already has an attributed with name"
1036 " 'function_aliases'"
1037 )
1038 root.function_aliases = loader.function_aliases
1040 if filters:
1041 return {node_id: loader.get(node_id) for node_id in filters}
1042 else:
1043 return {"root": root}