Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py: 22%
144 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SignatureDef utility functions implementation."""
18from tensorflow.core.framework import types_pb2
19from tensorflow.core.protobuf import meta_graph_pb2
20from tensorflow.python.framework import errors
21from tensorflow.python.framework import ops
22from tensorflow.python.saved_model import signature_constants
23from tensorflow.python.saved_model import utils_impl as utils
24from tensorflow.python.util import deprecation
25from tensorflow.python.util.tf_export import tf_export
28@tf_export(
29 v1=[
30 'saved_model.build_signature_def',
31 'saved_model.signature_def_utils.build_signature_def'
32 ])
33@deprecation.deprecated_endpoints(
34 'saved_model.signature_def_utils.build_signature_def')
35def build_signature_def(inputs=None, outputs=None, method_name=None):
36 """Utility function to build a SignatureDef protocol buffer.
38 Args:
39 inputs: Inputs of the SignatureDef defined as a proto map of string to
40 tensor info.
41 outputs: Outputs of the SignatureDef defined as a proto map of string to
42 tensor info.
43 method_name: Method name of the SignatureDef as a string.
45 Returns:
46 A SignatureDef protocol buffer constructed based on the supplied arguments.
47 """
48 signature_def = meta_graph_pb2.SignatureDef()
49 if inputs is not None:
50 for item in inputs:
51 signature_def.inputs[item].CopyFrom(inputs[item])
52 if outputs is not None:
53 for item in outputs:
54 signature_def.outputs[item].CopyFrom(outputs[item])
55 if method_name is not None:
56 signature_def.method_name = method_name
57 return signature_def
60@tf_export(
61 v1=[
62 'saved_model.regression_signature_def',
63 'saved_model.signature_def_utils.regression_signature_def'
64 ])
65@deprecation.deprecated_endpoints(
66 'saved_model.signature_def_utils.regression_signature_def')
67def regression_signature_def(examples, predictions):
68 """Creates regression signature from given examples and predictions.
70 This function produces signatures intended for use with the TensorFlow Serving
71 Regress API (tensorflow_serving/apis/prediction_service.proto), and so
72 constrains the input and output types to those allowed by TensorFlow Serving.
74 Args:
75 examples: A string `Tensor`, expected to accept serialized tf.Examples.
76 predictions: A float `Tensor`.
78 Returns:
79 A regression-flavored signature_def.
81 Raises:
82 ValueError: If examples is `None`.
83 """
84 if examples is None:
85 raise ValueError('Regression `examples` cannot be None.')
86 if not isinstance(examples, ops.Tensor):
87 raise ValueError('Expected regression `examples` to be of type Tensor. '
88 f'Found `examples` of type {type(examples)}.')
89 if predictions is None:
90 raise ValueError('Regression `predictions` cannot be None.')
92 input_tensor_info = utils.build_tensor_info(examples)
93 if input_tensor_info.dtype != types_pb2.DT_STRING:
94 raise ValueError('Regression input tensors must be of type string. '
95 f'Found tensors with type {input_tensor_info.dtype}.')
96 signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
98 output_tensor_info = utils.build_tensor_info(predictions)
99 if output_tensor_info.dtype != types_pb2.DT_FLOAT:
100 raise ValueError('Regression output tensors must be of type float. '
101 f'Found tensors with type {output_tensor_info.dtype}.')
102 signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
104 signature_def = build_signature_def(
105 signature_inputs, signature_outputs,
106 signature_constants.REGRESS_METHOD_NAME)
108 return signature_def
111@tf_export(
112 v1=[
113 'saved_model.classification_signature_def',
114 'saved_model.signature_def_utils.classification_signature_def'
115 ])
116@deprecation.deprecated_endpoints(
117 'saved_model.signature_def_utils.classification_signature_def')
118def classification_signature_def(examples, classes, scores):
119 """Creates classification signature from given examples and predictions.
121 This function produces signatures intended for use with the TensorFlow Serving
122 Classify API (tensorflow_serving/apis/prediction_service.proto), and so
123 constrains the input and output types to those allowed by TensorFlow Serving.
125 Args:
126 examples: A string `Tensor`, expected to accept serialized tf.Examples.
127 classes: A string `Tensor`. Note that the ClassificationResponse message
128 requires that class labels are strings, not integers or anything else.
129 scores: a float `Tensor`.
131 Returns:
132 A classification-flavored signature_def.
134 Raises:
135 ValueError: If examples is `None`.
136 """
137 if examples is None:
138 raise ValueError('Classification `examples` cannot be None.')
139 if not isinstance(examples, ops.Tensor):
140 raise ValueError('Classification `examples` must be a string Tensor. '
141 f'Found `examples` of type {type(examples)}.')
142 if classes is None and scores is None:
143 raise ValueError('Classification `classes` and `scores` cannot both be '
144 'None.')
146 input_tensor_info = utils.build_tensor_info(examples)
147 if input_tensor_info.dtype != types_pb2.DT_STRING:
148 raise ValueError('Classification input tensors must be of type string. '
149 f'Found tensors of type {input_tensor_info.dtype}')
150 signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
152 signature_outputs = {}
153 if classes is not None:
154 classes_tensor_info = utils.build_tensor_info(classes)
155 if classes_tensor_info.dtype != types_pb2.DT_STRING:
156 raise ValueError('Classification classes must be of type string Tensor. '
157 f'Found tensors of type {classes_tensor_info.dtype}.`')
158 signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
159 classes_tensor_info)
160 if scores is not None:
161 scores_tensor_info = utils.build_tensor_info(scores)
162 if scores_tensor_info.dtype != types_pb2.DT_FLOAT:
163 raise ValueError('Classification scores must be a float Tensor.')
164 signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
165 scores_tensor_info)
167 signature_def = build_signature_def(
168 signature_inputs, signature_outputs,
169 signature_constants.CLASSIFY_METHOD_NAME)
171 return signature_def
174@tf_export(
175 v1=[
176 'saved_model.predict_signature_def',
177 'saved_model.signature_def_utils.predict_signature_def'
178 ])
179@deprecation.deprecated_endpoints(
180 'saved_model.signature_def_utils.predict_signature_def')
181def predict_signature_def(inputs, outputs):
182 """Creates prediction signature from given inputs and outputs.
184 This function produces signatures intended for use with the TensorFlow Serving
185 Predict API (tensorflow_serving/apis/prediction_service.proto). This API
186 imposes no constraints on the input and output types.
188 Args:
189 inputs: dict of string to `Tensor`.
190 outputs: dict of string to `Tensor`.
192 Returns:
193 A prediction-flavored signature_def.
195 Raises:
196 ValueError: If inputs or outputs is `None`.
197 """
198 if inputs is None or not inputs:
199 raise ValueError('Prediction `inputs` cannot be None or empty.')
200 if outputs is None or not outputs:
201 raise ValueError('Prediction `outputs` cannot be None or empty.')
203 signature_inputs = {key: utils.build_tensor_info(tensor)
204 for key, tensor in inputs.items()}
205 signature_outputs = {key: utils.build_tensor_info(tensor)
206 for key, tensor in outputs.items()}
208 signature_def = build_signature_def(
209 signature_inputs, signature_outputs,
210 signature_constants.PREDICT_METHOD_NAME)
212 return signature_def
215# LINT.IfChange
216def supervised_train_signature_def(
217 inputs, loss, predictions=None, metrics=None):
218 return _supervised_signature_def(
219 signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
220 predictions=predictions, metrics=metrics)
223def supervised_eval_signature_def(
224 inputs, loss, predictions=None, metrics=None):
225 return _supervised_signature_def(
226 signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
227 predictions=predictions, metrics=metrics)
230def _supervised_signature_def(
231 method_name, inputs, loss=None, predictions=None,
232 metrics=None):
233 """Creates a signature for training and eval data.
235 This function produces signatures that describe the inputs and outputs
236 of a supervised process, such as training or evaluation, that
237 results in loss, metrics, and the like. Note that this function only requires
238 inputs to be not None.
240 Args:
241 method_name: Method name of the SignatureDef as a string.
242 inputs: dict of string to `Tensor`.
243 loss: dict of string to `Tensor` representing computed loss.
244 predictions: dict of string to `Tensor` representing the output predictions.
245 metrics: dict of string to `Tensor` representing metric ops.
247 Returns:
248 A train- or eval-flavored signature_def.
250 Raises:
251 ValueError: If inputs or outputs is `None`.
252 """
253 if inputs is None or not inputs:
254 raise ValueError(f'{method_name} `inputs` cannot be None or empty.')
256 signature_inputs = {key: utils.build_tensor_info(tensor)
257 for key, tensor in inputs.items()}
259 signature_outputs = {}
260 for output_set in (loss, predictions, metrics):
261 if output_set is not None:
262 sig_out = {key: utils.build_tensor_info(tensor)
263 for key, tensor in output_set.items()}
264 signature_outputs.update(sig_out)
266 signature_def = build_signature_def(
267 signature_inputs, signature_outputs, method_name)
269 return signature_def
270# LINT.ThenChange(//keras/saving/utils_v1/signature_def_utils.py)
273@tf_export(
274 v1=[
275 'saved_model.is_valid_signature',
276 'saved_model.signature_def_utils.is_valid_signature'
277 ])
278@deprecation.deprecated_endpoints(
279 'saved_model.signature_def_utils.is_valid_signature')
280def is_valid_signature(signature_def):
281 """Determine whether a SignatureDef can be served by TensorFlow Serving."""
282 if signature_def is None:
283 return False
284 return (_is_valid_classification_signature(signature_def) or
285 _is_valid_regression_signature(signature_def) or
286 _is_valid_predict_signature(signature_def))
289def _is_valid_predict_signature(signature_def):
290 """Determine whether the argument is a servable 'predict' SignatureDef."""
291 if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME:
292 return False
293 if not signature_def.inputs.keys():
294 return False
295 if not signature_def.outputs.keys():
296 return False
297 return True
300def _is_valid_regression_signature(signature_def):
301 """Determine whether the argument is a servable 'regress' SignatureDef."""
302 if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME:
303 return False
305 if (set(signature_def.inputs.keys())
306 != set([signature_constants.REGRESS_INPUTS])):
307 return False
308 if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype !=
309 types_pb2.DT_STRING):
310 return False
312 if (set(signature_def.outputs.keys())
313 != set([signature_constants.REGRESS_OUTPUTS])):
314 return False
315 if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype !=
316 types_pb2.DT_FLOAT):
317 return False
319 return True
322def _is_valid_classification_signature(signature_def):
323 """Determine whether the argument is a servable 'classify' SignatureDef."""
324 if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME:
325 return False
327 if (set(signature_def.inputs.keys())
328 != set([signature_constants.CLASSIFY_INPUTS])):
329 return False
330 if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype !=
331 types_pb2.DT_STRING):
332 return False
334 allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES,
335 signature_constants.CLASSIFY_OUTPUT_SCORES])
337 if not signature_def.outputs.keys():
338 return False
339 if set(signature_def.outputs.keys()) - allowed_outputs:
340 return False
341 if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs
342 and
343 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype
344 != types_pb2.DT_STRING):
345 return False
346 if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs
347 and
348 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype !=
349 types_pb2.DT_FLOAT):
350 return False
352 return True
355def op_signature_def(op, key):
356 """Creates a signature def with the output pointing to an op.
358 Note that op isn't strictly enforced to be an Op object, and may be a Tensor.
359 It is recommended to use the build_signature_def() function for Tensors.
361 Args:
362 op: An Op (or possibly Tensor).
363 key: Key to graph element in the SignatureDef outputs.
365 Returns:
366 A SignatureDef with a single output pointing to the op.
367 """
368 # Use build_tensor_info_from_op, which creates a TensorInfo from the element's
369 # name.
370 return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)})
373def load_op_from_signature_def(signature_def, key, import_scope=None):
374 """Load an Op from a SignatureDef created by op_signature_def().
376 Args:
377 signature_def: a SignatureDef proto
378 key: string key to op in the SignatureDef outputs.
379 import_scope: Scope used to import the op
381 Returns:
382 Op (or possibly Tensor) in the graph with the same name as saved in the
383 SignatureDef.
385 Raises:
386 NotFoundError: If the op could not be found in the graph.
387 """
388 tensor_info = signature_def.outputs[key]
389 try:
390 # The init and train ops are not strictly enforced to be operations, so
391 # retrieve any graph element (can be either op or tensor).
392 return utils.get_element_from_tensor_info(
393 tensor_info, import_scope=import_scope)
394 except KeyError:
395 raise errors.NotFoundError(
396 None, None,
397 f'The key "{key}" could not be found in the graph. Please make sure the'
398 ' SavedModel was created by the internal _SavedModelBuilder. If you '
399 'are using the public API, please make sure the SignatureDef in the '
400 f'SavedModel does not contain the key "{key}".')