Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/model_utils/export_output.py: 39%
152 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# LINT.IfChange
16"""Classes for different types of export output."""
18import abc
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.saved_model import signature_def_utils
28class ExportOutput:
29 """Represents an output of a model that can be served.
31 These typically correspond to model heads.
32 """
34 __metaclass__ = abc.ABCMeta
36 _SEPARATOR_CHAR = '/'
38 @abc.abstractmethod
39 def as_signature_def(self, receiver_tensors):
40 """Generate a SignatureDef proto for inclusion in a MetaGraphDef.
42 The SignatureDef will specify outputs as described in this ExportOutput,
43 and will use the provided receiver_tensors as inputs.
45 Args:
46 receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
47 input nodes that will be fed.
48 """
49 pass
51 def _check_output_key(self, key, error_label):
52 # For multi-head models, the key can be a tuple.
53 if isinstance(key, tuple):
54 key = self._SEPARATOR_CHAR.join(key)
56 if not isinstance(key, str):
57 raise ValueError(
58 '{} output key must be a string; got {}.'.format(error_label, key))
59 return key
61 def _wrap_and_check_outputs(
62 self, outputs, single_output_default_name, error_label=None):
63 """Wraps raw tensors as dicts and checks type.
65 Note that we create a new dict here so that we can overwrite the keys
66 if necessary.
68 Args:
69 outputs: A `Tensor` or a dict of string to `Tensor`.
70 single_output_default_name: A string key for use in the output dict
71 if the provided `outputs` is a raw tensor.
72 error_label: descriptive string for use in error messages. If none,
73 single_output_default_name will be used.
75 Returns:
76 A dict of tensors
78 Raises:
79 ValueError: if the outputs dict keys are not strings or tuples of strings
80 or the values are not Tensors.
81 """
82 if not isinstance(outputs, dict):
83 outputs = {single_output_default_name: outputs}
85 output_dict = {}
86 for key, value in outputs.items():
87 error_name = error_label or single_output_default_name
88 key = self._check_output_key(key, error_name)
89 if not isinstance(value, ops.Tensor):
90 raise ValueError(
91 '{} output value must be a Tensor; got {}.'.format(
92 error_name, value))
94 output_dict[key] = value
95 return output_dict
98class ClassificationOutput(ExportOutput):
99 """Represents the output of a classification head.
101 Either classes or scores or both must be set.
103 The classes `Tensor` must provide string labels, not integer class IDs.
105 If only classes is set, it is interpreted as providing top-k results in
106 descending order.
108 If only scores is set, it is interpreted as providing a score for every class
109 in order of class ID.
111 If both classes and scores are set, they are interpreted as zipped, so each
112 score corresponds to the class at the same index. Clients should not depend
113 on the order of the entries.
114 """
116 def __init__(self, scores=None, classes=None):
117 """Constructor for `ClassificationOutput`.
119 Args:
120 scores: A float `Tensor` giving scores (sometimes but not always
121 interpretable as probabilities) for each class. May be `None`, but
122 only if `classes` is set. Interpretation varies-- see class doc.
123 classes: A string `Tensor` giving predicted class labels. May be `None`,
124 but only if `scores` is set. Interpretation varies-- see class doc.
126 Raises:
127 ValueError: if neither classes nor scores is set, or one of them is not a
128 `Tensor` with the correct dtype.
129 """
130 if (scores is not None
131 and not (isinstance(scores, ops.Tensor)
132 and scores.dtype.is_floating)):
133 raise ValueError('Classification scores must be a float32 Tensor; '
134 'got {}'.format(scores))
135 if (classes is not None
136 and not (isinstance(classes, ops.Tensor)
137 and dtypes.as_dtype(classes.dtype) == dtypes.string)):
138 raise ValueError('Classification classes must be a string Tensor; '
139 'got {}'.format(classes))
140 if scores is None and classes is None:
141 raise ValueError('Cannot create a ClassificationOutput with empty '
142 'arguments. At least one of `scores` and `classes` '
143 'must be defined.')
144 self._scores = scores
145 self._classes = classes
147 @property
148 def scores(self):
149 return self._scores
151 @property
152 def classes(self):
153 return self._classes
155 def as_signature_def(self, receiver_tensors):
156 if len(receiver_tensors) != 1:
157 raise ValueError(
158 'Classification signatures can only accept a single tensor input of '
159 'type tf.string. Please check to make sure that you have structured '
160 'the serving_input_receiver_fn so that it creates a single string '
161 'placeholder. If your model function expects multiple inputs, then '
162 'use `tf.io.parse_example()` to parse the string into multiple '
163 f'tensors.\n Received: {receiver_tensors}')
164 (_, examples), = receiver_tensors.items()
165 if dtypes.as_dtype(examples.dtype) != dtypes.string:
166 raise ValueError(
167 'Classification signatures can only accept a single tensor input of '
168 'type tf.string. Please check to make sure that you have structured '
169 'the serving_input_receiver_fn so that it creates a single string '
170 'placeholder. If your model function expects multiple inputs, then '
171 'use `tf.io.parse_example()` to parse the string into multiple '
172 f'tensors.\n Received: {receiver_tensors}')
173 return signature_def_utils.classification_signature_def(
174 examples, self.classes, self.scores)
177class RegressionOutput(ExportOutput):
178 """Represents the output of a regression head."""
180 def __init__(self, value):
181 """Constructor for `RegressionOutput`.
183 Args:
184 value: a float `Tensor` giving the predicted values. Required.
186 Raises:
187 ValueError: if the value is not a `Tensor` with dtype tf.float32.
188 """
189 if not (isinstance(value, ops.Tensor) and value.dtype.is_floating):
190 raise ValueError('Regression output value must be a float32 Tensor; '
191 'got {}'.format(value))
192 self._value = value
194 @property
195 def value(self):
196 return self._value
198 def as_signature_def(self, receiver_tensors):
199 if len(receiver_tensors) != 1:
200 raise ValueError(
201 'Regression signatures can only accept a single tensor input of '
202 'type tf.string. Please check to make sure that you have structured '
203 'the serving_input_receiver_fn so that it creates a single string '
204 'placeholder. If your model function expects multiple inputs, then '
205 'use `tf.io.parse_example()` to parse the string into multiple '
206 f'tensors.\n Received: {receiver_tensors}')
207 (_, examples), = receiver_tensors.items()
208 if dtypes.as_dtype(examples.dtype) != dtypes.string:
209 raise ValueError(
210 'Regression signatures can only accept a single tensor input of '
211 'type tf.string. Please check to make sure that you have structured '
212 'the serving_input_receiver_fn so that it creates a single string '
213 'placeholder. If your model function expects multiple inputs, then '
214 'use `tf.io.parse_example()` to parse the string into multiple '
215 f'tensors.\n Received: {receiver_tensors}')
216 return signature_def_utils.regression_signature_def(examples, self.value)
219class PredictOutput(ExportOutput):
220 """Represents the output of a generic prediction head.
222 A generic prediction need not be either a classification or a regression.
224 Named outputs must be provided as a dict from string to `Tensor`,
225 """
226 _SINGLE_OUTPUT_DEFAULT_NAME = 'output'
228 def __init__(self, outputs):
229 """Constructor for PredictOutput.
231 Args:
232 outputs: A `Tensor` or a dict of string to `Tensor` representing the
233 predictions.
235 Raises:
236 ValueError: if the outputs is not dict, or any of its keys are not
237 strings, or any of its values are not `Tensor`s.
238 """
240 self._outputs = self._wrap_and_check_outputs(
241 outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction')
243 @property
244 def outputs(self):
245 return self._outputs
247 def as_signature_def(self, receiver_tensors):
248 return signature_def_utils.predict_signature_def(receiver_tensors,
249 self.outputs)
252class _SupervisedOutput(ExportOutput):
253 """Represents the output of a supervised training or eval process."""
254 __metaclass__ = abc.ABCMeta
256 LOSS_NAME = 'loss'
257 PREDICTIONS_NAME = 'predictions'
258 METRICS_NAME = 'metrics'
260 METRIC_VALUE_SUFFIX = 'value'
261 METRIC_UPDATE_SUFFIX = 'update_op'
263 _loss = None
264 _predictions = None
265 _metrics = None
267 def __init__(self, loss=None, predictions=None, metrics=None):
268 """Constructor for SupervisedOutput (ie, Train or Eval output).
270 Args:
271 loss: dict of Tensors or single Tensor representing calculated loss.
272 predictions: dict of Tensors or single Tensor representing model
273 predictions.
274 metrics: Dict of metric results keyed by name.
275 The values of the dict can be one of the following:
276 (1) instance of `Metric` class.
277 (2) (metric_value, update_op) tuples, or a single tuple.
278 metric_value must be a Tensor, and update_op must be a Tensor or Op.
280 Raises:
281 ValueError: if any of the outputs' dict keys are not strings or tuples of
282 strings or the values are not Tensors (or Operations in the case of
283 update_op).
284 """
286 if loss is not None:
287 loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME)
288 self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME)
289 if predictions is not None:
290 pred_dict = self._wrap_and_check_outputs(
291 predictions, self.PREDICTIONS_NAME)
292 self._predictions = self._prefix_output_keys(
293 pred_dict, self.PREDICTIONS_NAME)
294 if metrics is not None:
295 self._metrics = self._wrap_and_check_metrics(metrics)
297 def _prefix_output_keys(self, output_dict, output_name):
298 """Prepend output_name to the output_dict keys if it doesn't exist.
300 This produces predictable prefixes for the pre-determined outputs
301 of SupervisedOutput.
303 Args:
304 output_dict: dict of string to Tensor, assumed valid.
305 output_name: prefix string to prepend to existing keys.
307 Returns:
308 dict with updated keys and existing values.
309 """
311 new_outputs = {}
312 for key, val in output_dict.items():
313 key = self._prefix_key(key, output_name)
314 new_outputs[key] = val
315 return new_outputs
317 def _prefix_key(self, key, output_name):
318 if key.find(output_name) != 0:
319 key = output_name + self._SEPARATOR_CHAR + key
320 return key
322 def _wrap_and_check_metrics(self, metrics):
323 """Handle the saving of metrics.
325 Metrics is either a tuple of (value, update_op), or a dict of such tuples.
326 Here, we separate out the tuples and create a dict with names to tensors.
328 Args:
329 metrics: Dict of metric results keyed by name.
330 The values of the dict can be one of the following:
331 (1) instance of `Metric` class.
332 (2) (metric_value, update_op) tuples, or a single tuple.
333 metric_value must be a Tensor, and update_op must be a Tensor or Op.
335 Returns:
336 dict of output_names to tensors
338 Raises:
339 ValueError: if the dict key is not a string, or the metric values or ops
340 are not tensors.
341 """
342 if not isinstance(metrics, dict):
343 metrics = {self.METRICS_NAME: metrics}
345 outputs = {}
346 for key, value in metrics.items():
347 if isinstance(value, tuple):
348 metric_val, metric_op = value
349 else: # value is a keras.Metrics object
350 metric_val = value.result()
351 assert len(value.updates) == 1 # We expect only one update op.
352 metric_op = value.updates[0]
353 key = self._check_output_key(key, self.METRICS_NAME)
354 key = self._prefix_key(key, self.METRICS_NAME)
356 val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX
357 op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX
358 if not isinstance(metric_val, ops.Tensor):
359 raise ValueError(
360 '{} output value must be a Tensor; got {}.'.format(
361 key, metric_val))
362 if not (tensor_util.is_tf_type(metric_op) or
363 isinstance(metric_op, ops.Operation)):
364 raise ValueError(
365 '{} update_op must be a Tensor or Operation; got {}.'.format(
366 key, metric_op))
368 # We must wrap any ops (or variables) in a Tensor before export, as the
369 # SignatureDef proto expects tensors only. See b/109740581
370 metric_op_tensor = metric_op
371 if not isinstance(metric_op, ops.Tensor):
372 with ops.control_dependencies([metric_op]):
373 metric_op_tensor = constant_op.constant([], name='metric_op_wrapper')
375 outputs[val_name] = metric_val
376 outputs[op_name] = metric_op_tensor
378 return outputs
380 @property
381 def loss(self):
382 return self._loss
384 @property
385 def predictions(self):
386 return self._predictions
388 @property
389 def metrics(self):
390 return self._metrics
392 @abc.abstractmethod
393 def _get_signature_def_fn(self):
394 """Returns a function that produces a SignatureDef given desired outputs."""
395 pass
397 def as_signature_def(self, receiver_tensors):
398 signature_def_fn = self._get_signature_def_fn()
399 return signature_def_fn(
400 receiver_tensors, self.loss, self.predictions, self.metrics)
403class TrainOutput(_SupervisedOutput):
404 """Represents the output of a supervised training process.
406 This class generates the appropriate signature def for exporting
407 training output by type-checking and wrapping loss, predictions, and metrics
408 values.
409 """
411 def _get_signature_def_fn(self):
412 return signature_def_utils.supervised_train_signature_def
415class EvalOutput(_SupervisedOutput):
416 """Represents the output of a supervised eval process.
418 This class generates the appropriate signature def for exporting
419 eval output by type-checking and wrapping loss, predictions, and metrics
420 values.
421 """
423 def _get_signature_def_fn(self):
424 return signature_def_utils.supervised_eval_signature_def
425# LINT.ThenChange(//keras/saving/utils_v1/export_output.py)