Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py: 23%
160 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers for working with signatures in tf.saved_model.save."""
17from absl import logging
19from tensorflow.python.eager import def_function
20from tensorflow.python.eager import function as defun
21from tensorflow.python.framework import composite_tensor
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_spec
24from tensorflow.python.ops import resource_variable_ops
25from tensorflow.python.saved_model import function_serialization
26from tensorflow.python.saved_model import revived_types
27from tensorflow.python.saved_model import signature_constants
28from tensorflow.python.trackable import base
29from tensorflow.python.types import core
30from tensorflow.python.util import compat
31from tensorflow.python.util import nest
32from tensorflow.python.util.compat import collections_abc
35DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
36SIGNATURE_ATTRIBUTE_NAME = "signatures"
37# Max number of warnings to show if signature contains normalized input names.
38_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5
41def _get_signature(function):
42 if (isinstance(function, (defun.Function, def_function.Function)) and
43 function.input_signature is not None):
44 function = function._get_concrete_function_garbage_collected() # pylint: disable=protected-access
45 if not isinstance(function, defun.ConcreteFunction):
46 return None
47 return function
50def _valid_signature(concrete_function):
51 """Returns whether concrete function can be converted to a signature."""
52 if not concrete_function.outputs:
53 # Functions without outputs don't make sense as signatures. We just don't
54 # have any way to run an Operation with no outputs as a SignatureDef in the
55 # 1.x style.
56 return False
57 try:
58 _validate_inputs(concrete_function)
59 _normalize_outputs(concrete_function.structured_outputs, "unused", "unused")
60 except ValueError:
61 return False
62 return True
65def _validate_inputs(concrete_function):
66 """Raises error if input type is tf.Variable."""
67 if any(isinstance(inp, resource_variable_ops.VariableSpec)
68 for inp in nest.flatten(
69 concrete_function.structured_input_signature)):
70 raise ValueError(
71 f"Unable to serialize concrete_function '{concrete_function.name}'"
72 f"with tf.Variable input. Functions that expect tf.Variable "
73 "inputs cannot be exported as signatures.")
76def _get_signature_name_changes(concrete_function):
77 """Checks for user-specified signature input names that are normalized."""
78 # Map of {user-given name: normalized name} if the names are un-identical.
79 name_changes = {}
80 for signature_input_name, graph_input in zip(
81 concrete_function.function_def.signature.input_arg,
82 concrete_function.graph.inputs):
83 try:
84 user_specified_name = compat.as_str(
85 graph_input.op.get_attr("_user_specified_name"))
86 if signature_input_name.name != user_specified_name:
87 name_changes[user_specified_name] = signature_input_name.name
88 except ValueError:
89 # Signature input does not have a user-specified name.
90 pass
91 return name_changes
94def find_function_to_export(saveable_view):
95 """Function to export, None if no suitable function was found."""
96 # If the user did not specify signatures, check the root object for a function
97 # that can be made into a signature.
98 children = saveable_view.list_children(saveable_view.root)
100 # TODO(b/205014194): Discuss removing this behaviour. It can lead to WTFs when
101 # a user decides to annotate more functions with tf.function and suddenly
102 # serving that model way later in the process stops working.
103 possible_signatures = []
104 for name, child in children:
105 if not isinstance(child, (def_function.Function, defun.ConcreteFunction)):
106 continue
107 if name == DEFAULT_SIGNATURE_ATTR:
108 return child
109 concrete = _get_signature(child)
110 if concrete is not None and _valid_signature(concrete):
111 possible_signatures.append(concrete)
113 if len(possible_signatures) == 1:
114 single_function = possible_signatures[0]
115 signature = _get_signature(single_function)
116 if signature and _valid_signature(signature):
117 return signature
118 return None
121def canonicalize_signatures(signatures):
122 """Converts `signatures` into a dictionary of concrete functions."""
123 if signatures is None:
124 return {}, {}, {}
125 if not isinstance(signatures, collections_abc.Mapping):
126 signatures = {
127 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
128 num_normalized_signatures_counter = 0
129 concrete_signatures = {}
130 wrapped_functions = {}
131 defaults = {}
132 for signature_key, function in signatures.items():
133 original_function = signature_function = _get_signature(function)
134 if signature_function is None:
135 raise ValueError(
136 "Expected a TensorFlow function for which to generate a signature, "
137 f"but got {function}. Only `tf.functions` with an input signature or "
138 "concrete functions can be used as a signature.")
140 wrapped_functions[original_function] = signature_function = (
141 wrapped_functions.get(original_function) or
142 function_serialization.wrap_cached_variables(original_function))
143 _validate_inputs(signature_function)
144 if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES:
145 signature_name_changes = _get_signature_name_changes(signature_function)
146 if signature_name_changes:
147 num_normalized_signatures_counter += 1
148 logging.info(
149 "Function `%s` contains input name(s) %s with unsupported "
150 "characters which will be renamed to %s in the SavedModel.",
151 compat.as_str(signature_function.graph.name),
152 ", ".join(signature_name_changes.keys()),
153 ", ".join(signature_name_changes.values()))
154 # Re-wrap the function so that it returns a dictionary of Tensors. This
155 # matches the format of 1.x-style signatures.
156 # pylint: disable=cell-var-from-loop
157 def signature_wrapper(**kwargs):
158 structured_outputs = signature_function(**kwargs)
159 return _normalize_outputs(
160 structured_outputs, signature_function.name, signature_key)
161 if hasattr(function, "__name__"):
162 signature_wrapper.__name__ = "signature_wrapper_" + function.__name__
163 wrapped_function = def_function.function(signature_wrapper)
164 tensor_spec_signature = {}
165 if signature_function.structured_input_signature is not None:
166 # The structured input signature may contain other non-tensor arguments.
167 inputs = filter(
168 lambda x: isinstance(x, tensor_spec.TensorSpec),
169 nest.flatten(signature_function.structured_input_signature,
170 expand_composites=True))
171 else:
172 # Structured input signature isn't always defined for some functions.
173 inputs = signature_function.inputs
175 for keyword, inp in zip(
176 signature_function._arg_keywords, # pylint: disable=protected-access
177 inputs):
178 keyword = compat.as_str(keyword)
179 if isinstance(inp, tensor_spec.TensorSpec):
180 spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword)
181 else:
182 spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword)
183 tensor_spec_signature[keyword] = spec
184 final_concrete = wrapped_function._get_concrete_function_garbage_collected( # pylint: disable=protected-access
185 **tensor_spec_signature)
186 # pylint: disable=protected-access
187 if len(final_concrete._arg_keywords) == 1:
188 # If there is only one input to the signature, a very common case, then
189 # ordering is unambiguous and we can let people pass a positional
190 # argument. Since SignatureDefs are unordered (protobuf "map") multiple
191 # arguments means we need to be keyword-only.
192 final_concrete._num_positional_args = 1
193 else:
194 final_concrete._num_positional_args = 0
195 # pylint: enable=protected-access
196 concrete_signatures[signature_key] = final_concrete
197 # pylint: enable=cell-var-from-loop
198 if isinstance(function, core.GenericFunction):
199 full_arg_spec = function._function_spec.fullargspec # pylint: disable=protected-access
200 len_defaults = len(full_arg_spec.defaults or [])
201 for arg, default in zip(
202 full_arg_spec.args[-len_defaults:], full_arg_spec.defaults or []
203 ):
204 if not default:
205 continue
206 defaults[(signature_key, arg)] = default
207 return concrete_signatures, wrapped_functions, defaults
210def _normalize_outputs(outputs, function_name, signature_key):
211 """Normalize outputs if necessary and check that they are tensors."""
212 # Convert `outputs` to a dictionary (if it's not one already).
213 if not isinstance(outputs, collections_abc.Mapping):
214 # Check if `outputs` is a namedtuple.
215 if hasattr(outputs, "_asdict"):
216 outputs = outputs._asdict()
217 else:
218 if not isinstance(outputs, collections_abc.Sequence):
219 outputs = [outputs]
220 outputs = {("output_{}".format(output_index)): output
221 for output_index, output in enumerate(outputs)}
223 # Check that the keys of `outputs` are strings and the values are Tensors.
224 for key, value in outputs.items():
225 if not isinstance(key, compat.bytes_or_text_types):
226 raise ValueError(
227 f"Got a dictionary with a non-string key {key!r} in the output of "
228 f"the function {compat.as_str_any(function_name)} used to generate "
229 f"the SavedModel signature {signature_key!r}.")
230 if not isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
231 raise ValueError(
232 f"Got a non-Tensor value {value!r} for key {key!r} in the output of "
233 f"the function {compat.as_str_any(function_name)} used to generate "
234 f"the SavedModel signature {signature_key!r}. "
235 "Outputs for functions used as signatures must be a single Tensor, "
236 "a sequence of Tensors, or a dictionary from string to Tensor.")
237 return outputs
240# _SignatureMap is immutable to ensure that users do not expect changes to be
241# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the
242# only way to create a _SignatureMap and there is no way to modify it. So we can
243# safely ignore/overwrite ".signatures" attributes attached to objects being
244# saved if they contain a _SignatureMap. A ".signatures" attribute containing
245# any other type (e.g. a regular dict) will raise an exception asking the user
246# to first "del obj.signatures" if they want it overwritten.
247class _SignatureMap(collections_abc.Mapping, base.Trackable):
248 """A collection of SavedModel signatures."""
250 def __init__(self):
251 self._signatures = {}
253 def _add_signature(self, name, concrete_function):
254 """Adds a signature to the _SignatureMap."""
255 # Ideally this object would be immutable, but restore is streaming so we do
256 # need a private API for adding new signatures to an existing object.
257 self._signatures[name] = concrete_function
259 def __getitem__(self, key):
260 return self._signatures[key]
262 def __iter__(self):
263 return iter(self._signatures)
265 def __len__(self):
266 return len(self._signatures)
268 def __repr__(self):
269 return "_SignatureMap({})".format(self._signatures)
271 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
272 if save_type != base.SaveType.SAVEDMODEL:
273 return {}
275 return {
276 key: value for key, value in self.items()
277 if isinstance(value, (def_function.Function, defun.ConcreteFunction))
278 }
281revived_types.register_revived_type(
282 "signature_map",
283 lambda obj: isinstance(obj, _SignatureMap),
284 versions=[revived_types.VersionedTypeRegistration(
285 # Standard dependencies are enough to reconstruct the trackable
286 # items in dictionaries, so we don't need to save any extra information.
287 object_factory=lambda proto: _SignatureMap(),
288 version=1,
289 min_producer_version=1,
290 min_consumer_version=1,
291 setter=_SignatureMap._add_signature # pylint: disable=protected-access
292 )])
295def create_signature_map(signatures):
296 """Creates an object containing `signatures`."""
297 signature_map = _SignatureMap()
298 for name, func in signatures.items():
299 # This true of any signature that came from canonicalize_signatures. Here as
300 # a sanity check on saving; crashing on load (e.g. in _add_signature) would
301 # be more problematic in case future export changes violated these
302 # assertions.
303 assert isinstance(func, defun.ConcreteFunction)
304 assert isinstance(func.structured_outputs, collections_abc.Mapping)
305 # pylint: disable=protected-access
306 if len(func._arg_keywords) == 1:
307 assert 1 == func._num_positional_args
308 else:
309 assert 0 == func._num_positional_args
310 signature_map._add_signature(name, func)
311 # pylint: enable=protected-access
312 return signature_map
315def validate_augmented_graph_view(augmented_graph_view):
316 """Performs signature-related sanity checks on `augmented_graph_view`."""
317 for name, dep in augmented_graph_view.list_children(
318 augmented_graph_view.root):
319 if name == SIGNATURE_ATTRIBUTE_NAME:
320 if not isinstance(dep, _SignatureMap):
321 raise ValueError(
322 f"Exporting an object {augmented_graph_view.root} which has an attribute "
323 f"named '{SIGNATURE_ATTRIBUTE_NAME}'. This is a reserved attribute "
324 "used to store SavedModel signatures in objects which come from "
325 "`tf.saved_model.load`. Delete this attribute "
326 f"(e.g. `del obj.{SIGNATURE_ATTRIBUTE_NAME}`) before saving if "
327 "this shadowing is acceptable.")
328 break