Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/model_utils/export_utils.py: 24%
139 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# LINT.IfChange
16"""Utilities for creating SavedModels."""
18import collections
19import os
20import time
22from tensorflow.python.lib.io import file_io
23from tensorflow.python.ops import op_selector
24from tensorflow.python.platform import gfile
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.saved_model import signature_constants
27from tensorflow.python.saved_model import signature_def_utils
28from tensorflow.python.saved_model import tag_constants
29from tensorflow.python.saved_model import utils
30from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
31from tensorflow.python.saved_model.model_utils import mode_keys
32from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys
33from tensorflow.python.util import compat
34from tensorflow.python.util import nest
35from tensorflow.python.util import object_identity
38# Mapping of the modes to appropriate MetaGraph tags in the SavedModel.
39EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{
40 ModeKeys.PREDICT: [tag_constants.SERVING],
41 ModeKeys.TRAIN: [tag_constants.TRAINING],
42 ModeKeys.TEST: [tag_constants.EVAL]})
44# For every exported mode, a SignatureDef map should be created using the
45# functions `export_outputs_for_mode` and `build_all_signature_defs`. By
46# default, this map will contain a single Signature that defines the input
47# tensors and output predictions, losses, and/or metrics (depending on the mode)
48# The default keys used in the SignatureDef map are defined below.
49SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{
50 ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
51 ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY,
52 ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY})
54# Default names used in the SignatureDef input map, which maps strings to
55# TensorInfo protos.
56SINGLE_FEATURE_DEFAULT_NAME = 'feature'
57SINGLE_RECEIVER_DEFAULT_NAME = 'input'
58SINGLE_LABEL_DEFAULT_NAME = 'label'
60### Below utilities are specific to SavedModel exports.
63def _must_be_fed(op):
64 return op.type == 'Placeholder'
67def _ensure_servable(input_tensors, names_to_output_tensor_infos):
68 """Check that the signature outputs don't depend on unreachable placeholders.
70 Args:
71 input_tensors: An iterable of `Tensor`s specified as the signature's inputs.
72 names_to_output_tensor_infos: An mapping from output names to respective
73 `TensorInfo`s corresponding to the signature's output tensors.
75 Raises:
76 ValueError: If any of the signature's outputs depend on placeholders not
77 provided as signature's inputs.
78 """
79 plain_input_tensors = nest.flatten(input_tensors, expand_composites=True)
81 graph = op_selector.get_unique_graph(plain_input_tensors)
83 output_tensors = [
84 utils.get_tensor_from_tensor_info(tensor, graph=graph)
85 for tensor in names_to_output_tensor_infos.values()
86 ]
87 plain_output_tensors = nest.flatten(output_tensors, expand_composites=True)
89 dependency_ops = op_selector.get_backward_walk_ops(
90 plain_output_tensors, stop_at_ts=plain_input_tensors)
92 fed_tensors = object_identity.ObjectIdentitySet(plain_input_tensors)
93 for dependency_op in dependency_ops:
94 if _must_be_fed(dependency_op) and (not all(
95 output in fed_tensors for output in dependency_op.outputs)):
96 input_tensor_names = [tensor.name for tensor in plain_input_tensors]
97 output_tensor_keys = list(names_to_output_tensor_infos.keys())
98 output_tensor_names = [tensor.name for tensor in plain_output_tensors]
99 dependency_path = op_selector.show_path(dependency_op,
100 plain_output_tensors,
101 plain_input_tensors)
102 raise ValueError(
103 f'The signature\'s input tensors {input_tensor_names} are '
104 f'insufficient to compute its output keys {output_tensor_keys} '
105 f'(respectively, tensors {output_tensor_names}) because of the '
106 f'dependency on `{dependency_op.name}` which is not given as '
107 'a signature input, as illustrated by the following dependency path: '
108 f'{dependency_path}')
111def build_all_signature_defs(receiver_tensors,
112 export_outputs,
113 receiver_tensors_alternatives=None,
114 serving_only=True):
115 """Build `SignatureDef`s for all export outputs.
117 Args:
118 receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
119 input nodes where this receiver expects to be fed by default. Typically,
120 this is a single placeholder expecting serialized `tf.Example` protos.
121 export_outputs: a dict of ExportOutput instances, each of which has
122 an as_signature_def instance method that will be called to retrieve
123 the signature_def for all export output tensors.
124 receiver_tensors_alternatives: a dict of string to additional
125 groups of receiver tensors, each of which may be a `Tensor` or a dict of
126 string to `Tensor`. These named receiver tensor alternatives generate
127 additional serving signatures, which may be used to feed inputs at
128 different points within the input receiver subgraph. A typical usage is
129 to allow feeding raw feature `Tensor`s *downstream* of the
130 tf.io.parse_example() op. Defaults to None.
131 serving_only: boolean; if true, resulting signature defs will only include
132 valid serving signatures. If false, all requested signatures will be
133 returned.
135 Returns:
136 signature_def representing all passed args.
138 Raises:
139 ValueError: if export_outputs is not a dict
140 """
141 if not isinstance(receiver_tensors, dict):
142 receiver_tensors = {SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
143 if export_outputs is None or not isinstance(export_outputs, dict):
144 raise ValueError('`export_outputs` must be a dict. Received '
145 f'{export_outputs} with type '
146 f'{type(export_outputs).__name__}.')
148 signature_def_map = {}
149 excluded_signatures = {}
150 input_tensors = receiver_tensors.values()
151 for output_key, export_output in export_outputs.items():
152 signature_name = '{}'.format(output_key or 'None')
153 try:
154 signature = export_output.as_signature_def(receiver_tensors)
155 _ensure_servable(input_tensors, signature.outputs)
156 signature_def_map[signature_name] = signature
157 except ValueError as e:
158 excluded_signatures[signature_name] = str(e)
160 if receiver_tensors_alternatives:
161 for receiver_name, receiver_tensors_alt in (
162 receiver_tensors_alternatives.items()):
163 if not isinstance(receiver_tensors_alt, dict):
164 receiver_tensors_alt = {
165 SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
166 }
167 alt_input_tensors = receiver_tensors_alt.values()
168 for output_key, export_output in export_outputs.items():
169 signature_name = '{}:{}'.format(receiver_name or 'None', output_key or
170 'None')
171 try:
172 signature = export_output.as_signature_def(receiver_tensors_alt)
173 _ensure_servable(alt_input_tensors, signature.outputs)
174 signature_def_map[signature_name] = signature
175 except ValueError as e:
176 excluded_signatures[signature_name] = str(e)
178 _log_signature_report(signature_def_map, excluded_signatures)
180 # The above calls to export_output_lib.as_signature_def should return only
181 # valid signatures; if there is a validity problem, they raise a ValueError,
182 # in which case we exclude that signature from signature_def_map above.
183 # The is_valid_signature check ensures that the signatures produced are
184 # valid for serving, and acts as an additional sanity check for export
185 # signatures produced for serving. We skip this check for training and eval
186 # signatures, which are not intended for serving.
187 if serving_only:
188 signature_def_map = {
189 k: v
190 for k, v in signature_def_map.items()
191 if signature_def_utils.is_valid_signature(v)
192 }
193 return signature_def_map
196_FRIENDLY_METHOD_NAMES = {
197 signature_constants.CLASSIFY_METHOD_NAME: 'Classify',
198 signature_constants.REGRESS_METHOD_NAME: 'Regress',
199 signature_constants.PREDICT_METHOD_NAME: 'Predict',
200 signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train',
201 signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval',
202}
205def _log_signature_report(signature_def_map, excluded_signatures):
206 """Log a report of which signatures were produced."""
207 sig_names_by_method_name = collections.defaultdict(list)
209 # We'll collect whatever method_names are present, but also we want to make
210 # sure to output a line for each of the three standard methods even if they
211 # have no signatures.
212 for method_name in _FRIENDLY_METHOD_NAMES:
213 sig_names_by_method_name[method_name] = []
215 for signature_name, sig in signature_def_map.items():
216 sig_names_by_method_name[sig.method_name].append(signature_name)
218 # TODO(b/67733540): consider printing the full signatures, not just names
219 for method_name, sig_names in sig_names_by_method_name.items():
220 if method_name in _FRIENDLY_METHOD_NAMES:
221 method_name = _FRIENDLY_METHOD_NAMES[method_name]
222 logging.info('Signatures INCLUDED in export for {}: {}'.format(
223 method_name, sig_names if sig_names else 'None'))
225 if excluded_signatures:
226 logging.info('Signatures EXCLUDED from export because they cannot be '
227 'be served via TensorFlow Serving APIs:')
228 for signature_name, message in excluded_signatures.items():
229 logging.info('\'{}\' : {}'.format(signature_name, message))
231 if not signature_def_map:
232 logging.warn('Export includes no signatures!')
233 elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
234 signature_def_map):
235 logging.warn('Export includes no default signature!')
238# When we create a timestamped directory, there is a small chance that the
239# directory already exists because another process is also creating these
240# directories. In this case we just wait one second to get a new timestamp and
241# try again. If this fails several times in a row, then something is seriously
242# wrong.
243MAX_DIRECTORY_CREATION_ATTEMPTS = 10
246def get_timestamped_export_dir(export_dir_base):
247 """Builds a path to a new subdirectory within the base directory.
249 Each export is written into a new subdirectory named using the
250 current time. This guarantees monotonically increasing version
251 numbers even across multiple runs of the pipeline.
252 The timestamp used is the number of seconds since epoch UTC.
254 Args:
255 export_dir_base: A string containing a directory to write the exported
256 graph and checkpoints.
257 Returns:
258 The full path of the new subdirectory (which is not actually created yet).
260 Raises:
261 RuntimeError: if repeated attempts fail to obtain a unique timestamped
262 directory name.
263 """
264 attempts = 0
265 while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
266 timestamp = int(time.time())
268 result_dir = file_io.join(
269 compat.as_bytes(export_dir_base), compat.as_bytes(str(timestamp)))
270 if not gfile.Exists(result_dir):
271 # Collisions are still possible (though extremely unlikely): this
272 # directory is not actually created yet, but it will be almost
273 # instantly on return from this function.
274 return result_dir
275 time.sleep(1)
276 attempts += 1
277 logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format(
278 compat.as_str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
279 raise RuntimeError('Failed to obtain a unique export directory name after '
280 f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
283def get_temp_export_dir(timestamped_export_dir):
284 """Builds a directory name based on the argument but starting with 'temp-'.
286 This relies on the fact that TensorFlow Serving ignores subdirectories of
287 the base directory that can't be parsed as integers.
289 Args:
290 timestamped_export_dir: the name of the eventual export directory, e.g.
291 /foo/bar/<timestamp>
293 Returns:
294 A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>.
295 """
296 (dirname, basename) = os.path.split(timestamped_export_dir)
297 if isinstance(basename, bytes):
298 str_name = basename.decode('utf-8')
299 else:
300 str_name = str(basename)
301 temp_export_dir = file_io.join(
302 compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(str_name)))
303 return temp_export_dir
306def export_outputs_for_mode(
307 mode, serving_export_outputs=None, predictions=None, loss=None,
308 metrics=None):
309 """Util function for constructing a `ExportOutput` dict given a mode.
311 The returned dict can be directly passed to `build_all_signature_defs` helper
312 function as the `export_outputs` argument, used for generating a SignatureDef
313 map.
315 Args:
316 mode: A `ModeKeys` specifying the mode.
317 serving_export_outputs: Describes the output signatures to be exported to
318 `SavedModel` and used during serving. Should be a dict or None.
319 predictions: A dict of Tensors or single Tensor representing model
320 predictions. This argument is only used if serving_export_outputs is not
321 set.
322 loss: A dict of Tensors or single Tensor representing calculated loss.
323 metrics: A dict of (metric_value, update_op) tuples, or a single tuple.
324 metric_value must be a Tensor, and update_op must be a Tensor or Op
326 Returns:
327 Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
328 The key is the expected SignatureDef key for the mode.
330 Raises:
331 ValueError: if an appropriate ExportOutput cannot be found for the mode.
332 """
333 if mode not in SIGNATURE_KEY_MAP:
334 raise ValueError(
335 f'Export output type not found for `mode`: {mode}. Expected one of: '
336 f'{list(SIGNATURE_KEY_MAP.keys())}.\n'
337 'One likely error is that V1 Estimator Modekeys were somehow passed to '
338 'this function. Please ensure that you are using the new ModeKeys.')
339 signature_key = SIGNATURE_KEY_MAP[mode]
340 if mode_keys.is_predict(mode):
341 return get_export_outputs(serving_export_outputs, predictions)
342 elif mode_keys.is_train(mode):
343 return {signature_key: export_output_lib.TrainOutput(
344 loss=loss, predictions=predictions, metrics=metrics)}
345 else:
346 return {signature_key: export_output_lib.EvalOutput(
347 loss=loss, predictions=predictions, metrics=metrics)}
350def get_export_outputs(export_outputs, predictions):
351 """Validate export_outputs or create default export_outputs.
353 Args:
354 export_outputs: Describes the output signatures to be exported to
355 `SavedModel` and used during serving. Should be a dict or None.
356 predictions: Predictions `Tensor` or dict of `Tensor`.
358 Returns:
359 Valid export_outputs dict
361 Raises:
362 TypeError: if export_outputs is not a dict or its values are not
363 ExportOutput instances.
364 """
365 if export_outputs is None:
366 default_output = export_output_lib.PredictOutput(predictions)
367 export_outputs = {
368 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output}
370 if not isinstance(export_outputs, dict):
371 raise TypeError(
372 f'`export_outputs` must be dict, received: {export_outputs}.')
373 for v in export_outputs.values():
374 if not isinstance(v, export_output_lib.ExportOutput):
375 raise TypeError(
376 'Values in `export_outputs` must be ExportOutput objects, '
377 f'received: {export_outputs}.')
379 _maybe_add_default_serving_output(export_outputs)
381 return export_outputs
384def _maybe_add_default_serving_output(export_outputs):
385 """Add a default serving output to the export_outputs if not present.
387 Args:
388 export_outputs: Describes the output signatures to be exported to
389 `SavedModel` and used during serving. Should be a dict.
391 Returns:
392 export_outputs dict with default serving signature added if necessary
394 Raises:
395 ValueError: if multiple export_outputs were provided without a default
396 serving key.
397 """
398 if len(export_outputs) == 1:
399 (key, value), = export_outputs.items()
400 if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
401 export_outputs[
402 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value
403 if len(export_outputs) > 1:
404 if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
405 not in export_outputs):
406 raise ValueError(
407 'Multiple `export_outputs` were provided, but none of them are '
408 'specified as the default. Use'
409 '`tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY` to '
410 'specify a default.')
412 return export_outputs
413# LINT.ThenChange(//keras/saving/utils_v1/export_utils.py)