Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/builder_impl.py: 23%
239 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel builder implementation."""
17import functools
18import os
20from google.protobuf.any_pb2 import Any
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.protobuf import meta_graph_pb2
24from tensorflow.core.protobuf import saved_model_pb2
25from tensorflow.core.protobuf import saver_pb2
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.lib.io import file_io
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import tf_logging
31from tensorflow.python.saved_model import fingerprinting_utils
32from tensorflow.python.saved_model import path_helpers
33from tensorflow.python.saved_model import signature_def_utils
34from tensorflow.python.saved_model.pywrap_saved_model import constants
35from tensorflow.python.saved_model.pywrap_saved_model import metrics
36from tensorflow.python.training import saver as tf_saver
37from tensorflow.python.util import compat
38from tensorflow.python.util.deprecation import deprecated_args
39from tensorflow.python.util.tf_export import tf_export
41# API label for SavedModel metrics.
42_SAVE_BUILDER_LABEL = "save_v1_builder"
45# Base class for the SavedModelBuilder that is only used by Tensorflow
46# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead.
47@tf_export("__internal__.saved_model.SavedModelBuilder", v1=[])
48class _SavedModelBuilder(object):
49 """Builds the `SavedModel` protocol buffer and saves variables and assets.
51 The `SavedModelBuilder` class provides the functionality to build a
52 `SavedModel` protocol buffer. Specifically, this allows multiple meta
53 graphs to be saved as part of a single language-neutral `SavedModel`,
54 while sharing variables and assets.
56 To build a SavedModel, the first meta graph must be saved with variables.
57 Subsequent meta graphs will simply be saved with their graph definitions. If
58 assets need to be saved and written or copied to disk, they can be provided
59 when the meta graph def is added. If multiple meta graph defs are associated
60 an asset of the same name, only the first version is retained.
62 Each meta graph added to the SavedModel must be annotated with tags. The tags
63 provide a means to identify the specific meta graph to load and restore, along
64 with the shared set of variables and assets.
66 Typical usage for the `SavedModelBuilder`:
68 ```python
69 ...
70 builder = tf.compat.v1.saved_model.Builder(export_dir)
72 with tf.compat.v1.Session(graph=tf.Graph()) as sess:
73 ...
74 builder.add_meta_graph_and_variables(sess,
75 ["foo-tag"],
76 signature_def_map=foo_signatures,
77 assets_list=foo_assets)
78 ...
80 with tf.compat.v1.Session(graph=tf.Graph()) as sess:
81 ...
82 builder.add_meta_graph(["bar-tag", "baz-tag"])
83 ...
85 builder.save()
86 ```
88 Note: This function will only be available through the v1 compatibility
89 library as tf.compat.v1.saved_model.builder.SavedModelBuilder or
90 tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new
91 object-based method of creating SavedModels.
92 """
94 def __init__(self, export_dir):
95 self._saved_model = saved_model_pb2.SavedModel()
96 self._saved_model.saved_model_schema_version = (
97 constants.SAVED_MODEL_SCHEMA_VERSION)
99 self._export_dir = export_dir
100 if file_io.file_exists(export_dir):
101 if file_io.list_directory(export_dir):
102 raise AssertionError(
103 f"Export directory {export_dir} already exists, and isn't empty. "
104 "Please choose a different export directory, or delete all the "
105 "contents of the specified directory.")
106 else:
107 file_io.recursive_create_dir(self._export_dir)
109 # Boolean to track whether variables and assets corresponding to the
110 # SavedModel have been saved. Specifically, the first meta graph to be added
111 # MUST use the add_meta_graph_and_variables() API. Subsequent add operations
112 # on the SavedModel MUST use the add_meta_graph() API which does not save
113 # weights.
114 self._has_saved_variables = False
115 self._saved_asset_files = set()
117 def _save_and_write_assets(self, meta_graph_def, assets_list=None):
118 """Saves asset to the meta graph and writes asset files to disk.
120 Args:
121 meta_graph_def: The meta graph def to which the assets will be added.
122 assets_list: The list where the asset paths are setup.
123 """
124 # Creates a function that adds assets into the meta graph def.
125 write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def)
126 asset_filename_map = _maybe_save_assets(write_fn, assets_list)
128 # Return if there are no assets to write.
129 if not asset_filename_map:
130 tf_logging.info("No assets to write.")
131 return
133 # Copy assets from source path to destination path.
134 copy_assets_to_destination_dir(asset_filename_map, self._export_dir,
135 self._saved_asset_files)
137 def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
138 """Tags the meta graph def and adds it to the SavedModel.
140 Tags the meta graph def with the supplied tags, adds signature defs to it if
141 provided and appends the meta graph def to the SavedModel proto.
143 Args:
144 meta_graph_def: The meta graph def to add to the SavedModel.
145 tags: The set of tags to annotate the meta graph def with.
146 signature_def_map: The map of signature defs to be added to the meta graph
147 def.
148 """
149 for tag in tags:
150 meta_graph_def.meta_info_def.tags.append(tag)
152 if signature_def_map is not None:
153 for key in signature_def_map:
154 meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key])
156 proto_meta_graph_def = self._saved_model.meta_graphs.add()
157 proto_meta_graph_def.CopyFrom(meta_graph_def)
159 def _validate_tensor_info(self, tensor_info):
160 """Validates the `TensorInfo` proto.
162 Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and
163 `dtype` fields exist and are non-empty.
165 Args:
166 tensor_info: `TensorInfo` protocol buffer to validate.
168 Raises:
169 AssertionError: If the `encoding` or `dtype` fields of the supplied
170 `TensorInfo` proto are not populated.
171 """
172 if tensor_info is None:
173 raise AssertionError(
174 "All TensorInfo protos used in the SignatureDefs must have the name "
175 "and dtype fields set.")
176 if tensor_info.WhichOneof("encoding") is None:
177 # TODO(soergel) validate each of the fields of coo_sparse
178 raise AssertionError(
179 f"Invalid `tensor_info`: {tensor_info}. All TensorInfo protos used "
180 "in the SignatureDefs must have one of the 'encoding' fields (e.g., "
181 "name or coo_sparse) set.")
182 if tensor_info.WhichOneof("encoding") == "composite_tensor":
183 for component in tensor_info.composite_tensor.components:
184 self._validate_tensor_info(component)
185 elif tensor_info.dtype == types_pb2.DT_INVALID:
186 raise AssertionError(
187 f"Invalid `tensor_info`: {tensor_info}. All TensorInfo protos used in"
188 " the SignatureDefs must have the dtype field set.")
190 def _validate_signature_def_map(self, signature_def_map):
191 """Validates the `SignatureDef` entries in the signature def map.
193 Validation of entries in the signature def map includes ensuring that the
194 `name` and `dtype` fields of the TensorInfo protos of the `inputs` and
195 `outputs` of each `SignatureDef` are populated. Also ensures that reserved
196 SignatureDef keys for the initialization and train ops are not used.
198 Args:
199 signature_def_map: The map of signature defs to be validated.
201 Raises:
202 AssertionError: If a TensorInfo is not valid.
203 KeyError: If a reserved signature key is used in the map.
204 """
205 for signature_def_key in signature_def_map:
206 signature_def = signature_def_map[signature_def_key]
207 inputs = signature_def.inputs
208 outputs = signature_def.outputs
209 for inputs_key in inputs:
210 self._validate_tensor_info(inputs[inputs_key])
211 for outputs_key in outputs:
212 self._validate_tensor_info(outputs[outputs_key])
213 if constants.INIT_OP_SIGNATURE_KEY in signature_def_map:
214 raise KeyError(
215 f"SignatureDef map key \"{constants.INIT_OP_SIGNATURE_KEY}\" is "
216 "reserved for initialization. Please use a different key.")
217 if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map:
218 raise KeyError(
219 f"SignatureDef map key \"{constants.TRAIN_OP_SIGNATURE_KEY}\" is "
220 f"reserved for the train op. Please use a different key.")
222 def _maybe_create_saver(self, saver=None):
223 """Creates a sharded saver if one does not already exist."""
224 if not saver:
225 # Initialize a saver to generate a sharded output for all saveables in the
226 # current scope.
227 saver = tf_saver.Saver(
228 variables._all_saveable_objects(), # pylint: disable=protected-access
229 sharded=True,
230 write_version=saver_pb2.SaverDef.V2,
231 allow_empty=True)
232 return saver
234 def add_meta_graph(self,
235 tags,
236 signature_def_map=None,
237 assets_list=None,
238 clear_devices=False,
239 init_op=None,
240 train_op=None,
241 saver=None):
242 """Adds the current meta graph to the SavedModel.
244 Creates a Saver in the current scope and uses the Saver to export the meta
245 graph def. Invoking this API requires the `add_meta_graph_and_variables()`
246 API to have been invoked before.
248 Args:
249 tags: The set of tags to annotate the meta graph def with.
250 signature_def_map: The map of signature defs to be added to the meta graph
251 def.
252 assets_list: Assets to be saved with SavedModel. Note
253 that this list should be a subset of the assets saved as part of
254 the first meta graph in the SavedModel.
255 clear_devices: Set to true if the device info on the default graph should
256 be cleared.
257 init_op: Op or group of ops to execute when the graph is loaded. Note
258 that when the init_op is specified it is run after the restore op at
259 load-time.
260 train_op: Op or group of opts that trains the model when run. This will
261 not be run automatically when the graph is loaded, instead saved in
262 a SignatureDef accessible through the exported MetaGraph.
263 saver: An instance of tf.compat.v1.train.Saver that will be used to export
264 the metagraph. If None, a sharded Saver that restores all variables will
265 be used.
267 Raises:
268 AssertionError: If the variables for the SavedModel have not been saved
269 yet, or if the graph already contains one or more legacy init ops.
270 """
271 if not self._has_saved_variables:
272 raise AssertionError(
273 "Graph state including variables and assets has not been saved yet. "
274 "Please invoke `add_meta_graph_and_variables()` first.")
276 # Validate the signature def map to ensure all included TensorInfos are
277 # properly populated.
278 signature_def_map = signature_def_map or {}
279 self._validate_signature_def_map(signature_def_map)
281 # Create a SignatureDef pointing to the graph initialization op, which will
282 # be added to the MetaGraphDef.
283 _add_op_to_signature_def_map(signature_def_map, init_op,
284 constants.INIT_OP_SIGNATURE_KEY)
285 _add_op_to_signature_def_map(signature_def_map, train_op,
286 constants.TRAIN_OP_SIGNATURE_KEY)
288 saver = self._maybe_create_saver(saver)
290 # The graph almost certainly previously contained at least one Saver, and
291 # possibly several (e.g. one for loading a pretrained embedding, and another
292 # for the model weights). Removing the preexisting ones was the
293 # motivation for the clear_extraneous_savers option, but it turns out that
294 # there are edge cases where that option breaks the graph. Until that is
295 # resolved, we just leave the option set to False for now.
296 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
297 meta_graph_def = saver.export_meta_graph(
298 clear_devices=clear_devices, strip_default_attrs=True)
300 # Save asset files and write them to disk, if any.
301 self._save_and_write_assets(meta_graph_def, assets_list)
303 # Tag the meta graph def and add it to the SavedModel.
304 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
306 def add_meta_graph_and_variables(self,
307 sess,
308 tags,
309 signature_def_map=None,
310 assets_list=None,
311 clear_devices=False,
312 init_op=None,
313 train_op=None,
314 strip_default_attrs=False,
315 saver=None):
316 # pylint: disable=line-too-long
317 """Adds the current meta graph to the SavedModel and saves variables.
319 Creates a Saver to save the variables from the provided session. Exports the
320 corresponding meta graph def. This function assumes that the variables to be
321 saved have been initialized. For a given `SavedModelBuilder`, this API must
322 be called exactly once and for the first meta graph to save. For subsequent
323 meta graph defs to be added, the `add_meta_graph()` API must be used.
325 Args:
326 sess: The TensorFlow session from which to save the meta graph and
327 variables.
328 tags: The set of tags with which to save the meta graph.
329 signature_def_map: The map of signature def map to add to the meta graph
330 def.
331 assets_list: Assets to be saved with SavedModel.
332 clear_devices: Set to true if the device info on the default graph should
333 be cleared.
334 init_op: Op or group of ops to execute when the graph is loaded. Note
335 that when the init_op is specified it is run after the restore op at
336 load-time.
337 train_op: Op or group of ops that trains the model when run. This will
338 not be run automatically when the graph is loaded, instead saved in
339 a SignatureDef accessible through the exported MetaGraph.
340 strip_default_attrs: Boolean. If `True`, default-valued attributes will be
341 removed from the NodeDefs. For a detailed guide, see
342 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
343 saver: An instance of tf.compat.v1.train.Saver that will be used to export the
344 metagraph and save variables. If None, a sharded Saver that restores
345 all variables will be used.
347 """
348 # pylint: enable=line-too-long
349 if self._has_saved_variables:
350 raise AssertionError("Graph state including variables and assets has "
351 "already been saved. Please invoke "
352 "`add_meta_graph()` instead.")
354 # Validate the signature def map to ensure all included TensorInfos are
355 # properly populated.
356 signature_def_map = signature_def_map or {}
357 self._validate_signature_def_map(signature_def_map)
359 # Create a SignatureDef pointing to the graph initialization op, which will
360 # be added to the MetaGraphDef.
361 _add_op_to_signature_def_map(signature_def_map, init_op,
362 constants.INIT_OP_SIGNATURE_KEY)
363 _add_op_to_signature_def_map(signature_def_map, train_op,
364 constants.TRAIN_OP_SIGNATURE_KEY)
366 path_helpers.get_or_create_variables_dir(self._export_dir)
367 variables_path = path_helpers.get_variables_path(self._export_dir)
369 saver = self._maybe_create_saver(saver)
371 # Save the variables. Also, disable writing the checkpoint state proto. The
372 # file is not used during SavedModel loading. In addition, since a
373 # SavedModel can be copied or moved, this avoids the checkpoint state to
374 # become outdated.
375 saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
377 # Export the meta graph def.
379 # The graph almost certainly previously contained at least one Saver, and
380 # possibly several (e.g. one for loading a pretrained embedding, and another
381 # for the model weights). Removing the preexisting ones was the
382 # motivation for the clear_extraneous_savers option, but it turns out that
383 # there are edge cases where that option breaks the graph. Until that is
384 # resolved, we just leave the option set to False for now.
385 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
386 meta_graph_def = saver.export_meta_graph(
387 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
389 # Save asset files and write them to disk, if any.
390 self._save_and_write_assets(meta_graph_def, assets_list)
392 # Tag the meta graph def and add it to the SavedModel.
393 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
395 # Mark this instance of SavedModel as having saved variables, such that
396 # subsequent attempts to save variables will fail.
397 self._has_saved_variables = True
399 def save(self, as_text=False):
400 """Writes a `SavedModel` protocol buffer to disk.
402 The function writes the SavedModel protocol buffer to the export directory
403 in a serialized format.
405 Args:
406 as_text: Writes the SavedModel protocol buffer in text format to
407 disk. Protocol buffers in text format are useful for debugging, but
408 parsing fails when it encounters an unknown field and so is not forward
409 compatible. This means changes to TensorFlow may prevent deployment of
410 new text format SavedModels to existing serving binaries. Do not deploy
411 `as_text` SavedModels to production.
413 Returns:
414 The path to which the SavedModel protocol buffer was written.
415 """
416 metrics.IncrementWriteApi(_SAVE_BUILDER_LABEL)
417 if not file_io.file_exists(self._export_dir):
418 file_io.recursive_create_dir(self._export_dir)
420 saved_model_serialized = self._saved_model.SerializeToString(
421 deterministic=True)
423 if as_text:
424 path = file_io.join(
425 compat.as_bytes(self._export_dir),
426 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
427 file_io.write_string_to_file(path, str(self._saved_model))
428 else:
429 path = file_io.join(
430 compat.as_bytes(self._export_dir),
431 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
432 file_io.write_string_to_file(
433 path, saved_model_serialized)
434 tf_logging.info("SavedModel written to: %s", compat.as_text(path))
435 metrics.IncrementWrite(write_version="1")
437 # Placeholder for internal TF1 model fingerprint write
439 return path
442@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) # pylint: disable=missing-docstring
443class SavedModelBuilder(_SavedModelBuilder):
444 __doc__ = _SavedModelBuilder.__doc__.replace("assets_list",
445 "assets_collection")
447 def __init__(self, export_dir):
448 super(SavedModelBuilder, self).__init__(export_dir=export_dir)
450 def _add_collections(self, assets_collection, main_op, train_op):
451 """Add asset and op collections to be saved."""
452 # Save asset files and write them to disk, if any.
453 self._save_and_write_assets(assets_collection)
455 self._maybe_add_main_op(main_op)
457 self._add_train_op(train_op)
459 def _save_and_write_assets(self, assets_collection_to_add=None):
460 """Saves asset to the meta graph and writes asset files to disk.
462 Args:
463 assets_collection_to_add: The collection where the asset paths are setup.
464 """
465 # Add assets to the collection with key `saved_model.ASSETS_KEY`, in the
466 # graph.
467 asset_filename_map = _maybe_save_assets(_add_asset_to_collection,
468 assets_collection_to_add)
470 # Return if there are no assets to write.
471 if not asset_filename_map:
472 tf_logging.info("No assets to write.")
473 return
475 # Copy assets from source path to destination path.
476 copy_assets_to_destination_dir(asset_filename_map, self._export_dir,
477 self._saved_asset_files)
479 def _maybe_add_main_op(self, main_op):
480 """Adds main op to the SavedModel.
482 Args:
483 main_op: Main op to run as part of graph initialization. If None, no main
484 op will be added to the graph.
486 Raises:
487 TypeError: If the main op is provided but is not of type `Operation`.
488 ValueError: if the Graph already contains an init op.
489 """
490 if main_op is None:
491 return
493 if not isinstance(main_op, ops.Operation):
494 raise TypeError(f"Expected {main_op} to be an Operation but got type "
495 f"{type(main_op)} instead.")
497 # Validate that no other init ops have been added to this graph already.
498 # We check main_op and legacy_init_op for thoroughness and explicitness.
499 for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
500 if ops.get_collection(init_op_key):
501 raise ValueError(
502 "Graph already contains one or more main ops under the "
503 f"collection {init_op_key}.")
505 ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
507 def _add_train_op(self, train_op):
508 """Add train op to the SavedModel.
510 Note that this functionality is in development, and liable to be
511 moved elsewhere.
513 Args:
514 train_op: Op or group of ops that are used for training. These are stored
515 as a collection with key TRAIN_OP_KEY, but not executed.
517 Raises:
518 TypeError if Train op is not of type `Operation`.
519 """
520 if train_op is not None:
521 if (not isinstance(train_op, ops.Tensor) and
522 not isinstance(train_op, ops.Operation)):
523 raise TypeError(f"`train_op` {train_op} needs to be a Tensor or Op.")
524 ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
526 @deprecated_args(None,
527 "Pass your op to the equivalent parameter main_op instead.",
528 "legacy_init_op")
529 def add_meta_graph(self,
530 tags,
531 signature_def_map=None,
532 assets_collection=None,
533 legacy_init_op=None,
534 clear_devices=False,
535 main_op=None,
536 strip_default_attrs=False,
537 saver=None):
538 if not self._has_saved_variables:
539 raise AssertionError(
540 "Graph state including variables and assets has not been saved yet. "
541 "Please invoke `add_meta_graph_and_variables()` first.")
543 # Validate the signature def map to ensure all included TensorInfos are
544 # properly populated.
545 signature_def_map = signature_def_map or {}
546 self._validate_signature_def_map(signature_def_map)
548 # legacy_init_op is deprecated, and going away in TF 2.0.
549 # Re-mapping to main_op, as treatment is identical regardless.
550 main_op = main_op if main_op is not None else legacy_init_op
552 # Add assets and ops
553 self._add_collections(assets_collection, main_op, None)
555 saver = self._maybe_create_saver(saver)
557 # The graph almost certainly previously contained at least one Saver, and
558 # possibly several (e.g. one for loading a pretrained embedding, and another
559 # for the model weights). Removing the preexisting ones was the
560 # motivation for the clear_extraneous_savers option, but it turns out that
561 # there are edge cases where that option breaks the graph. Until that is
562 # resolved, we just leave the option set to False for now.
563 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
564 meta_graph_def = saver.export_meta_graph(
565 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
567 # Tag the meta graph def and add it to the SavedModel.
568 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
570 @deprecated_args(None,
571 "Pass your op to the equivalent parameter main_op instead.",
572 "legacy_init_op")
573 def add_meta_graph_and_variables(self,
574 sess,
575 tags,
576 signature_def_map=None,
577 assets_collection=None,
578 legacy_init_op=None,
579 clear_devices=False,
580 main_op=None,
581 strip_default_attrs=False,
582 saver=None):
583 if self._has_saved_variables:
584 raise AssertionError("Graph state including variables and assets has "
585 "already been saved. Please invoke "
586 "`add_meta_graph()` instead.")
588 # Validate the signature def map to ensure all included TensorInfos are
589 # properly populated.
590 signature_def_map = signature_def_map or {}
591 self._validate_signature_def_map(signature_def_map)
593 # legacy_init_op is deprecated, and going away in TF 2.0.
594 # Re-mapping to main_op, as treatment is identical regardless.
595 main_op = main_op or legacy_init_op
597 # Add assets and ops
598 self._add_collections(assets_collection, main_op, None)
600 path_helpers.get_or_create_variables_dir(self._export_dir)
601 variables_path = path_helpers.get_variables_path(self._export_dir)
603 saver = self._maybe_create_saver(saver)
605 # Save the variables. Also, disable writing the checkpoint state proto. The
606 # file is not used during SavedModel loading. In addition, since a
607 # SavedModel can be copied or moved, this avoids the checkpoint state to
608 # become outdated.
609 saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
611 # Export the meta graph def.
613 # The graph almost certainly previously contained at least one Saver, and
614 # possibly several (e.g. one for loading a pretrained embedding, and another
615 # for the model weights). Removing the preexisting ones was the
616 # motivation for the clear_extraneous_savers option, but it turns out that
617 # there are edge cases where that option breaks the graph. Until that is
618 # resolved, we just leave the option set to False for now.
619 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
620 meta_graph_def = saver.export_meta_graph(
621 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
623 # Tag the meta graph def and add it to the SavedModel.
624 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
626 # Mark this instance of SavedModel as having saved variables, such that
627 # subsequent attempts to save variables will fail.
628 self._has_saved_variables = True
630 add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace(
631 "assets_list", "assets_collection")
632 add_meta_graph_and_variables.__doc__ = \
633 _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace(
634 "assets_list", "assets_collection")
637def _maybe_save_assets(write_fn, assets_to_add=None):
638 """Saves assets to the meta graph.
640 Args:
641 write_fn: A function callback that writes assets into meta graph.
642 assets_to_add: The list where the asset paths are setup.
644 Returns:
645 A dict of asset basenames for saving to the original full path to the asset.
647 Raises:
648 ValueError: Indicating an invalid filepath tensor.
649 """
650 # Map of target file names to original filenames
651 asset_filename_map = {}
653 if assets_to_add is None:
654 tf_logging.info("No assets to save.")
655 return asset_filename_map
657 # Iterate over the supplied assets, build the `AssetFile` proto and add them
658 # to the meta graph.
659 for asset_tensor in assets_to_add:
660 asset_source_filepath = _asset_path_from_tensor(asset_tensor)
661 if not asset_source_filepath:
662 raise ValueError(f"Asset filepath tensor {asset_tensor} in is invalid.")
664 asset_filename = get_asset_filename_to_add(
665 asset_source_filepath, asset_filename_map)
667 # Call the passed-in function that builds AssetFileDef proto and adds it
668 # to either the collection or asset_file_def field of the meta graph.
669 # Note that this should be done even when the file is a duplicate of an
670 # already-added file, as the tensor reference should still exist.
671 write_fn(asset_filename, asset_tensor)
673 # In the cases where we are adding a duplicate, this will result in the
674 # last of the filepaths being the one used for copying the file to the
675 # SavedModel. Since the files in question are the same, it doesn't matter
676 # either way.
677 asset_filename_map[asset_filename] = asset_source_filepath
679 tf_logging.info("Assets added to graph.")
680 return asset_filename_map
683def get_asset_filename_to_add(asset_filepath, asset_filename_map):
684 """Get a unique basename to add to the SavedModel if this file is unseen.
686 Assets come from users as full paths, and we save them out to the
687 SavedModel as basenames. In some cases, the basenames collide. Here,
688 we dedupe asset basenames by first checking if the file is the same,
689 and, if different, generate and return an index-suffixed basename
690 that can be used to add the asset to the SavedModel.
692 Args:
693 asset_filepath: the full path to the asset that is being saved
694 asset_filename_map: a dict of filenames used for saving the asset in
695 the SavedModel to full paths from which the filenames were derived.
697 Returns:
698 Uniquified filename string if the file is not a duplicate, or the original
699 filename if the file has already been seen and saved.
700 """
701 asset_filename = os.path.basename(asset_filepath)
703 if asset_filename not in asset_filename_map:
704 # This is an unseen asset. Safe to add.
705 return asset_filename
707 other_asset_filepath = asset_filename_map[asset_filename]
708 if other_asset_filepath == asset_filepath:
709 # This is the same file, stored twice in the list. No need
710 # to make unique.
711 return asset_filename
713 # Else, asset_filename is in the map, and the filepath is different. Dedupe.
714 if not file_io.filecmp(asset_filepath, other_asset_filepath):
715 # Files are different; dedupe filenames.
716 return _get_unique_asset_filename(asset_filename, asset_filename_map)
718 # Files are the same; don't make unique.
719 return asset_filename
722def _get_unique_asset_filename(asset_filename, asset_filename_map):
723 i = 1
724 unique_filename = asset_filename
725 while unique_filename in asset_filename_map:
726 unique_filename = compat.as_bytes("_").join(
727 [compat.as_bytes(asset_filename), compat.as_bytes(str(i))])
728 i += 1
729 return unique_filename
732def _asset_path_from_tensor(path_tensor):
733 """Returns the filepath value stored in constant `path_tensor`.
735 Args:
736 path_tensor: Tensor of a file-path.
738 Returns:
739 The string value i.e. path of the tensor, if valid.
741 Raises:
742 TypeError if tensor does not match expected op type, dtype or value.
743 """
744 if not isinstance(path_tensor, ops.Tensor):
745 raise TypeError(f"Asset path tensor {path_tensor} must be a Tensor.")
746 if path_tensor.op.type != "Const":
747 raise TypeError(f"Asset path tensor {path_tensor} must be of type constant."
748 f"Has type {path_tensor.op.type} instead.")
749 if path_tensor.dtype != dtypes.string:
750 raise TypeError(f"Asset path tensor {path_tensor}` must be of dtype string."
751 f"Has type {path_tensor.dtype} instead.")
752 str_values = path_tensor.op.get_attr("value").string_val
753 if len(str_values) != 1:
754 raise TypeError(f"Asset path tensor {path_tensor} must be a scalar.")
755 return str_values[0]
758def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor):
759 """Builds an asset proto and adds it to the meta graph def.
761 Args:
762 meta_graph_def: The meta graph def to which the asset will be added.
763 asset_filename: The filename of the asset to be added.
764 asset_tensor: The asset tensor used to populate the tensor info of the asset
765 proto.
766 """
767 asset_proto = meta_graph_def.asset_file_def.add()
768 asset_proto.filename = asset_filename
769 asset_proto.tensor_info.name = asset_tensor.name
772def copy_assets_to_destination_dir(asset_filename_map, destination_dir,
773 saved_files=None):
774 """Copy all assets from source path to destination path.
776 Args:
777 asset_filename_map: a dict of filenames used for saving the asset in
778 the SavedModel to full paths from which the filenames were derived.
779 destination_dir: the destination directory that assets are stored in.
780 saved_files: a set of destination filepaths that have already been copied
781 and will be skipped
782 """
783 if saved_files is None:
784 saved_files = set()
786 assets_destination_dir = path_helpers.get_or_create_assets_dir(
787 destination_dir)
789 # Copy each asset from source path to destination path.
790 for asset_basename, asset_source_filepath in asset_filename_map.items():
791 asset_destination_filepath = file_io.join(
792 compat.as_bytes(assets_destination_dir),
793 compat.as_bytes(asset_basename))
795 # Copy if source file exists, src & dst are not the same, and dst is not in
796 # saved_files
797 if (file_io.file_exists(asset_source_filepath) and
798 asset_source_filepath != asset_destination_filepath and
799 asset_destination_filepath not in saved_files):
800 file_io.copy(
801 asset_source_filepath, asset_destination_filepath, overwrite=True)
802 saved_files.add(asset_destination_filepath)
804 tf_logging.info("Assets written to: %s",
805 compat.as_text(assets_destination_dir))
808def _add_asset_to_collection(asset_filename, asset_tensor):
809 """Builds an asset proto and adds it to the asset collection of the graph.
811 Args:
812 asset_filename: The filename of the asset to be added.
813 asset_tensor: The asset tensor used to populate the tensor info of the
814 asset proto.
815 """
816 asset_proto = meta_graph_pb2.AssetFileDef()
817 asset_proto.filename = asset_filename
818 asset_proto.tensor_info.name = asset_tensor.name
820 asset_any_proto = Any()
821 asset_any_proto.Pack(asset_proto)
822 ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
825def _add_op_to_signature_def_map(signature_def_map, op, key):
826 if op is not None:
827 signature_def_map[key] = signature_def_utils.op_signature_def(op, key)