Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/load_v1_in_v2.py: 23%
151 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Import a TF v1-style SavedModel when executing eagerly."""
17import functools
19from tensorflow.python.eager import context
20from tensorflow.python.eager import lift_to_graph
21from tensorflow.python.eager import wrap_function
22from tensorflow.python.framework import composite_tensor
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import func_graph
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.saved_model import function_deserialization
29from tensorflow.python.saved_model import loader_impl
30from tensorflow.python.saved_model import signature_serialization
31from tensorflow.python.saved_model.pywrap_saved_model import metrics
32from tensorflow.python.trackable import asset
33from tensorflow.python.trackable import autotrackable
34from tensorflow.python.trackable import resource
35from tensorflow.python.training import monitored_session
36from tensorflow.python.training import saver as tf_saver
37from tensorflow.python.util import nest
39# API label for SavedModel metrics.
40_LOAD_V1_V2_LABEL = "load_v1_in_v2"
43class _Initializer(resource.CapturableResource):
44 """Represents an initialization operation restored from a SavedModel.
46 Without this object re-export of imported 1.x SavedModels would omit the
47 original SavedModel's initialization procedure.
49 Created when `tf.saved_model.load` loads a TF 1.x-style SavedModel with an
50 initialization op. This object holds a function that runs the
51 initialization. It does not require any manual user intervention;
52 `tf.saved_model.save` will see this object and automatically add it to the
53 exported SavedModel, and `tf.saved_model.load` runs the initialization
54 function automatically.
55 """
57 def __init__(self, init_fn, asset_paths):
58 super(_Initializer, self).__init__()
59 self._asset_paths = asset_paths
60 self._init_fn = init_fn
62 def _create_resource(self):
63 # Return a constant here so that when re-saved, the traced `create_resource`
64 # has valid returns.
65 return constant_op.constant(1.0)
67 def _initialize(self):
68 return self._init_fn(*[path.asset_path for path in self._asset_paths])
71class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
72 """Loads a SavedModel without using Sessions."""
74 def get_meta_graph_def_from_tags(self, tags):
75 """Override to support implicit one-MetaGraph loading with tags=None."""
76 if tags is None:
77 if len(self._saved_model.meta_graphs) != 1:
78 tag_sets = [
79 mg.meta_info_def.tags for mg in self._saved_model.meta_graphs
80 ]
81 raise ValueError(
82 "Importing a SavedModel with `tf.saved_model.load` requires a "
83 "`tags=` argument if there is more than one MetaGraph. Got "
84 f"`tags=None`, but there are {len(self._saved_model.meta_graphs)} "
85 f"MetaGraphs in the SavedModel with tag sets: {tag_sets}. Pass a "
86 "`tags=` argument to load this SavedModel."
87 )
88 return self._saved_model.meta_graphs[0]
89 return super(_EagerSavedModelLoader, self).get_meta_graph_def_from_tags(
90 tags
91 )
93 def load_graph(self, returns, meta_graph_def):
94 """Called from wrap_function to import `meta_graph_def`."""
95 # pylint: disable=protected-access
96 saver, _ = tf_saver._import_meta_graph_with_return_elements(meta_graph_def)
97 # pylint: enable=protected-access
98 returns[0] = saver
100 def _extract_saver_restore(self, wrapped, saver):
101 if saver is None:
102 return None
103 saver_def = saver.saver_def
104 filename_tensor = wrapped.graph.as_graph_element(
105 saver_def.filename_tensor_name
106 )
107 # We both feed and fetch filename_tensor so we have an operation to use to
108 # feed into variable initializers (only relevant for v1 graph building).
109 return wrapped.prune(
110 feeds=[filename_tensor],
111 fetches=[
112 filename_tensor,
113 wrapped.graph.as_graph_element(saver_def.restore_op_name),
114 ],
115 )
117 def restore_variables(self, wrapped, restore_from_saver):
118 """Restores variables from the checkpoint."""
119 if restore_from_saver is not None:
120 initializer, _ = restore_from_saver(
121 constant_op.constant(self._variables_path)
122 )
123 if not ops.executing_eagerly_outside_functions():
124 # Add the initialization operation to the "saved_model_initializers"
125 # collection in case we don't have any lifted variables to attach it to.
126 ops.add_to_collection("saved_model_initializers", initializer)
127 one_unlifted = False
129 for variable in wrapped.graph.get_collection_ref(
130 ops.GraphKeys.GLOBAL_VARIABLES
131 ):
132 if variable.graph is wrapped.graph:
133 one_unlifted = True
134 # pylint: disable=protected-access
135 variable._initializer_op = initializer
136 # pylint: enable=protected-access
137 if one_unlifted:
138 logging.warning(
139 "Some variables could not be lifted out of a loaded function. "
140 "Please run "
141 '`sess.run(tf.get_collection("saved_model_initializers"))`to '
142 "restore these variables."
143 )
145 def _extract_signatures(self, wrapped, meta_graph_def):
146 """Creates ConcreteFunctions for signatures in `meta_graph_def`."""
147 signature_functions = {}
148 for signature_key, signature_def in meta_graph_def.signature_def.items():
149 if signature_def.inputs:
150 input_items = sorted(
151 signature_def.inputs.items(), key=lambda item: item[0]
152 )
153 original_input_names, input_specs = zip(*input_items)
154 else:
155 original_input_names = []
156 input_specs = []
157 # TODO(b/205015292): Support optional arguments
158 feeds = [
159 wrap_function._get_element_from_tensor_info(input_spec, wrapped.graph) # pylint: disable=protected-access
160 for input_spec in input_specs
161 ]
162 input_names = []
163 input_tensors = []
164 for original_input_name, feed in zip(original_input_names, feeds):
165 if isinstance(feed, sparse_tensor.SparseTensor):
166 # We have to give explicit name for SparseTensor arguments, because
167 # these are not present in the TensorInfo.
168 indices_name = "%s_indices" % original_input_name
169 values_name = "%s_values" % original_input_name
170 dense_shape_name = "%s_dense_shape" % original_input_name
171 input_names.extend([indices_name, values_name, dense_shape_name])
172 input_tensors.extend([feed.indices, feed.values, feed.dense_shape])
173 elif isinstance(feed, composite_tensor.CompositeTensor):
174 component_tensors = nest.flatten(feed, expand_composites=True)
175 input_names.extend(
176 "%s_component_%d" % (original_input_name, n)
177 for n in range(len(component_tensors))
178 )
179 input_tensors.extend(component_tensors)
180 else:
181 input_names.append(original_input_name)
182 input_tensors.append(feed)
183 fetches = {name: out for name, out in signature_def.outputs.items()}
184 try:
185 signature_fn = wrapped.prune(feeds=feeds, fetches=fetches)
186 except lift_to_graph.UnliftableError as ex:
187 # Mutate the exception to add a bit more detail.
188 args = ex.args
189 if not args:
190 message = ""
191 else:
192 message = args[0]
193 message = (
194 "A SavedModel signature needs an input for each placeholder the "
195 "signature's outputs use. An output for signature '{}' depends on "
196 "a placeholder which is not an input (i.e. the placeholder is not "
197 "fed a value).\n\n"
198 ).format(signature_key) + message
199 ex.args = (message,) + args[1:]
200 raise
201 # pylint: disable=protected-access
202 signature_fn._arg_keywords = input_names
203 signature_fn._func_graph.structured_input_signature = (
204 (),
205 func_graph.convert_structure_to_signature(
206 dict(zip(input_names, input_tensors))
207 ),
208 )
210 if len(input_names) == 1:
211 # Allowing positional arguments does not create any ambiguity if there's
212 # only one.
213 signature_fn._num_positional_args = 1
214 else:
215 signature_fn._num_positional_args = 0
216 # pylint: enable=protected-access
217 signature_functions[signature_key] = signature_fn
218 return signature_functions
220 def load(self, tags):
221 """Creates an object from the MetaGraph identified by `tags`."""
222 meta_graph_def = self.get_meta_graph_def_from_tags(tags)
223 load_shared_name_suffix = "_load_{}".format(ops.uid())
224 functions = function_deserialization.load_function_def_library(
225 meta_graph_def.graph_def.library,
226 load_shared_name_suffix=load_shared_name_suffix,
227 )
228 # Replace existing functions in the MetaGraphDef with renamed functions so
229 # we don't have duplicates or name collisions.
230 meta_graph_def.graph_def.library.Clear()
231 for function in functions.values():
232 meta_graph_def.graph_def.library.function.add().CopyFrom(
233 function.function_def
234 )
235 # We've renamed functions and shared names. We need the same operation on
236 # the GraphDef itself for consistency.
237 for node_def in meta_graph_def.graph_def.node:
238 function_deserialization.fix_node_def(
239 node_def, functions, load_shared_name_suffix
240 )
242 load_graph_returns = [None]
243 wrapped = wrap_function.wrap_function(
244 functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
245 signature=[],
246 )
247 (saver,) = load_graph_returns
248 restore_from_saver = self._extract_saver_restore(wrapped, saver)
249 self.restore_variables(wrapped, restore_from_saver)
250 with wrapped.graph.as_default():
251 init_op = (
252 loader_impl.get_init_op(meta_graph_def)
253 or monitored_session.Scaffold.default_local_init_op()
254 )
255 # Add a dummy Tensor we know we can fetch to add control dependencies to.
256 init_anchor = constant_op.constant(0.0, name="dummy_fetch")
258 root = autotrackable.AutoTrackable()
259 if restore_from_saver is not None:
260 root.restore = lambda path: restore_from_saver(constant_op.constant(path))
261 asset_feed_tensors = []
262 asset_paths = []
263 for tensor_name, value in loader_impl.get_asset_tensors(
264 self._export_dir, meta_graph_def
265 ).items():
266 asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
267 asset_paths.append(asset.Asset(value))
268 init_fn = wrapped.prune(
269 feeds=asset_feed_tensors,
270 fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)],
271 )
272 initializer = _Initializer(init_fn, asset_paths)
273 # pylint: disable=protected-access
274 local_init_op, _ = initializer._initialize()
275 # pylint: enable=protected-access
276 with ops.init_scope():
277 if not context.executing_eagerly():
278 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op)
279 for variable in wrapped.graph.get_collection_ref(
280 ops.GraphKeys.LOCAL_VARIABLES
281 ):
282 # pylint: disable=protected-access
283 variable._initializer_op = local_init_op
284 # pylint: enable=protected-access
285 root.initializer = initializer
286 root.asset_paths = asset_paths
287 signature_functions = self._extract_signatures(wrapped, meta_graph_def)
289 root.signatures = signature_serialization.create_signature_map(
290 signature_functions
291 )
292 root.variables = list(wrapped.graph.variables)
293 root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
294 root.tensorflow_git_version = (
295 meta_graph_def.meta_info_def.tensorflow_git_version
296 )
297 root.graph = wrapped.graph
298 root.prune = wrapped.prune
299 return root
302def load(export_dir, tags):
303 """Load a v1-style SavedModel as an object."""
304 metrics.IncrementReadApi(_LOAD_V1_V2_LABEL)
305 loader = _EagerSavedModelLoader(export_dir)
306 result = loader.load(tags=tags)
307 metrics.IncrementRead(write_version="1")
308 return result