Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/method_name_updater.py: 32%
38 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SignatureDef method name utility functions.
17Utility functions for manipulating signature_def.method_names.
18"""
20from tensorflow.python.lib.io import file_io
21from tensorflow.python.platform import tf_logging
22from tensorflow.python.saved_model import constants
23from tensorflow.python.saved_model import loader_impl as loader
24from tensorflow.python.util import compat
25from tensorflow.python.util.tf_export import tf_export
28# TODO(jdchung): Consider integrated this into the saved_model_cli so that users
29# could do this from the command line directly.
30@tf_export(v1=["saved_model.signature_def_utils.MethodNameUpdater"])
31class MethodNameUpdater(object):
32 """Updates the method name(s) of the SavedModel stored in the given path.
34 The `MethodNameUpdater` class provides the functionality to update the method
35 name field in the signature_defs of the given SavedModel. For example, it
36 can be used to replace the `predict` `method_name` to `regress`.
38 Typical usages of the `MethodNameUpdater`
39 ```python
40 ...
41 updater = tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater(
42 export_dir)
43 # Update all signature_defs with key "foo" in all meta graph defs.
44 updater.replace_method_name(signature_key="foo", method_name="regress")
45 # Update a single signature_def with key "bar" in the meta graph def with
46 # tags ["serve"]
47 updater.replace_method_name(signature_key="bar", method_name="classify",
48 tags="serve")
49 updater.save(new_export_dir)
50 ```
52 Note: This function will only be available through the v1 compatibility
53 library as tf.compat.v1.saved_model.builder.MethodNameUpdater.
54 """
56 def __init__(self, export_dir):
57 """Creates an MethodNameUpdater object.
59 Args:
60 export_dir: Directory containing the SavedModel files.
62 Raises:
63 IOError: If the saved model file does not exist, or cannot be successfully
64 parsed.
65 """
66 self._export_dir = export_dir
67 self._saved_model = loader.parse_saved_model(export_dir)
69 def replace_method_name(self, signature_key, method_name, tags=None):
70 """Replaces the method_name in the specified signature_def.
72 This will match and replace multiple sig defs iff tags is None (i.e when
73 multiple `MetaGraph`s have a signature_def with the same key).
74 If tags is not None, this will only replace a single signature_def in the
75 `MetaGraph` with matching tags.
77 Args:
78 signature_key: Key of the signature_def to be updated.
79 method_name: new method_name to replace the existing one.
80 tags: A tag or sequence of tags identifying the `MetaGraph` to update. If
81 None, all meta graphs will be updated.
82 Raises:
83 ValueError: if signature_key or method_name are not defined or
84 if no metagraphs were found with the associated tags or
85 if no meta graph has a signature_def that matches signature_key.
86 """
87 if not signature_key:
88 raise ValueError("`signature_key` must be defined.")
89 if not method_name:
90 raise ValueError("`method_name` must be defined.")
92 if (tags is not None and not isinstance(tags, list)):
93 tags = [tags]
94 found_match = False
95 for meta_graph_def in self._saved_model.meta_graphs:
96 if tags is None or set(tags) == set(meta_graph_def.meta_info_def.tags):
97 if signature_key not in meta_graph_def.signature_def:
98 raise ValueError(
99 f"MetaGraphDef associated with tags {tags} "
100 f"does not have a signature_def with key: '{signature_key}'. "
101 "This means either you specified the wrong signature key or "
102 "forgot to put the signature_def with the corresponding key in "
103 "your SavedModel.")
104 meta_graph_def.signature_def[signature_key].method_name = method_name
105 found_match = True
107 if not found_match:
108 raise ValueError(
109 f"MetaGraphDef associated with tags {tags} could not be found in "
110 "SavedModel. This means either you specified invalid tags or your "
111 "SavedModel does not have a MetaGraphDef with the specified tags.")
113 def save(self, new_export_dir=None):
114 """Saves the updated `SavedModel`.
116 Args:
117 new_export_dir: Path where the updated `SavedModel` will be saved. If
118 None, the input `SavedModel` will be overriden with the updates.
120 Raises:
121 errors.OpError: If there are errors during the file save operation.
122 """
124 is_input_text_proto = file_io.file_exists(
125 file_io.join(
126 compat.as_bytes(self._export_dir),
127 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)))
128 if not new_export_dir:
129 new_export_dir = self._export_dir
131 if is_input_text_proto:
132 # TODO(jdchung): Add a util for the path creation below.
133 path = file_io.join(
134 compat.as_bytes(new_export_dir),
135 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
136 file_io.write_string_to_file(path, str(self._saved_model))
137 else:
138 path = file_io.join(
139 compat.as_bytes(new_export_dir),
140 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
141 file_io.write_string_to_file(
142 path, self._saved_model.SerializeToString(deterministic=True))
143 tf_logging.info("SavedModel written to: %s", compat.as_text(path))