Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/loader_impl.py: 31%
170 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Loader implementation for SavedModel with hermetic, language-neutral exports.
16"""
18import os
19import sys
21from google.protobuf import message
22from google.protobuf import text_format
24from tensorflow.core.framework import graph_debug_info_pb2
25from tensorflow.core.protobuf import meta_graph_pb2
26from tensorflow.core.protobuf import saved_model_pb2
27from tensorflow.python.framework import ops
28from tensorflow.python.lib.io import file_io
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import tf_logging
31from tensorflow.python.saved_model import constants
32from tensorflow.python.saved_model import path_helpers
33from tensorflow.python.saved_model import signature_def_utils
34from tensorflow.python.saved_model import utils_impl as saved_model_utils
35from tensorflow.python.saved_model.pywrap_saved_model import metrics
36from tensorflow.python.training import saver as tf_saver
37from tensorflow.python.util import compat
38from tensorflow.python.util import deprecation
39from tensorflow.python.util.tf_export import tf_export
41# API label for SavedModel metrics.
42_LOADER_LABEL = "loader"
45def parse_saved_model_with_debug_info(export_dir):
46 """Reads the savedmodel as well as the graph debug info.
48 Args:
49 export_dir: Directory containing the SavedModel and GraphDebugInfo files.
51 Returns:
52 `SavedModel` and `GraphDebugInfo` protocol buffers.
54 Raises:
55 IOError: If the saved model file does not exist, or cannot be successfully
56 parsed. Missing graph debug info file is fine.
57 """
58 saved_model = parse_saved_model(export_dir)
60 debug_info_path = file_io.join(
61 path_helpers.get_debug_dir(export_dir),
62 constants.DEBUG_INFO_FILENAME_PB)
63 debug_info = graph_debug_info_pb2.GraphDebugInfo()
64 if file_io.file_exists(debug_info_path):
65 with file_io.FileIO(debug_info_path, "rb") as debug_file:
66 try:
67 debug_info.ParseFromString(debug_file.read())
68 except message.DecodeError as e:
69 raise IOError(f"Cannot parse file {debug_info_path}: {e}.")
71 return (saved_model, debug_info)
74@tf_export("__internal__.saved_model.parse_saved_model", v1=[])
75def parse_saved_model(export_dir):
76 """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
78 Args:
79 export_dir: String or Pathlike, path to the directory containing the
80 SavedModel file.
82 Returns:
83 A `SavedModel` protocol buffer.
85 Raises:
86 IOError: If the file does not exist, or cannot be successfully parsed.
87 """
88 # Build the path to the SavedModel in pbtxt format.
89 path_to_pbtxt = file_io.join(
90 compat.as_bytes(compat.path_to_str(export_dir)),
91 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
92 # Build the path to the SavedModel in pb format.
93 path_to_pb = file_io.join(
94 compat.as_bytes(compat.path_to_str(export_dir)),
95 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
97 # Parse the SavedModel protocol buffer.
98 saved_model = saved_model_pb2.SavedModel()
99 if file_io.file_exists(path_to_pb):
100 with file_io.FileIO(path_to_pb, "rb") as f:
101 file_content = f.read()
102 try:
103 saved_model.ParseFromString(file_content)
104 return saved_model
105 except message.DecodeError as e:
106 raise IOError(f"Cannot parse file {path_to_pb}: {str(e)}.")
107 elif file_io.file_exists(path_to_pbtxt):
108 with file_io.FileIO(path_to_pbtxt, "rb") as f:
109 file_content = f.read()
110 try:
111 text_format.Merge(file_content.decode("utf-8"), saved_model)
112 return saved_model
113 except text_format.ParseError as e:
114 raise IOError(f"Cannot parse file {path_to_pbtxt}: {str(e)}.")
115 else:
116 raise IOError(
117 f"SavedModel file does not exist at: {export_dir}{os.path.sep}"
118 f"{{{constants.SAVED_MODEL_FILENAME_PBTXT}|"
119 f"{constants.SAVED_MODEL_FILENAME_PB}}}")
122def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
123 """Gets the asset tensors, if defined in the meta graph def to load.
125 Args:
126 export_dir: Directory where the SavedModel is located.
127 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
128 import_scope: Optional `string` -- if specified, prepend this followed by
129 '/' to all returned asset tensor names.
131 Returns:
132 A dictionary of asset tensors, keyed by the name of the asset tensor. The
133 value in the map corresponds to the absolute path of the asset file.
134 """
135 # Collection-def that may contain the assets key.
136 collection_def = meta_graph_def_to_load.collection_def
138 asset_tensor_dict = {}
139 asset_protos = []
141 if meta_graph_def_to_load.asset_file_def:
142 asset_protos = meta_graph_def_to_load.asset_file_def
143 elif constants.ASSETS_KEY in collection_def:
144 assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
145 for asset_any_proto in assets_any_proto:
146 asset_proto = meta_graph_pb2.AssetFileDef()
147 asset_any_proto.Unpack(asset_proto)
148 asset_protos.append(asset_proto)
150 # Location of the assets for SavedModel.
151 assets_directory = file_io.join(
152 compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY))
153 # Process each asset and add it to the asset tensor dictionary.
154 for asset_proto in asset_protos:
155 tensor_name = asset_proto.tensor_info.name
156 if import_scope:
157 tensor_name = "%s/%s" % (import_scope, tensor_name)
158 asset_tensor_dict[tensor_name] = file_io.join(
159 compat.as_bytes(assets_directory),
160 compat.as_bytes(asset_proto.filename))
162 return asset_tensor_dict
165def _get_main_op_tensor(
166 meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
167 """Gets the main op tensor, if one exists.
169 Args:
170 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
171 init_op_key: name of the collection to check; should be one of MAIN_OP_KEY
172 or the deprecated LEGACY_INIT_OP_KEY
174 Returns:
175 The main op tensor, if it exists and `None` otherwise.
177 Raises:
178 RuntimeError: If the collection def corresponding to the main op key has
179 other than exactly one tensor.
180 """
181 # TODO(kathywu): Rename this method to _get_op_from_collection when
182 # dependency from SavedModelEstimator is removed.
183 collection_def = meta_graph_def_to_load.collection_def
184 init_op = None
185 if init_op_key in collection_def:
186 init_op_list = collection_def[init_op_key].node_list.value
187 if len(init_op_list) != 1:
188 raise RuntimeError("Expected exactly one SavedModel init op. "
189 f"Found {len(init_op_list)}: {init_op_list}.")
190 init_op = ops.get_collection(init_op_key)[0]
191 return init_op
194def _get_op_from_collection(meta_graph_def, op_key):
195 return _get_main_op_tensor(meta_graph_def, op_key)
198def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
199 """Retrieve op stored in the imported meta graph's signature def."""
200 if op_signature_key in meta_graph_def.signature_def:
201 return signature_def_utils.load_op_from_signature_def(
202 meta_graph_def.signature_def[op_signature_key], op_signature_key,
203 import_scope)
204 else:
205 return None
208def get_init_op(meta_graph_def, import_scope=None):
209 return (_get_op_from_signature_def(
210 meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
211 _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
212 _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
215def get_train_op(meta_graph_def, import_scope=None):
216 train_op = _get_op_from_signature_def(
217 meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
218 if train_op is None:
219 train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
220 return train_op
223@tf_export(v1=[
224 "saved_model.contains_saved_model",
225 "saved_model.maybe_saved_model_directory",
226 "saved_model.loader.maybe_saved_model_directory"
227])
228@deprecation.deprecated_endpoints(
229 "saved_model.loader.maybe_saved_model_directory")
230def maybe_saved_model_directory(export_dir):
231 """Checks whether the provided export directory could contain a SavedModel.
233 Note that the method does not load any data by itself. If the method returns
234 `false`, the export directory definitely does not contain a SavedModel. If the
235 method returns `true`, the export directory may contain a SavedModel but
236 provides no guarantee that it can be loaded.
238 Args:
239 export_dir: Absolute string path to possible export location. For example,
240 '/my/foo/model'.
242 Returns:
243 True if the export directory contains SavedModel files, False otherwise.
244 """
245 txt_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
246 pb_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
247 return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
250@tf_export("saved_model.contains_saved_model", v1=[])
251def contains_saved_model(export_dir):
252 """Checks whether the provided export directory could contain a SavedModel.
254 Note that the method does not load any data by itself. If the method returns
255 `false`, the export directory definitely does not contain a SavedModel. If the
256 method returns `true`, the export directory may contain a SavedModel but
257 provides no guarantee that it can be loaded.
259 Args:
260 export_dir: Absolute path to possible export location. For example,
261 '/my/foo/model'.
263 Returns:
264 True if the export directory contains SavedModel files, False otherwise.
265 """
266 if isinstance(export_dir, os.PathLike):
267 export_dir = os.fspath(export_dir)
268 return maybe_saved_model_directory(export_dir)
271@tf_export(v1=["saved_model.load", "saved_model.loader.load"])
272@deprecation.deprecated(
273 None,
274 "Use `tf.saved_model.load` instead.")
275def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
276 """Loads the model from a SavedModel as specified by tags.
278 Args:
279 sess: The TensorFlow session to restore the variables.
280 tags: Set of string tags to identify the required MetaGraphDef. These should
281 correspond to the tags used when saving the variables using the
282 SavedModel `save()` API.
283 export_dir: Directory in which the SavedModel protocol buffer and variables
284 to be loaded are located.
285 import_scope: Optional `string` -- if specified, prepend this string
286 followed by '/' to all loaded tensor names. This scope is applied to
287 tensor instances loaded into the passed session, but it is *not* written
288 through to the static `MetaGraphDef` protocol buffer that is returned.
289 **saver_kwargs: Optional keyword arguments passed through to Saver.
291 Returns:
292 The `MetaGraphDef` protocol buffer loaded in the provided session. This
293 can be used to further extract signature-defs, collection-defs, etc.
295 Raises:
296 RuntimeError: MetaGraphDef associated with the tags cannot be found.
298 @compatibility(TF2)
300 `tf.compat.v1.saved_model.load` or `tf.compat.v1.saved_model.loader.load` is
301 not compatible with eager execution. Please use `tf.saved_model.load` instead
302 to load your model. You can refer to the [SavedModel guide]
303 (https://www.tensorflow.org/guide/saved_model) for more information as well as
304 "Importing SavedModels from TensorFlow 1.x" in the [`tf.saved_model.load`]
305 (https://www.tensorflow.org/api_docs/python/tf/saved_model/load) docstring.
307 #### How to Map Arguments
309 | TF1 Arg Name | TF2 Arg Name | Note |
310 | :-------------------- | :-------------- | :------------------------- |
311 | `sess` | Not supported | - |
312 | `tags` | `tags` | - |
313 | `export_dir` | `export_dir` | - |
314 | `import_scope` | Not supported | Name scopes are not needed.
315 : : : By default, variables are :
316 : : : associated with the loaded :
317 : : : object and function names :
318 : : : are deduped. :
319 | `saver_kwargs` | Not supported | - |
321 #### Before & After Usage Example
323 Before:
325 ```
326 with tf.compat.v1.Session(graph=tf.Graph()) as sess:
327 tf.compat.v1.saved_model.loader.load(sess, ["foo-tag"], export_dir)
328 ```
330 After:
332 ```
333 model = tf.saved_model.load(export_dir, tags=["foo-tag"])
334 ```
335 @end_compatibility
336 """
337 loader = SavedModelLoader(export_dir)
338 return loader.load(sess, tags, import_scope, **saver_kwargs)
341class SavedModelLoader(object):
342 """Load graphs and restore variable values from a `SavedModel`."""
344 def __init__(self, export_dir):
345 """Creates a `SavedModelLoader`.
347 Args:
348 export_dir: Directory in which the SavedModel protocol buffer and
349 variables to be loaded are located.
350 """
351 self._export_dir = export_dir
352 self._variables_path = path_helpers.get_variables_path(export_dir)
353 self._saved_model = parse_saved_model(export_dir)
355 @property
356 def export_dir(self):
357 """Directory containing the SavedModel."""
358 return self._export_dir
360 @property
361 def variables_path(self):
362 """Path to variable checkpoint files."""
363 return self._variables_path
365 @property
366 def saved_model(self):
367 """SavedModel object parsed from the export directory."""
368 return self._saved_model
370 def get_meta_graph_def_from_tags(self, tags):
371 """Return MetaGraphDef with the exact specified tags.
373 Args:
374 tags: A list or set of string tags that identify the MetaGraphDef.
376 Returns:
377 MetaGraphDef with the same tags.
379 Raises:
380 RuntimeError: if no metagraphs were found with the associated tags.
381 """
382 found_match = False
383 available_tags = []
384 for meta_graph_def in self._saved_model.meta_graphs:
385 available_tags.append(set(meta_graph_def.meta_info_def.tags))
386 if set(meta_graph_def.meta_info_def.tags) == set(tags):
387 meta_graph_def_to_load = meta_graph_def
388 found_match = True
389 break
391 if not found_match:
392 raise RuntimeError(
393 f"MetaGraphDef associated with tags {str(tags).strip('[]')} "
394 "could not be found in SavedModel, with available tags "
395 f"'{available_tags}'. To inspect available tag-sets in"
396 " the SavedModel, please use the SavedModel CLI: `saved_model_cli`.")
397 return meta_graph_def_to_load
399 def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
400 """Load ops and nodes from SavedModel MetaGraph into graph.
402 Args:
403 graph: tf.Graph object.
404 tags: a set of string tags identifying a MetaGraphDef.
405 import_scope: Optional `string` -- if specified, prepend this string
406 followed by '/' to all loaded tensor names. This scope is applied to
407 tensor instances loaded into the passed session, but it is *not* written
408 through to the static `MetaGraphDef` protocol buffer that is returned.
409 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
411 Returns:
412 A tuple of
413 * Saver defined by the MetaGraph, which can be used to restore the
414 variable values.
415 * List of `Operation`/`Tensor` objects returned from
416 `tf.import_graph_def` (may be `None`).
417 """
418 meta_graph_def = self.get_meta_graph_def_from_tags(tags)
419 if sys.byteorder == "big":
420 saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
421 "big")
422 with graph.as_default():
423 return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access
424 meta_graph_def, import_scope=import_scope, **saver_kwargs)
426 def restore_variables(self, sess, saver, import_scope=None):
427 """Restore SavedModel variable values into the session.
429 Args:
430 sess: tf.compat.v1.Session to restore variable values.
431 saver: a tf.compat.v1.train.Saver object. Can be None if there are no
432 variables in graph. This may be the saver returned by the load_graph()
433 function, or a default `tf.compat.v1.train.Saver()`.
434 import_scope: Optional `string` -- if specified, prepend this string
435 followed by '/' to all loaded tensor names. This scope is applied to
436 tensor instances loaded into the passed session, but it is *not* written
437 through to the static `MetaGraphDef` protocol buffer that is returned.
439 Raises:
440 ValueError: if no saver was passed to the saver argument, and there are
441 variables in the graph.
442 """
443 with sess.graph.as_default():
444 if (saver is None and
445 not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access
446 tf_logging.info("The specified SavedModel has no variables; no "
447 "checkpoints were restored.")
448 elif isinstance(saver, tf_saver.Saver):
449 saver.restore(sess, self._variables_path)
450 else:
451 raise ValueError(
452 "No tf.train.Saver object was passed to the function "
453 "`SavedModelLoader.restore_variables`. Since there are variables in"
454 " the graph, a saver is required.")
456 def run_init_ops(self, sess, tags, import_scope=None):
457 """Run initialization ops defined in the `MetaGraphDef`.
459 Args:
460 sess: tf.compat.v1.Session to restore variable values.
461 tags: a set of string tags identifying a MetaGraphDef.
462 import_scope: Optional `string` -- if specified, prepend this string
463 followed by '/' to all loaded tensor names. This scope is applied to
464 tensor instances loaded into the passed session, but it is *not* written
465 through to the static `MetaGraphDef` protocol buffer that is returned.
466 """
467 meta_graph_def = self.get_meta_graph_def_from_tags(tags)
468 with sess.graph.as_default():
469 # Get asset tensors, if any.
470 asset_tensors_dictionary = get_asset_tensors(
471 self._export_dir, meta_graph_def, import_scope=import_scope)
473 init_op = get_init_op(meta_graph_def, import_scope)
474 if init_op is not None:
475 sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
477 def load(self, sess, tags, import_scope=None, **saver_kwargs):
478 """Load the MetaGraphDef graph and restore variable values into the session.
480 Args:
481 sess: tf.compat.v1.Session to restore variable values.
482 tags: a set of string tags identifying a MetaGraphDef.
483 import_scope: Optional `string` -- if specified, prepend this string
484 followed by '/' to all loaded tensor names. This scope is applied to
485 tensor instances loaded into the passed session, but it is *not* written
486 through to the static `MetaGraphDef` protocol buffer that is returned.
487 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
489 Returns:
490 `MetagraphDef` proto of the graph that was loaded.
491 """
492 saved_model_proto = parse_saved_model(self._export_dir)
493 metrics.IncrementReadApi(_LOADER_LABEL)
495 with sess.graph.as_default():
496 saver, _ = self.load_graph(sess.graph, tags, import_scope,
497 **saver_kwargs)
498 self.restore_variables(sess, saver, import_scope)
499 self.run_init_ops(sess, tags, import_scope)
500 meta_graph_def = self.get_meta_graph_def_from_tags(tags)
502 if (len(saved_model_proto.meta_graphs) == 1 and
503 saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
504 metrics.IncrementRead(write_version="2")
505 else:
506 metrics.IncrementRead(write_version="1")
508 return meta_graph_def