Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py: 19%
562 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Exports a SavedModel from a Trackable Python object."""
17import collections
18import os
19import re
20import sys
21import traceback
23from absl import logging
25from tensorflow.core.framework import function_pb2
26from tensorflow.core.framework import versions_pb2
27from tensorflow.core.protobuf import meta_graph_pb2
28from tensorflow.core.protobuf import saved_model_pb2
29from tensorflow.core.protobuf import saved_object_graph_pb2
30from tensorflow.python.checkpoint import checkpoint
31from tensorflow.python.checkpoint import checkpoint_options
32from tensorflow.python.checkpoint import functional_saver
33from tensorflow.python.checkpoint import graph_view
34from tensorflow.python.checkpoint import save_util_v1
35from tensorflow.python.checkpoint import util as checkpoint_util
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.eager import function as defun
39from tensorflow.python.eager.polymorphic_function import polymorphic_function
40from tensorflow.python.eager.polymorphic_function import saved_model_exported_concrete
41from tensorflow.python.eager.polymorphic_function import saved_model_utils
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import error_interpolation
44from tensorflow.python.framework import errors
45from tensorflow.python.framework import function as framework_fn
46from tensorflow.python.framework import meta_graph
47from tensorflow.python.framework import ops
48from tensorflow.python.framework import tensor_util
49from tensorflow.python.framework import versions
50from tensorflow.python.lib.io import file_io
51from tensorflow.python.ops import array_ops
52from tensorflow.python.ops import control_flow_ops
53from tensorflow.python.ops import resource_variable_ops
54from tensorflow.python.saved_model import builder_impl
55from tensorflow.python.saved_model import fingerprinting_utils
56from tensorflow.python.saved_model import function_serialization
57from tensorflow.python.saved_model import path_helpers
58from tensorflow.python.saved_model import pywrap_saved_model
59from tensorflow.python.saved_model import registration
60from tensorflow.python.saved_model import revived_types
61from tensorflow.python.saved_model import save_context
62from tensorflow.python.saved_model import save_options
63from tensorflow.python.saved_model import signature_constants
64from tensorflow.python.saved_model import signature_def_utils
65from tensorflow.python.saved_model import signature_serialization
66from tensorflow.python.saved_model import tag_constants
67from tensorflow.python.saved_model import tracing_utils
68from tensorflow.python.saved_model import utils_impl
69from tensorflow.python.saved_model.pywrap_saved_model import constants
70from tensorflow.python.saved_model.pywrap_saved_model import metrics
71from tensorflow.python.trackable import asset
72from tensorflow.python.trackable import base
73from tensorflow.python.trackable import resource
74from tensorflow.python.trackable import trackable_utils
75from tensorflow.python.training.saving import trace_saveable_util
76from tensorflow.python.types import core as types_core
77from tensorflow.python.util import compat
78from tensorflow.python.util import object_identity
79from tensorflow.python.util.tf_export import tf_export
81_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
83# Container for tensors captured from external functions.
84_CapturedTensor = collections.namedtuple("_CapturedTensor",
85 ["name", "concrete_function"])
87# Number of untraced functions to display to user in warning message.
88_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
90# API label for SavedModel metrics.
91_SAVE_V2_LABEL = "save_v2"
94class _AugmentedGraphView(graph_view.ObjectGraphView):
95 """An extendable graph which also tracks functions attached to objects.
97 Extensions through `add_object` appear in the object graph and any checkpoints
98 generated from it, even if they are not dependencies of the node they were
99 attached to in the saving program. For example a `.signatures` attribute is
100 added to exported SavedModel root objects without modifying the root object
101 itself.
103 Also tracks functions attached to objects in the graph, through the caching
104 `_list_functions` method. Enumerating functions only through this method
105 ensures that we get a consistent view of functions, even if object attributes
106 create new functions every time they are accessed.
107 """
109 def __init__(self, root):
110 super(_AugmentedGraphView, self).__init__(root)
112 # Cache the results of `GraphView.list_children()` to ensure that the
113 # `Trackable` children are gathered exactly once.
114 self._children_cache = object_identity.ObjectIdentityDictionary()
116 # Cache shared between objects in the same object graph. This is passed to
117 # `Trackable._trackable_children()`.
118 self._serialization_cache = object_identity.ObjectIdentityDictionary()
120 # Maps functions -> wrapped functions that capture non-cached variables.
121 self._wrapped_functions = {}
123 self.untraced_functions = []
125 def set_signature(self, signature_map, wrapped_functions):
126 """Attach signature to the root object.
128 Args:
129 signature_map: An object that contains signature functions.
130 wrapped_functions: A dictionary mapping functions to functions that are
131 guaranteed to not capture cached variables (functions that capture
132 cached variables can't be saved).
133 """
134 self.list_children(self.root)
135 # Overrides existing dependency.
136 name = signature_serialization.SIGNATURE_ATTRIBUTE_NAME
137 self._children_cache[self.root][name] = signature_map
138 self._wrapped_functions.update(wrapped_functions)
140 def _breadth_first_traversal(self):
141 """Returns all trackable objects in the SavedObjectGraph."""
142 # This method is overriden to merge all equivalent constant tensors and
143 # Assets in the object graph.
145 trackable_objects, _ = (
146 super(_AugmentedGraphView, self)._breadth_first_traversal())
148 asset_paths = object_identity.ObjectIdentityDictionary()
149 constant_captures = object_identity.ObjectIdentityDictionary()
150 for obj in trackable_objects:
151 if isinstance(obj, asset.Asset):
152 asset_paths[obj.asset_path] = obj
153 if isinstance(obj, saved_model_utils.TrackableConstant):
154 constant_captures[obj.capture] = obj
156 def _get_merged_trackable(x):
157 if isinstance(x, asset.Asset):
158 return asset_paths[x.asset_path]
159 if isinstance(x, saved_model_utils.TrackableConstant):
160 if x.capture in asset_paths:
161 return asset_paths[x.capture]
162 else:
163 return constant_captures[x.capture]
164 return x
166 for obj in list(self._children_cache.keys()):
167 if _get_merged_trackable(obj) is not obj:
168 del self._children_cache[obj]
169 continue
170 for name, child in self._children_cache[obj].items():
171 self._children_cache[obj][name] = _get_merged_trackable(child)
173 return super(_AugmentedGraphView, self)._breadth_first_traversal()
175 def list_children(self, obj):
176 """Lists children of `obj` for SavedModel."""
177 if obj not in self._children_cache:
178 children = self._children_cache[obj] = {}
180 for name, child in super(_AugmentedGraphView, self).list_children(
181 obj,
182 save_type=base.SaveType.SAVEDMODEL,
183 cache=self._serialization_cache):
184 if isinstance(child, defun.ConcreteFunction):
185 child = self._maybe_uncache_variable_captures(child)
186 children[name] = child
188 # Keep track of untraced functions for later reporting to the user.
189 if isinstance(obj, def_function.Function) and not children:
190 self.untraced_functions.append(obj.name)
192 for name, child in self._children_cache[obj].items():
193 yield base.TrackableReference(name, child)
195 def get_child(self, obj, name):
196 return self._children_cache[obj][name]
198 def _maybe_uncache_variable_captures(self, concrete_function):
199 if concrete_function in self._wrapped_functions:
200 return self._wrapped_functions[concrete_function]
201 for capture in concrete_function.captured_inputs:
202 if hasattr(capture, "_cached_variable"):
203 if concrete_function not in self._wrapped_functions:
204 wrapped = self._wrapped_functions[concrete_function] = (
205 function_serialization.wrap_cached_variables(concrete_function))
206 return wrapped
207 return concrete_function
209 def list_dependencies(self, obj):
210 """Yields `Trackables` that must be loaded before `obj`.
212 Dependencies and children are both dictionaries of `Trackables`. Children
213 define the object graph structure (used in both checkpoints and SavedModel),
214 while dependency defines the order used to load the SavedModel
216 Args:
217 obj: A `Trackable` object
219 Yields:
220 Tuple of dependency names and trackable objects.
222 Raises:
223 TypeError: if any of the returned dependencies are not instances of
224 `Trackable`.
225 """
226 if obj not in self._children_cache:
227 # Slot variables do not appear in the children_cache.
228 children = {}
229 else:
230 children = self._children_cache[obj]
231 for name, dep in obj._deserialization_dependencies(children).items(): # pylint: disable=protected-access
232 if not isinstance(dep, base.Trackable):
233 raise TypeError(
234 f"The dependency of type {type(dep)} is not an instance `Trackable`"
235 ", and can't be saved to SavedModel. Please check the "
236 "implementation of `_deserialization_dependencies` in the parent "
237 f"object {obj}.")
238 yield name, dep
241class _SaveableView(object):
242 """Provides a frozen view over a trackable root.
244 This class helps to create a single stable view over an object to save. The
245 saving code should access properties and functions via this class and not via
246 the original object as there are cases where an object construct their
247 trackable attributes and functions dynamically per call and will yield
248 different objects if invoked more than once.
250 Changes to the graph, for example adding objects, must happen in
251 `augmented_graph_view` (an `_AugmentedGraphView`) before the `_SaveableView`
252 is constructed. Changes after the `_SaveableView` has been constructed will be
253 ignored.
254 """
256 def __init__(self, augmented_graph_view, options):
257 """Initializes a SaveableView.
259 Args:
260 augmented_graph_view: A GraphView object.
261 options: A SaveOptions instance.
262 """
264 self.augmented_graph_view = augmented_graph_view
265 self._options = options
267 (self._trackable_objects, self.node_paths, self.node_ids,
268 self._slot_variables, self.object_names) = (
269 checkpoint_util.objects_ids_and_slot_variables_and_paths(
270 self.augmented_graph_view))
272 untraced_functions = self.augmented_graph_view.untraced_functions
273 if untraced_functions:
274 logging.info(
275 "Found untraced functions such as %s while saving (showing %d of %d)."
276 " These functions will not be directly callable after loading.",
277 ", ".join(untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
278 min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(untraced_functions)),
279 len(untraced_functions))
281 self._initialize_save_and_restore_functions()
282 self._initialize_nodes_and_concrete_functions()
284 self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
286 def _initialize_save_and_restore_functions(self):
287 """Generates all checkpoint save/restore functions.
289 The save and restore functions are generated in the eager context (or in the
290 user's Graph/Session) before being copied to the exported GraphDef. These
291 functions record the ops for saving/restoring the entire object or
292 individual objects (e.g. variables and hash tables).
294 The global save and restore functions are generated for compatibility with
295 TF1 and loading from C++, and is saved in the `MetaGraphDef.saver_def`.
297 The individual functions are generated for the Python TF2 use case, where
298 users use the loaded SavedModel as-is, or compose new models using parts
299 of the object loaded from the SavedModel. These functions are recorded in
300 the `saveable_objects` map in the `SavedObject` proto.
301 """
302 checkpoint_factory_map, registered_savers = (
303 save_util_v1.get_checkpoint_factories_and_keys(self.object_names))
304 self._obj_to_registered_saver = object_identity.ObjectIdentityDictionary()
305 for saver_name, trackables in registered_savers.items():
306 for trackable in trackables.values():
307 self._obj_to_registered_saver[trackable] = saver_name
308 self._saveable_objects_map = (
309 _gen_save_and_restore_functions(checkpoint_factory_map))
311 def _initialize_nodes_and_concrete_functions(self):
312 """Creates graph with nodes for trackable objects and functions.
314 Adds functions for each trackable object to `self.nodes` and associated
315 concrete functions to `self.concrete_functions` for serialization.
316 """
317 self.nodes = list(self._trackable_objects)
318 self.gradient_functions = []
319 self.gradient_defs = []
321 for obj in self.nodes:
322 if obj in self._saveable_objects_map:
323 for save_fn, restore_fn in self._saveable_objects_map[obj].values():
324 self.node_ids[save_fn] = len(self.nodes)
325 self.nodes.append(save_fn)
327 self.node_ids[restore_fn] = len(self.nodes)
328 self.nodes.append(restore_fn)
330 self.concrete_functions = [
331 obj for obj in self.nodes if isinstance(obj, defun.ConcreteFunction)
332 ]
334 @property
335 def concrete_and_gradient_functions(self):
336 return self.concrete_functions + self.gradient_functions
338 @property
339 def root(self):
340 return self.nodes[0]
342 def fill_object_graph_proto(self, proto):
343 """Populate the nodes, children and slot_variables of a SavedObjectGraph."""
344 for node_id, node in enumerate(self.nodes):
345 assert self.node_ids[node] == node_id
346 object_proto = proto.nodes.add()
347 object_proto.slot_variables.extend(self._slot_variables.get(node, ()))
348 if isinstance(node, _CapturedTensor):
349 continue
350 for child in self.augmented_graph_view.list_children(node):
351 child_proto = object_proto.children.add()
352 child_proto.node_id = self.node_ids[child.ref]
353 child_proto.local_name = child.name
354 for name, ref in self.augmented_graph_view.list_dependencies(node):
355 child_proto = object_proto.dependencies.add()
356 child_proto.node_id = self.node_ids[ref]
357 child_proto.local_name = name
359 if node in self._saveable_objects_map:
360 assert node not in self._obj_to_registered_saver, (
361 "Objects can't have both SaveableObjects and a registered saver")
363 for local_name, (save_fn, restore_fn) in (
364 self._saveable_objects_map[node].items()):
365 saveable_object_proto = object_proto.saveable_objects[local_name]
366 saveable_object_proto.save_function = self.node_ids[save_fn]
367 saveable_object_proto.restore_function = self.node_ids[restore_fn]
369 elif node in self._obj_to_registered_saver:
370 object_proto.registered_saver = self._obj_to_registered_saver[node]
372 def map_resources(self):
373 """Makes new resource handle ops corresponding to existing resource tensors.
375 Creates resource handle ops in the current default graph, whereas
376 `accessible_objects` will be from an eager context. Resource mapping adds
377 resource handle ops to the main GraphDef of a SavedModel, which allows the
378 C++ loader API to interact with resources.
380 Returns:
381 A tuple of (object_map, tensor_map, asset_info):
382 object_map: A dictionary mapping from object in `accessible_objects` to
383 replacement objects created to hold the new resource tensors.
384 tensor_map: A dictionary mapping from resource tensors extracted from
385 `accessible_objects` to newly created resource tensors.
386 asset_info: An _AssetInfo tuple describing external assets referenced
387 from accessible_objects.
388 """
389 # Only makes sense when adding to the export Graph
390 assert not context.executing_eagerly()
391 # TODO(b/205007558): Handle MirroredVariables and other types of variables
392 # which may need special casing.
393 object_map = object_identity.ObjectIdentityDictionary()
394 tensor_map = object_identity.ObjectIdentityDictionary()
395 asset_info = _AssetInfo(
396 asset_defs=[],
397 asset_initializers_by_resource=object_identity.ObjectIdentityDictionary(),
398 asset_filename_map={},
399 asset_index={})
401 for node_id in _dependency_sorted_node_ids(self):
402 obj = self.nodes[node_id]
403 tensors = obj._export_to_saved_model_graph( # pylint: disable=protected-access
404 object_map=object_map,
405 tensor_map=tensor_map,
406 options=self._options)
407 if isinstance(obj, asset.Asset):
408 _add_asset_info(obj, asset_info, tensor_map[obj.asset_path])
409 if tensors:
410 for tensor in tensors:
411 self.captured_tensor_node_ids[tensor] = node_id
413 return object_map, tensor_map, asset_info
415 def add_capture_and_node(self, capture, node):
416 node_id = len(self.nodes)
417 self.nodes.append(node)
418 self.node_ids[capture] = node_id
419 self.node_ids[node] = node_id
420 self.captured_tensor_node_ids[capture] = node_id
421 return node_id
423 def get_concrete_resource_initializers(self):
424 concrete_initializers = []
425 for obj in self.nodes:
426 if isinstance(obj, resource.CapturableResource):
427 concrete_initializers.append(
428 self.augmented_graph_view.get_child(
429 obj, "_initialize").get_concrete_function())
430 return concrete_initializers
433def _gen_save_and_restore_functions(checkpoint_factory_map):
434 """Generates global and individual save/restore concrete functions.
436 The global functions records the ops to save and restore the entire object to
437 a file prefix, while the individual functions save and restore value tensors
438 for resources.
440 This function is intended to run on the output of
441 `save_util_v1.get_checkpoint_factories_and_keys(object_names)`,
442 which returns the generated a map of `_CheckpointFactoryData`.
444 Args:
445 checkpoint_factory_map: A dictionary mapping trackable objects to
446 a list of `_CheckpointFactoryData`.
448 Returns:
449 Tuple of (
450 saveable_fn_map: Maps obj -> factory name -> (concrete save, restore)
451 )
452 """
453 # Maps obj -> factory attribute_name -> (concrete save, concrete restore)
454 # This
455 saveable_fn_map = object_identity.ObjectIdentityDictionary()
457 for obj, factory_data_list in checkpoint_factory_map.items():
458 if resource_variable_ops.is_resource_variable(obj) or not factory_data_list:
459 # There is no need to trace the save and restore functions for variables.
460 continue
462 if factory_data_list[0].name == trackable_utils.SERIALIZE_TO_TENSORS_NAME:
463 # Trace Trackable save and restore functions.
464 assert len(factory_data_list) == 1
465 saveable_fn_map[obj] = {trackable_utils.SERIALIZE_TO_TENSORS_NAME: (
466 tracing_utils.trace_save_and_restore(obj))}
467 else:
468 # Trace deprecated SaveableObject save and restore functions.
469 saveable_fn_map[obj] = (
470 trace_saveable_util.trace_save_restore_function_map(
471 obj, factory_data_list))
472 return saveable_fn_map
475def _tensor_dict_to_tensorinfo(tensor_dict):
476 return {
477 key: utils_impl.build_tensor_info_internal(value)
478 for key, value in tensor_dict.items()
479 }
482def _to_safe_name_scope(signature_key, user_input_name):
483 """Creates a sanitized name scope from user signature and input names.
485 Concatenates signature and input names, sanitizing as needed to be a valid
486 scope name.
488 Args:
489 signature_key: The user-provided key for the signature.
490 user_input_name: The user-provided name for the input placeholder.
492 Returns:
493 A name scope that is safe to be used in tf.name_scope().
494 """
495 name_scope = "{}_{}".format(signature_key, user_input_name)
496 if re.match(r"^[A-Za-z0-9.][A-Za-z0-9_.\\-]*$", name_scope):
497 return name_scope
498 invalid_prefix_stripped = re.sub(r"^[^A-Za-z0-9.]*", "", name_scope)
499 return re.sub(r"[^A-Za-z0-9_.\\-]", "_", invalid_prefix_stripped)
502def _map_function_arguments_to_created_inputs(
503 function_arguments, signature_key, function_name, defaults=None
504):
505 """Creates exterior placeholders in the exported graph for function arguments.
507 Functions have two types of inputs: tensors captured from the outside (eager)
508 context, and arguments to the function which we expect to receive from the
509 user at each call. `_map_captures_to_created_tensors` replaces
510 captured tensors with stand-ins (typically these are resource dtype tensors
511 associated with variables). `_map_function_inputs_to_created_inputs` runs over
512 every argument, creating a new placeholder for each which will belong to the
513 exported graph rather than the function body.
515 Args:
516 function_arguments: A list of argument placeholders in the function body.
517 signature_key: The name of the signature being exported, for error messages.
518 function_name: The name of the function, for error messages.
519 defaults: A dictionary mapping (signature_key, user_specified_name) to
520 Tensor representing default values.
522 Returns:
523 A tuple of (mapped_inputs, exterior_placeholders)
524 mapped_inputs: A list with entries corresponding to `function_arguments`
525 containing all of the inputs of the function gathered from the exported
526 graph (both captured resources and arguments).
527 exterior_argument_placeholders: A dictionary mapping from argument names
528 to placeholders in the exported graph, containing the explicit arguments
529 to the function which a user is expected to provide.
531 Raises:
532 ValueError: If argument names are not unique.
533 """
534 # `exterior_argument_placeholders` holds placeholders which are outside the
535 # function body, directly contained in a MetaGraph of the SavedModel. The
536 # function body itself contains nearly identical placeholders used when
537 # running the function, but these exterior placeholders allow Session-based
538 # APIs to call the function using feeds and fetches which name Tensors in the
539 # MetaGraph.
540 exterior_argument_placeholders = {}
541 mapped_inputs = []
542 for placeholder in function_arguments:
543 # `export_captures` contains an exhaustive set of captures, so if we don't
544 # find the input there then we now know we have an argument.
545 user_input_name = compat.as_str_any(
546 placeholder.op.get_attr("_user_specified_name"))
547 # If the internal placeholders for a function have names which were
548 # uniquified by TensorFlow, then a single user-specified argument name
549 # must refer to multiple Tensors. The resulting signatures would be
550 # confusing to call. Instead, we throw an exception telling the user to
551 # specify explicit names.
552 if user_input_name != placeholder.op.name:
553 # This should be unreachable, since concrete functions may not be
554 # generated with non-unique argument names.
555 raise ValueError(
556 "Got non-flat/non-unique argument names for SavedModel signature "
557 f"'{signature_key}': more than one argument to "
558 f"'{compat.as_str_any(function_name)}' was named "
559 f"'{user_input_name}'. "
560 "Signatures have one Tensor per named input, so to have "
561 "predictable names Python functions used to generate these "
562 "signatures should avoid *args and Tensors in nested "
563 "structures unless unique names are specified for each. Use "
564 "tf.TensorSpec(..., name=...) to provide a name for a Tensor "
565 "input.")
566 default_value = defaults.get((signature_key, user_input_name))
567 if default_value is not None:
568 placeholder_with_default = array_ops.placeholder_with_default(
569 input=default_value.numpy(),
570 shape=placeholder.shape,
571 name=_to_safe_name_scope(signature_key, user_input_name),
572 )
573 exterior_argument_placeholders[user_input_name] = placeholder_with_default
574 mapped_inputs.append(placeholder_with_default)
575 else:
576 arg_placeholder = array_ops.placeholder(
577 shape=placeholder.shape,
578 dtype=placeholder.dtype,
579 name=_to_safe_name_scope(signature_key, user_input_name),
580 )
581 exterior_argument_placeholders[user_input_name] = arg_placeholder
582 mapped_inputs.append(arg_placeholder)
583 return mapped_inputs, exterior_argument_placeholders
586def _generate_signatures(signature_functions, object_map, defaults=None):
587 """Validates and calls `signature_functions` in the exported graph.
589 Args:
590 signature_functions: A dictionary mapping string keys to concrete TensorFlow
591 functions (e.g. from `signature_serialization.canonicalize_signatures`)
592 which will be used to generate SignatureDefs.
593 object_map: A dictionary that contains mappings from signature functions to
594 concrete functions in the exported graph.
595 defaults: A dictionary mapping (signature_key, user_specified_name) to
596 Tensor representing default values.
598 Returns:
599 Each function in the `signature_functions` dictionary is called with
600 placeholder Tensors, generating a function call operation and output
601 Tensors. The placeholder Tensors, the function call operation, and the
602 output Tensors from the function call are part of the default Graph.
604 This function then returns a dictionary with the same structure as
605 `signature_functions`, with the concrete functions replaced by SignatureDefs
606 implicitly containing information about how to call each function from a
607 TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference
608 the generated placeholders and Tensor outputs by name.
610 The caller is expected to include the default Graph set while calling this
611 function as a MetaGraph in a SavedModel, including the returned
612 SignatureDefs as part of that MetaGraph.
613 """
614 signatures = {}
615 for signature_key, function in sorted(signature_functions.items()):
616 if function.graph.captures:
617 argument_inputs = function.graph.inputs[:-len(function.graph.captures)]
618 else:
619 argument_inputs = function.graph.inputs
620 mapped_inputs, exterior_argument_placeholders = (
621 _map_function_arguments_to_created_inputs(
622 argument_inputs, signature_key, function.name, defaults
623 )
624 )
625 kwarg_names = list(
626 sorted(
627 object_map[function].function.structured_input_signature[1].keys()))
628 outputs = object_map[function](**{
629 kwarg_name: mapped_input
630 for kwarg_name, mapped_input in zip(kwarg_names, mapped_inputs)
631 })
632 signatures[signature_key] = signature_def_utils.build_signature_def(
633 _tensor_dict_to_tensorinfo(exterior_argument_placeholders),
634 _tensor_dict_to_tensorinfo(outputs),
635 method_name=signature_constants.PREDICT_METHOD_NAME)
636 return signatures
639_AssetInfo = collections.namedtuple(
640 "_AssetInfo",
641 [
642 # List of AssetFileDef protocol buffers
643 "asset_defs",
644 # Map from asset variable resource Tensors to their init ops
645 "asset_initializers_by_resource",
646 # Map from base asset filenames to full paths
647 "asset_filename_map",
648 # Map from Asset to index of corresponding AssetFileDef
649 "asset_index"
650 ])
653def _add_asset_info(trackable_asset, asset_info, mapped_path_variable):
654 """Add `trackable_asset` to `asset_info`."""
655 original_path_tensor = trackable_asset.asset_path
656 original_path = tensor_util.constant_value(original_path_tensor)
657 try:
658 original_path = str(original_path.astype(str))
659 except AttributeError:
660 # Already a string rather than a numpy array
661 pass
663 path = builder_impl.get_asset_filename_to_add(
664 asset_filepath=original_path,
665 asset_filename_map=asset_info.asset_filename_map)
666 asset_info.asset_filename_map[path] = original_path
667 asset_def = meta_graph_pb2.AssetFileDef()
668 asset_def.filename = path
669 asset_def.tensor_info.name = mapped_path_variable.initial_value.name
670 asset_info.asset_defs.append(asset_def)
671 asset_info.asset_initializers_by_resource[original_path_tensor] = (
672 mapped_path_variable.initializer)
673 asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1
676def _iterate_op_types(fn):
677 """Iterates through each op in the function and returns the op type and op."""
678 if isinstance(fn, framework_fn._DefinedFunction): # pylint: disable=protected-access
679 for node in fn.definition.node_def:
680 op_type = node.attr["_gradient_op_type"].s
681 if op_type:
682 raise ValueError(
683 "Unable to save gradient functions when exporting a "
684 "_DefinedFunction (generally created through graph freezing utils "
685 "or through V1 graph importers). Please save with "
686 "`options=tf.SaveOptions(experimental_custom_gradients=False)`")
687 else:
688 for op in fn.graph.get_operations():
689 try:
690 op_type = op.get_attr("_gradient_op_type")
691 except ValueError:
692 continue
693 yield op_type, op
696def _get_outer_most_capture(fn, capture, func_graph_map):
697 """Tries to find the original captured tensor if capture more than once."""
698 outer_fn = fn
699 while outer_fn is not None and not isinstance(capture, ops.EagerTensor):
700 if capture.graph is not outer_fn.graph:
701 outer_fn = func_graph_map.get(outer_fn.graph.outer_graph)
702 else:
703 try:
704 capture_index = outer_fn.graph.internal_captures.index(capture)
705 except ValueError:
706 break # Capture is a tensor inside function, and not captured from
707 # another external function
708 capture = outer_fn.graph.external_captures[capture_index]
709 outer_fn = func_graph_map.get(outer_fn.graph.outer_graph)
710 return outer_fn, capture
713def _trace_gradient_functions(graph, saveable_view):
714 """Traces gradient functions and records them in the SaveableView."""
715 functions = list(graph._functions.values()) # pylint: disable=protected-access
716 func_graph_map = {f.graph: f for f in functions if hasattr(f, "graph")}
717 seen_op_types = set()
719 for fn in functions:
720 for op_type, op in _iterate_op_types(fn):
721 if op_type in seen_op_types:
722 continue
723 seen_op_types.add(op_type)
725 try:
726 custom_gradient = ops.gradient_registry.lookup(op_type)
727 except LookupError:
728 continue
730 try:
731 grad_fn = (
732 def_function.function(custom_gradient).get_concrete_function(
733 None, *op.inputs))
734 except Exception as exc:
735 traceback.print_exc()
736 raise ValueError(
737 "Error when tracing gradients for SavedModel.\n\n"
738 "Check the error log to see the error that was raised when "
739 "converting a gradient function to a concrete function. You may "
740 "need to update the custom gradient, or disable saving gradients "
741 "with the option "
742 "tf.saved_model.SaveOptions(experimental_custom_gradients=False)"
743 f".\n\tProblematic op name: {op.name}\n\tGradient inputs: "
744 f"{op.inputs}") from exc
746 with graph.as_default():
747 # The gradient function will capture all intermediate values. These
748 # captures be serialized so that they can be re-bound to the function
749 # when loading.
750 bad_captures = []
751 for capture in grad_fn.captured_inputs:
752 if capture.dtype in _UNCOPIABLE_DTYPES:
753 continue
754 # Tries to find the outermost capture in case the tensor is a constant
755 # or not actually captured in the current function (this could happen
756 # if the function is a while loop body, in which case the captured
757 # input is not the internal captured tensor).
758 outer_fn, outer_capture = _get_outer_most_capture(
759 fn, capture, func_graph_map
760 )
761 if outer_fn is None or isinstance(outer_capture, ops.EagerTensor):
762 if outer_capture not in saveable_view.captured_tensor_node_ids:
763 raise ValueError(
764 f"Found invalid capture {outer_capture} when "
765 "saving custom gradients."
766 )
767 saveable_view.captured_tensor_node_ids[capture] = (
768 saveable_view.captured_tensor_node_ids[outer_capture]
769 )
770 elif outer_capture.graph is outer_fn.graph:
771 capture_name = outer_capture.name
772 # It's possible for AtomicFunctions to save different names
773 # for input tensors when serialized to FunctionDef (all
774 # non-alphanumeric characters are converted to '_').
775 if isinstance(outer_fn, defun.AtomicFunction): # pylint:disable=protected-access
776 try:
777 arg_index = outer_fn.graph.inputs.index(outer_capture)
778 capture_name = (
779 outer_fn.cached_definition.signature.input_arg[
780 arg_index
781 ].name
782 + ":0"
783 )
784 except ValueError:
785 pass
787 node = _CapturedTensor(capture_name, outer_fn.name)
788 saveable_view.add_capture_and_node(capture, node)
789 else:
790 bad_captures.append(capture.name)
791 if not bad_captures:
792 grad_fn.add_to_graph(graph)
793 else:
794 raise ValueError(
795 f"Cannot save custom gradient {op_type} called in function {fn} "
796 "because SavedModel is unable to serialize the captured "
797 f"inputs: {bad_captures}"
798 )
800 saveable_view.gradient_functions.append(grad_fn)
801 func_graph_map[grad_fn.graph] = grad_fn
803 grad_def = function_pb2.RegisteredGradient()
804 grad_def.gradient_func = grad_fn.name
805 grad_def.registered_op_type = op_type
806 saveable_view.gradient_defs.append(grad_def)
809def _fill_meta_graph_def(
810 meta_graph_def,
811 saveable_view,
812 signature_functions,
813 namespace_whitelist,
814 save_custom_gradients,
815 defaults=None,
816):
817 """Generates a MetaGraph which calls `signature_functions`.
819 Args:
820 meta_graph_def: The MetaGraphDef proto to fill.
821 saveable_view: The _SaveableView being exported.
822 signature_functions: A dictionary mapping signature keys to concrete
823 functions containing signatures to add to the MetaGraph.
824 namespace_whitelist: List of strings containing whitelisted op namespaces.
825 save_custom_gradients: Whether to save custom gradients.
826 defaults: A dictionary mapping (signature_key, user_specified_name) to
827 Tensor representing default values.
829 Returns:
830 A tuple of (_AssetInfo, Graph) containing the captured assets and
831 exported Graph generated from tracing the saveable_view.
832 """
833 # List objects from the eager context to make sure Optimizers give us the
834 # right Graph-dependent variables.
835 resource_initializers = saveable_view.get_concrete_resource_initializers()
836 exported_graph = ops.Graph()
837 resource_initializer_ops = []
838 with exported_graph.as_default():
839 object_map, tensor_map, asset_info = saveable_view.map_resources()
840 signatures = _generate_signatures(signature_functions, object_map, defaults)
841 if save_custom_gradients:
842 # Custom gradients functions must be traced in the same context as the
843 # when they are registered.
844 _trace_gradient_functions(exported_graph, saveable_view)
845 with exported_graph.as_default():
846 # Create initializers for assets and resources.
847 for resource_initializer_function in resource_initializers:
848 asset_dependencies = []
849 for capture in resource_initializer_function.graph.external_captures:
850 asset_initializer = asset_info.asset_initializers_by_resource.get(
851 capture, None)
852 if asset_initializer is not None:
853 asset_dependencies.append(asset_initializer)
854 with ops.control_dependencies(asset_dependencies):
855 mapped_initializer = object_map[resource_initializer_function]
856 resource_initializer_ops.append(mapped_initializer())
857 resource_initializer_ops.extend(
858 asset_info.asset_initializers_by_resource.values())
859 with ops.control_dependencies(resource_initializer_ops):
860 init_op = control_flow_ops.no_op()
861 # Add the same op to the main_op collection and to the init_op
862 # signature. The collection is for compatibility with older loader APIs;
863 # only one will be executed.
864 meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append(
865 init_op.name)
866 meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
867 signature_def_utils.op_signature_def(init_op,
868 constants.INIT_OP_SIGNATURE_KEY))
870 # Saving an object-based checkpoint again gathers variables. We need to do the
871 # gathering from the eager context so Optimizers save the right set of
872 # variables, but want any operations associated with the save/restore to be in
873 # the exported graph (thus the `to_graph` argument).
874 def call_with_mapped_captures(function, args):
875 if function in object_map:
876 return object_map[function](*args)
877 # Registered saver/restore functions do not appear in `object_map`, because
878 # they are not in the object graph.
879 return saved_model_exported_concrete.ExportedConcreteFunction(
880 function, tensor_map)(*args)
882 for obj in object_map.values():
883 obj._maybe_initialize_trackable() # pylint: disable=protected-access
884 named_saveable_objects, registered_savers = (
885 save_util_v1.frozen_saveables_and_savers(
886 graph_view=saveable_view.augmented_graph_view,
887 object_map=object_map,
888 to_graph=exported_graph,
889 call_with_mapped_captures=call_with_mapped_captures))
890 saver = functional_saver.MultiDeviceSaver.from_saveables(
891 named_saveable_objects, registered_savers, call_with_mapped_captures)
893 with exported_graph.as_default():
894 saver_def = saver.to_proto()
895 meta_graph_def.saver_def.CopyFrom(saver_def)
897 # At this point all nodes that can be added to the SavedObjectGraph have been
898 # added, so run the following to validate deserialization dependencies.
899 _dependency_sorted_node_ids(saveable_view)
901 graph_def = exported_graph.as_graph_def(add_shapes=True)
902 graph_def.library.registered_gradients.extend(saveable_view.gradient_defs)
903 _verify_ops(graph_def, namespace_whitelist)
905 meta_graph_def.graph_def.CopyFrom(graph_def)
906 meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
907 meta_graph_def.meta_info_def.tensorflow_version = versions.__version__
908 meta_graph_def.meta_info_def.tensorflow_git_version = (
909 versions.__git_version__)
910 # We currently always strip default attributes.
911 meta_graph_def.meta_info_def.stripped_default_attrs = True
912 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
913 meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def))
914 meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
915 for signature_key, signature in signatures.items():
916 meta_graph_def.signature_def[signature_key].CopyFrom(signature)
917 meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
918 # store tensor_content in litle endian format
919 if sys.byteorder == "big":
920 utils_impl.swap_function_tensor_content(meta_graph_def, "big", "little")
921 return asset_info, exported_graph
924def _verify_ops(graph_def, namespace_whitelist):
925 """Verifies that all namespaced ops in the graph are whitelisted.
927 Args:
928 graph_def: the GraphDef to validate.
929 namespace_whitelist: a list of namespaces to allow. If `None`, all will be
930 allowed. If an op does not have a namespace, it will be allowed.
932 Raises:
933 ValueError: If the graph contains ops that violate the whitelist.
934 """
935 # By default, if the user has not specified a whitelist, we want to allow
936 # everything. We check for None directly rather than falseness, since the
937 # user may instead want to pass an empty list to disallow all custom
938 # namespaced ops.
939 if namespace_whitelist is None:
940 return
942 invalid_ops = []
943 invalid_namespaces = set()
945 all_operations = []
946 all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def))
948 for op in all_operations:
949 if ">" in op:
950 namespace = op.split(">")[0]
951 if namespace not in namespace_whitelist:
952 invalid_ops.append(op)
953 invalid_namespaces.add(namespace)
954 if invalid_ops:
955 raise ValueError(
956 "Attempted to save ops from non-whitelisted namespaces to SavedModel: "
957 f"{invalid_ops}.\nPlease verify that these ops should be saved, since "
958 "they must be available when loading the SavedModel. If loading from "
959 "Python, you must import the library defining these ops. From C++, "
960 "link the custom ops to the serving binary. Once you've confirmed this,"
961 " add the following namespaces to the `namespace_whitelist` "
962 f"argument in tf.saved_model.SaveOptions: {invalid_namespaces}.")
965def _dependency_sorted_node_ids(saveable_view):
966 """Returns topologically sorted nodes, sorted by dependencies."""
967 dependency_map = {}
968 for node in saveable_view.nodes:
969 node_id = saveable_view.node_ids[node]
970 deps = dependency_map[node_id] = []
971 # TODO(kathywu): Remove once all of these have been converted to trackable.
972 if isinstance(node, _CapturedTensor):
973 continue # These are not `Trackable` and therefore have no dependencies.
974 for _, dep in saveable_view.augmented_graph_view.list_dependencies(node):
975 if dep not in saveable_view.node_ids:
976 node_path = trackable_utils.pretty_print_node_path(
977 saveable_view.node_paths[node])
978 raise ValueError(
979 f"Found an untracked dependency. Object {node_path} depends "
980 f"on {dep}, but this dependency isn't listed as a child. "
981 "Please track this child by overriding `_trackable_children` "
982 "or use `._track_trackable`.")
983 deps.append(saveable_view.node_ids[dep])
984 try:
985 return trackable_utils.order_by_dependency(dependency_map)
986 except trackable_utils.CyclicDependencyError as err:
987 pretty_printed_nodes = []
988 pretty_printed_dependencies = []
990 for x, deps in err.leftover_dependency_map.items():
991 node_path = trackable_utils.pretty_print_node_path(
992 saveable_view.node_paths[saveable_view.nodes[x]])
993 pretty_printed_nodes.append(
994 f"\tNode {x} = {node_path} (type {type(saveable_view.nodes[x])})")
995 pretty_printed_dependencies.append(f"\tNode {x} depends on nodes {deps}")
996 pretty_printed_nodes = "\n".join(pretty_printed_nodes)
997 pretty_printed_dependencies = "\n".join(pretty_printed_dependencies)
998 raise ValueError(
999 "There is one or more dependency cycle in the saved Trackable object. "
1000 "Saving cannot continue until this cycle is resolved."
1001 f"\n>> Unresolved nodes:\n{pretty_printed_nodes}"
1002 f"\n>> Unresolved cyclic dependencies:\n{pretty_printed_dependencies}")
1005def _serialize_object_graph(saveable_view, asset_file_def_index):
1006 """Save a SavedObjectGraph proto for `root`."""
1007 # SavedObjectGraph is similar to the TrackableObjectGraph proto in the
1008 # checkpoint. It will eventually go into the SavedModel.
1009 proto = saved_object_graph_pb2.SavedObjectGraph()
1010 saveable_view.fill_object_graph_proto(proto)
1012 for concrete_function in saveable_view.concrete_and_gradient_functions:
1013 name = compat.as_text(concrete_function.name)
1014 serialized = function_serialization.serialize_concrete_function(
1015 concrete_function, saveable_view.captured_tensor_node_ids)
1016 if serialized is not None:
1017 proto.concrete_functions[name].CopyFrom(serialized)
1019 for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
1020 _write_object_proto(obj, obj_proto, asset_file_def_index,
1021 saveable_view.augmented_graph_view.list_children)
1022 return proto
1025def _write_object_proto(obj, proto, asset_file_def_index, list_children_fn):
1026 """Saves an object into SavedObject proto."""
1027 if isinstance(obj, asset.Asset):
1028 proto.asset.SetInParent()
1029 proto.asset.asset_file_def_index = asset_file_def_index[obj]
1030 elif resource_variable_ops.is_resource_variable(obj):
1031 options = save_context.get_save_options()
1032 obj._write_object_proto(proto, options) # pylint: disable=protected-access
1033 elif isinstance(obj, def_function.Function):
1034 proto.function.CopyFrom(
1035 function_serialization.serialize_function(
1036 obj, [x.ref for x in list_children_fn(obj)]))
1037 elif isinstance(obj, defun.ConcreteFunction):
1038 proto.bare_concrete_function.CopyFrom(
1039 function_serialization.serialize_bare_concrete_function(obj))
1040 elif isinstance(obj, _CapturedTensor):
1041 proto.captured_tensor.name = obj.name
1042 proto.captured_tensor.concrete_function = obj.concrete_function
1043 elif isinstance(obj, resource.CapturableResource):
1044 proto.resource.device = obj._resource_device # pylint: disable=protected-access
1045 else:
1046 registered_type_proto = revived_types.serialize(obj)
1047 if registered_type_proto is None:
1048 # Fallback for types with no matching registration
1049 # pylint:disable=protected-access
1050 registered_type_proto = saved_object_graph_pb2.SavedUserObject(
1051 identifier=obj._object_identifier,
1052 version=versions_pb2.VersionDef(
1053 producer=1, min_consumer=1, bad_consumers=[]))
1054 # pylint:enable=protected-access
1055 proto.user_object.CopyFrom(registered_type_proto)
1057 registered_name = registration.get_registered_class_name(obj)
1058 if registered_name:
1059 proto.registered_name = registered_name
1060 serialized_user_proto = obj._serialize_to_proto(object_proto=proto) # pylint: disable=protected-access
1061 if serialized_user_proto is not None:
1062 proto.serialized_user_proto.Pack(serialized_user_proto)
1065def _export_debug_info(exported_graph, export_dir):
1066 """Exports debug information from graph to file.
1068 Creates and writes GraphDebugInfo with traces for ops in all functions of the
1069 exported_graph.
1071 Args:
1072 exported_graph: A Graph that has been created by tracing a saveable view.
1073 export_dir: SavedModel directory in which to write the debug info.
1074 """
1075 exported_operations = []
1076 for fn_name in exported_graph._functions: # pylint: disable=protected-access
1077 fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access
1078 if not isinstance(fn, defun.AtomicFunction): # pylint: disable=protected-access
1079 continue
1081 fn_graph = fn.graph
1082 for fn_op in fn_graph.get_operations():
1083 exported_operations.append((fn_name, fn_op))
1085 graph_debug_info = error_interpolation.create_graph_debug_info_def(
1086 exported_operations)
1087 file_io.atomic_write_string_to_file(
1088 file_io.join(
1089 path_helpers.get_or_create_debug_dir(export_dir),
1090 constants.DEBUG_INFO_FILENAME_PB),
1091 graph_debug_info.SerializeToString(deterministic=True))
1094@tf_export(
1095 "saved_model.save",
1096 v1=["saved_model.save", "saved_model.experimental.save"])
1097def save(obj, export_dir, signatures=None, options=None):
1098 # pylint: disable=line-too-long
1099 """Exports a [tf.Module](https://www.tensorflow.org/api_docs/python/tf/Module) (and subclasses) `obj` to [SavedModel format](https://www.tensorflow.org/guide/saved_model#the_savedmodel_format_on_disk).
1101 The `obj` must inherit from the [`Trackable`
1102 class](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/tracking/base.py#L591).
1104 Example usage:
1106 >>> class Adder(tf.Module):
1107 ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
1108 ... def add(self, x):
1109 ... return x + x
1111 >>> model = Adder()
1112 >>> tf.saved_model.save(model, '/tmp/adder')
1114 The resulting SavedModel is then servable with an input named "x", a scalar
1115 with dtype float32.
1117 _Signatures_
1119 Signatures define the input and output types for a computation. The optional
1120 save `signatures` argument controls which methods in `obj` will be
1121 available to programs which consume `SavedModel`s, for example, serving
1122 APIs. Python functions may be decorated with
1123 `@tf.function(input_signature=...)` and passed as signatures directly, or
1124 lazily with a call to `get_concrete_function` on the method decorated with
1125 `@tf.function`.
1127 Example:
1129 >>> class Adder(tf.Module):
1130 ... @tf.function
1131 ... def add(self, x):
1132 ... return x + x
1134 >>> model = Adder()
1135 >>> tf.saved_model.save(
1136 ... model, '/tmp/adder',signatures=model.add.get_concrete_function(
1137 ... tf.TensorSpec([], tf.float32)))
1139 If a `@tf.function` does not have an input signature and
1140 `get_concrete_function` is not called on that method, the function will not
1141 be directly callable in the restored SavedModel.
1143 Example:
1145 >>> class Adder(tf.Module):
1146 ... @tf.function
1147 ... def add(self, x):
1148 ... return x + x
1150 >>> model = Adder()
1151 >>> tf.saved_model.save(model, '/tmp/adder')
1152 >>> restored = tf.saved_model.load('/tmp/adder')
1153 >>> restored.add(1.)
1154 Traceback (most recent call last):
1155 ...
1156 ValueError: Found zero restored functions for caller function.
1158 If the `signatures` argument is omitted, `obj` will be searched for
1159 `@tf.function`-decorated methods. If exactly one traced `@tf.function` is
1160 found, that method will be used as the default signature for the SavedModel.
1161 Else, any `@tf.function` attached to `obj` or its dependencies will be
1162 exported for use with `tf.saved_model.load`.
1164 When invoking a signature in an exported SavedModel, `Tensor` arguments are
1165 identified by name. These names will come from the Python function's argument
1166 names by default. They may be overridden by specifying a `name=...` argument
1167 in the corresponding `tf.TensorSpec` object. Explicit naming is required if
1168 multiple `Tensor`s are passed through a single argument to the Python
1169 function.
1171 The outputs of functions used as `signatures` must either be flat lists, in
1172 which case outputs will be numbered, or a dictionary mapping string keys to
1173 `Tensor`, in which case the keys will be used to name outputs.
1175 Signatures are available in objects returned by `tf.saved_model.load` as a
1176 `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
1177 on an object with a custom `.signatures` attribute will raise an exception.
1179 _Using `tf.saved_model.save` with Keras models_
1181 While Keras has its own [saving and loading
1182 API](https://www.tensorflow.org/guide/keras/save_and_serialize),
1183 this function can be used to export Keras models. For example, exporting with
1184 a signature specified:
1186 >>> class Adder(tf.keras.Model):
1187 ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
1188 ... def concat(self, x):
1189 ... return x + x
1191 >>> model = Adder()
1192 >>> tf.saved_model.save(model, '/tmp/adder')
1194 Exporting from a function without a fixed signature:
1196 >>> class Adder(tf.keras.Model):
1197 ... @tf.function
1198 ... def concat(self, x):
1199 ... return x + x
1201 >>> model = Adder()
1202 >>> tf.saved_model.save(
1203 ... model, '/tmp/adder',
1204 ... signatures=model.concat.get_concrete_function(
1205 ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")))
1207 `tf.keras.Model` instances constructed from inputs and outputs already have a
1208 signature and so do not require a `@tf.function` decorator or a `signatures`
1209 argument. If neither are specified, the model's forward pass is exported.
1211 >>> x = tf.keras.layers.Input((4,), name="x")
1212 >>> y = tf.keras.layers.Dense(5, name="out")(x)
1213 >>> model = tf.keras.Model(x, y)
1214 >>> tf.saved_model.save(model, '/tmp/saved_model/')
1216 The exported SavedModel takes "x" with shape [None, 4] and returns "out"
1217 with shape [None, 5]
1219 _Variables and Checkpoints_
1221 Variables must be tracked by assigning them to an attribute of a tracked
1222 object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
1223 from `tf.keras.layers`, optimizers from `tf.train`) track their variables
1224 automatically. This is the same tracking scheme that `tf.train.Checkpoint`
1225 uses, and an exported `Checkpoint` object may be restored as a training
1226 checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
1227 "variables/" subdirectory.
1229 `tf.function` does not hard-code device annotations from outside the function
1230 body, instead of using the calling context's device. This means for example
1231 that exporting a model that runs on a GPU and serving it on a CPU will
1232 generally work, with some exceptions:
1234 * `tf.device` annotations inside the body of the function will be hard-coded
1235 in the exported model; this type of annotation is discouraged.
1236 * Device-specific operations, e.g. with "cuDNN" in the name or with
1237 device-specific layouts, may cause issues.
1238 * For `ConcreteFunctions`, active distribution strategies will cause device
1239 placements to be hard-coded in the function.
1241 SavedModels exported with `tf.saved_model.save` [strip default-valued
1242 attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
1243 automatically, which removes one source of incompatibilities when the consumer
1244 of a SavedModel is running an older TensorFlow version than the
1245 producer. There are however other sources of incompatibilities which are not
1246 handled automatically, such as when the exported model contains operations
1247 which the consumer does not have definitions for.
1249 Args:
1250 obj: A trackable object (e.g. tf.Module or tf.train.Checkpoint) to export.
1251 export_dir: A directory in which to write the SavedModel.
1252 signatures: Optional, one of three types:
1253 * A `tf.function` with an input signature specified, which will use the
1254 default serving signature key.
1255 * The result of `f.get_concrete_function` on a `@tf.function`-decorated
1256 function `f`, in which case `f` will be used to generate a signature for
1257 the SavedModel under the default serving signature key.
1258 * A dictionary, which maps signature keys to either `tf.function`
1259 instances with input signatures or concrete functions. Keys of such a
1260 dictionary may be arbitrary strings, but will typically be from the
1261 `tf.saved_model.signature_constants` module.
1262 options: `tf.saved_model.SaveOptions` object for configuring save options.
1264 Raises:
1265 ValueError: If `obj` is not trackable.
1267 @compatibility(eager)
1268 Not well supported when graph building. From TensorFlow 1.x,
1269 `tf.compat.v1.enable_eager_execution()` should run first. Calling
1270 tf.saved_model.save in a loop when graph building from TensorFlow 1.x will
1271 add new save operations to the default graph each iteration.
1273 May not be called from within a function body.
1274 @end_compatibility
1275 """
1276 if isinstance(export_dir, os.PathLike):
1277 export_dir = os.fspath(export_dir)
1278 # pylint: enable=line-too-long
1279 metrics.IncrementWriteApi(_SAVE_V2_LABEL)
1280 save_and_return_nodes(obj, export_dir, signatures, options)
1282 metrics.IncrementWrite(write_version="2")
1285def save_and_return_nodes(obj,
1286 export_dir,
1287 signatures=None,
1288 options=None,
1289 experimental_skip_checkpoint=False):
1290 """Saves a SavedModel while returning all saved nodes and their paths.
1292 Please see `tf.saved_model.save` for details.
1294 Args:
1295 obj: A trackable object to export.
1296 export_dir: A directory in which to write the SavedModel.
1297 signatures: A function or dictionary of functions to save in the SavedModel
1298 as signatures.
1299 options: `tf.saved_model.SaveOptions` object for configuring save options.
1300 experimental_skip_checkpoint: If set to `True`, the checkpoint will not be
1301 written.
1303 Returns:
1304 A tuple of (a list of saved nodes in the order they are serialized to the
1305 `SavedObjectGraph`, dictionary mapping nodes to one possible path from
1306 the root node to the key node)
1307 """
1308 options = options or save_options.SaveOptions()
1309 saved_model = saved_model_pb2.SavedModel()
1310 meta_graph_def = saved_model.meta_graphs.add()
1312 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
1313 _build_meta_graph(obj, signatures, options, meta_graph_def))
1314 saved_model.saved_model_schema_version = (
1315 constants.SAVED_MODEL_SCHEMA_VERSION)
1317 # Write the checkpoint, copy assets into the assets directory, and write out
1318 # the SavedModel proto itself.
1319 if not experimental_skip_checkpoint:
1320 path_helpers.get_or_create_variables_dir(export_dir)
1321 ckpt_options = checkpoint_options.CheckpointOptions(
1322 experimental_io_device=options.experimental_io_device)
1323 object_saver.save(
1324 path_helpers.get_variables_path(export_dir), options=ckpt_options)
1325 builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
1326 export_dir)
1327 # Note that this needs to be the last file operation when saving the
1328 # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
1329 # indication that the SavedModel is completely written.
1330 if context.executing_eagerly():
1331 try:
1332 context.async_wait() # Ensure save operations have completed.
1333 except errors.NotFoundError as err:
1334 raise FileNotFoundError(
1335 f"{err}\n You may be trying to save on a different device from the "
1336 "computational device. Consider setting the "
1337 "`experimental_io_device` option in `tf.saved_model.SaveOptions` "
1338 "to the io_device such as '/job:localhost'.") from err
1340 # We will slowly migrate code in this function to pywrap_saved_model.Save
1341 # as we build up the C++ API.
1342 pywrap_saved_model.Save(export_dir)
1344 saved_model_serialized = saved_model.SerializeToString(deterministic=True)
1346 fingerprinting_utils.write_fingerprint(export_dir, saved_model_serialized)
1348 path = file_io.join(
1349 compat.as_str(export_dir),
1350 compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
1351 file_io.atomic_write_string_to_file(path, saved_model_serialized)
1353 # Save debug info, if requested.
1354 if options.save_debug_info:
1355 _export_debug_info(exported_graph, export_dir)
1356 # For privacy concerns, please see the note in
1357 # tensorflow/cc/saved_model/metrics.h
1358 metrics.SetWritePath(saved_model_path=str(export_dir))
1359 # Clean reference cycles so repeated export()s don't make work for the garbage
1360 # collector. Before this point, we need to keep references to captured
1361 # constants in the saved graph.
1362 ops.dismantle_graph(exported_graph)
1364 return saved_nodes, node_paths
1367def export_meta_graph(obj, filename, signatures=None, options=None):
1368 """Exports the MetaGraph proto of the `obj` to a file.
1370 This function goes through the same procedures saved_model.save goes to
1371 produce the given object's MetaGraph, then saves it to the given file. It
1372 skips saving checkpoint information, and is useful when all one wants is the
1373 graph defining the model.
1375 Args:
1376 obj: A trackable object to build the MetaGraph from.
1377 filename: The file into which to write the MetaGraph.
1378 signatures: Optional, either a `tf.function` with an input signature
1379 specified or the result of `f.get_concrete_function` on a
1380 `@tf.function`-decorated function `f`, in which case `f` will be used to
1381 generate a signature for the SavedModel under the default serving
1382 signature key. `signatures` may also be a dictionary, in which case it
1383 maps from signature keys to either `tf.function` instances with input
1384 signatures or concrete functions. The keys of such a dictionary may be
1385 arbitrary strings, but will typically be from the
1386 `tf.saved_model.signature_constants` module.
1387 options: Optional, `tf.saved_model.SaveOptions` object that specifies
1388 options for saving.
1389 """
1390 options = options or save_options.SaveOptions()
1391 export_dir = os.path.dirname(filename)
1392 meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
1393 obj, signatures, options)
1395 file_io.atomic_write_string_to_file(
1396 filename, meta_graph_def.SerializeToString(deterministic=True))
1398 # Save debug info, if requested.
1399 if options.save_debug_info:
1400 _export_debug_info(exported_graph, export_dir)
1402 # Clean reference cycles so repeated export()s don't make work for the garbage
1403 # collector. Before this point, we need to keep references to captured
1404 # constants in the saved graph.
1405 ops.dismantle_graph(exported_graph)
1408def _build_meta_graph_impl(obj, signatures, options, meta_graph_def=None):
1409 """Creates a MetaGraph containing the resources and functions of an object."""
1410 if ops.inside_function():
1411 raise AssertionError(
1412 "`tf.saved_model.save` is not supported inside a traced @tf.function. "
1413 "Move the call to the outer eagerly-executed context.")
1414 # pylint: enable=line-too-long
1415 if not isinstance(obj, base.Trackable):
1416 raise ValueError(
1417 "Expected an object of type `Trackable`, such as `tf.Module` or a "
1418 f"subclass of the `Trackable` class, for export. Got {obj} "
1419 f"with type {type(obj)}.")
1420 meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
1422 augmented_graph_view = _AugmentedGraphView(obj)
1423 if signatures is None:
1424 signatures = signature_serialization.find_function_to_export(
1425 augmented_graph_view)
1427 signatures, wrapped_functions, defaults = (
1428 signature_serialization.canonicalize_signatures(signatures)
1429 )
1430 signature_serialization.validate_augmented_graph_view(augmented_graph_view)
1431 signature_map = signature_serialization.create_signature_map(signatures)
1432 augmented_graph_view.set_signature(signature_map, wrapped_functions)
1434 # Use _SaveableView to provide a frozen listing of properties and functions.
1435 saveable_view = _SaveableView(augmented_graph_view, options)
1436 object_saver = checkpoint.TrackableSaver(augmented_graph_view)
1437 asset_info, exported_graph = _fill_meta_graph_def(
1438 meta_graph_def,
1439 saveable_view,
1440 signatures,
1441 options.namespace_whitelist,
1442 options.experimental_custom_gradients,
1443 defaults,
1444 )
1445 if options.function_aliases:
1446 function_aliases = meta_graph_def.meta_info_def.function_aliases
1447 for alias, func in options.function_aliases.items():
1448 if isinstance(func, types_core.ConcreteFunction):
1449 function_aliases[func.name] = alias
1450 elif isinstance(func, polymorphic_function.Function):
1451 for fdef in func._list_all_concrete_functions(): # pylint: disable=protected-access
1452 function_aliases[fdef.name] = alias
1453 else:
1454 raise TypeError(
1455 f"Unsupported type f{type(func)}. Functions in `function_aliases`"
1456 " should be created by tf.function, or concrete functions."
1457 )
1458 object_graph_proto = _serialize_object_graph(saveable_view,
1459 asset_info.asset_index)
1460 meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
1461 return (meta_graph_def, exported_graph, object_saver, asset_info,
1462 saveable_view.nodes, saveable_view.node_paths)
1465def _build_meta_graph(obj, signatures, options, meta_graph_def=None):
1466 """Creates a MetaGraph under a save context.
1468 Args:
1469 obj: A trackable object to build the MetaGraph from.
1470 signatures: Can be a `tf.function` with an input signature specified or the
1471 result of `f.get_concrete_function` on a `@tf.function`-decorated function
1472 `f`. `signatures` may also be a dictionary, in which case it maps from
1473 signature keys to `tf.function` instances. If None, finds signature to
1474 export from the `@tf.function`-decorated methods in `obj`.
1475 options: `tf.saved_model.SaveOptions` object that specifies options for
1476 saving.
1477 meta_graph_def: Optional, the MetaGraphDef proto fill.
1479 Raises:
1480 AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
1481 ValueError: If `obj` is not trackable.
1483 Returns:
1484 meta_graph_def: Filled MetaGraphDef proto
1485 exported_graph: `tf.Graph` object generated from `obj`.
1486 object_saver: `checkpoint.TrackableSaver` of the `obj` and its dependencies.
1487 asset_info: `_AssetInfo` tuple containing external assets in the `obj`.
1488 saveable_view.nodes: _SaveableView nodes.
1489 saveable_view.node_paths: _SaveableView paths.
1490 """
1492 with save_context.save_context(options):
1493 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)