Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py: 22%

144 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""SignatureDef utility functions implementation.""" 

16 

17 

18from tensorflow.core.framework import types_pb2 

19from tensorflow.core.protobuf import meta_graph_pb2 

20from tensorflow.python.framework import errors 

21from tensorflow.python.framework import ops 

22from tensorflow.python.saved_model import signature_constants 

23from tensorflow.python.saved_model import utils_impl as utils 

24from tensorflow.python.util import deprecation 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28@tf_export( 

29 v1=[ 

30 'saved_model.build_signature_def', 

31 'saved_model.signature_def_utils.build_signature_def' 

32 ]) 

33@deprecation.deprecated_endpoints( 

34 'saved_model.signature_def_utils.build_signature_def') 

35def build_signature_def(inputs=None, outputs=None, method_name=None): 

36 """Utility function to build a SignatureDef protocol buffer. 

37 

38 Args: 

39 inputs: Inputs of the SignatureDef defined as a proto map of string to 

40 tensor info. 

41 outputs: Outputs of the SignatureDef defined as a proto map of string to 

42 tensor info. 

43 method_name: Method name of the SignatureDef as a string. 

44 

45 Returns: 

46 A SignatureDef protocol buffer constructed based on the supplied arguments. 

47 """ 

48 signature_def = meta_graph_pb2.SignatureDef() 

49 if inputs is not None: 

50 for item in inputs: 

51 signature_def.inputs[item].CopyFrom(inputs[item]) 

52 if outputs is not None: 

53 for item in outputs: 

54 signature_def.outputs[item].CopyFrom(outputs[item]) 

55 if method_name is not None: 

56 signature_def.method_name = method_name 

57 return signature_def 

58 

59 

60@tf_export( 

61 v1=[ 

62 'saved_model.regression_signature_def', 

63 'saved_model.signature_def_utils.regression_signature_def' 

64 ]) 

65@deprecation.deprecated_endpoints( 

66 'saved_model.signature_def_utils.regression_signature_def') 

67def regression_signature_def(examples, predictions): 

68 """Creates regression signature from given examples and predictions. 

69 

70 This function produces signatures intended for use with the TensorFlow Serving 

71 Regress API (tensorflow_serving/apis/prediction_service.proto), and so 

72 constrains the input and output types to those allowed by TensorFlow Serving. 

73 

74 Args: 

75 examples: A string `Tensor`, expected to accept serialized tf.Examples. 

76 predictions: A float `Tensor`. 

77 

78 Returns: 

79 A regression-flavored signature_def. 

80 

81 Raises: 

82 ValueError: If examples is `None`. 

83 """ 

84 if examples is None: 

85 raise ValueError('Regression `examples` cannot be None.') 

86 if not isinstance(examples, ops.Tensor): 

87 raise ValueError('Expected regression `examples` to be of type Tensor. ' 

88 f'Found `examples` of type {type(examples)}.') 

89 if predictions is None: 

90 raise ValueError('Regression `predictions` cannot be None.') 

91 

92 input_tensor_info = utils.build_tensor_info(examples) 

93 if input_tensor_info.dtype != types_pb2.DT_STRING: 

94 raise ValueError('Regression input tensors must be of type string. ' 

95 f'Found tensors with type {input_tensor_info.dtype}.') 

96 signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info} 

97 

98 output_tensor_info = utils.build_tensor_info(predictions) 

99 if output_tensor_info.dtype != types_pb2.DT_FLOAT: 

100 raise ValueError('Regression output tensors must be of type float. ' 

101 f'Found tensors with type {output_tensor_info.dtype}.') 

102 signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info} 

103 

104 signature_def = build_signature_def( 

105 signature_inputs, signature_outputs, 

106 signature_constants.REGRESS_METHOD_NAME) 

107 

108 return signature_def 

109 

110 

111@tf_export( 

112 v1=[ 

113 'saved_model.classification_signature_def', 

114 'saved_model.signature_def_utils.classification_signature_def' 

115 ]) 

116@deprecation.deprecated_endpoints( 

117 'saved_model.signature_def_utils.classification_signature_def') 

118def classification_signature_def(examples, classes, scores): 

119 """Creates classification signature from given examples and predictions. 

120 

121 This function produces signatures intended for use with the TensorFlow Serving 

122 Classify API (tensorflow_serving/apis/prediction_service.proto), and so 

123 constrains the input and output types to those allowed by TensorFlow Serving. 

124 

125 Args: 

126 examples: A string `Tensor`, expected to accept serialized tf.Examples. 

127 classes: A string `Tensor`. Note that the ClassificationResponse message 

128 requires that class labels are strings, not integers or anything else. 

129 scores: a float `Tensor`. 

130 

131 Returns: 

132 A classification-flavored signature_def. 

133 

134 Raises: 

135 ValueError: If examples is `None`. 

136 """ 

137 if examples is None: 

138 raise ValueError('Classification `examples` cannot be None.') 

139 if not isinstance(examples, ops.Tensor): 

140 raise ValueError('Classification `examples` must be a string Tensor. ' 

141 f'Found `examples` of type {type(examples)}.') 

142 if classes is None and scores is None: 

143 raise ValueError('Classification `classes` and `scores` cannot both be ' 

144 'None.') 

145 

146 input_tensor_info = utils.build_tensor_info(examples) 

147 if input_tensor_info.dtype != types_pb2.DT_STRING: 

148 raise ValueError('Classification input tensors must be of type string. ' 

149 f'Found tensors of type {input_tensor_info.dtype}') 

150 signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info} 

151 

152 signature_outputs = {} 

153 if classes is not None: 

154 classes_tensor_info = utils.build_tensor_info(classes) 

155 if classes_tensor_info.dtype != types_pb2.DT_STRING: 

156 raise ValueError('Classification classes must be of type string Tensor. ' 

157 f'Found tensors of type {classes_tensor_info.dtype}.`') 

158 signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = ( 

159 classes_tensor_info) 

160 if scores is not None: 

161 scores_tensor_info = utils.build_tensor_info(scores) 

162 if scores_tensor_info.dtype != types_pb2.DT_FLOAT: 

163 raise ValueError('Classification scores must be a float Tensor.') 

164 signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = ( 

165 scores_tensor_info) 

166 

167 signature_def = build_signature_def( 

168 signature_inputs, signature_outputs, 

169 signature_constants.CLASSIFY_METHOD_NAME) 

170 

171 return signature_def 

172 

173 

174@tf_export( 

175 v1=[ 

176 'saved_model.predict_signature_def', 

177 'saved_model.signature_def_utils.predict_signature_def' 

178 ]) 

179@deprecation.deprecated_endpoints( 

180 'saved_model.signature_def_utils.predict_signature_def') 

181def predict_signature_def(inputs, outputs): 

182 """Creates prediction signature from given inputs and outputs. 

183 

184 This function produces signatures intended for use with the TensorFlow Serving 

185 Predict API (tensorflow_serving/apis/prediction_service.proto). This API 

186 imposes no constraints on the input and output types. 

187 

188 Args: 

189 inputs: dict of string to `Tensor`. 

190 outputs: dict of string to `Tensor`. 

191 

192 Returns: 

193 A prediction-flavored signature_def. 

194 

195 Raises: 

196 ValueError: If inputs or outputs is `None`. 

197 """ 

198 if inputs is None or not inputs: 

199 raise ValueError('Prediction `inputs` cannot be None or empty.') 

200 if outputs is None or not outputs: 

201 raise ValueError('Prediction `outputs` cannot be None or empty.') 

202 

203 signature_inputs = {key: utils.build_tensor_info(tensor) 

204 for key, tensor in inputs.items()} 

205 signature_outputs = {key: utils.build_tensor_info(tensor) 

206 for key, tensor in outputs.items()} 

207 

208 signature_def = build_signature_def( 

209 signature_inputs, signature_outputs, 

210 signature_constants.PREDICT_METHOD_NAME) 

211 

212 return signature_def 

213 

214 

215# LINT.IfChange 

216def supervised_train_signature_def( 

217 inputs, loss, predictions=None, metrics=None): 

218 return _supervised_signature_def( 

219 signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, 

220 predictions=predictions, metrics=metrics) 

221 

222 

223def supervised_eval_signature_def( 

224 inputs, loss, predictions=None, metrics=None): 

225 return _supervised_signature_def( 

226 signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, 

227 predictions=predictions, metrics=metrics) 

228 

229 

230def _supervised_signature_def( 

231 method_name, inputs, loss=None, predictions=None, 

232 metrics=None): 

233 """Creates a signature for training and eval data. 

234 

235 This function produces signatures that describe the inputs and outputs 

236 of a supervised process, such as training or evaluation, that 

237 results in loss, metrics, and the like. Note that this function only requires 

238 inputs to be not None. 

239 

240 Args: 

241 method_name: Method name of the SignatureDef as a string. 

242 inputs: dict of string to `Tensor`. 

243 loss: dict of string to `Tensor` representing computed loss. 

244 predictions: dict of string to `Tensor` representing the output predictions. 

245 metrics: dict of string to `Tensor` representing metric ops. 

246 

247 Returns: 

248 A train- or eval-flavored signature_def. 

249 

250 Raises: 

251 ValueError: If inputs or outputs is `None`. 

252 """ 

253 if inputs is None or not inputs: 

254 raise ValueError(f'{method_name} `inputs` cannot be None or empty.') 

255 

256 signature_inputs = {key: utils.build_tensor_info(tensor) 

257 for key, tensor in inputs.items()} 

258 

259 signature_outputs = {} 

260 for output_set in (loss, predictions, metrics): 

261 if output_set is not None: 

262 sig_out = {key: utils.build_tensor_info(tensor) 

263 for key, tensor in output_set.items()} 

264 signature_outputs.update(sig_out) 

265 

266 signature_def = build_signature_def( 

267 signature_inputs, signature_outputs, method_name) 

268 

269 return signature_def 

270# LINT.ThenChange(//keras/saving/utils_v1/signature_def_utils.py) 

271 

272 

273@tf_export( 

274 v1=[ 

275 'saved_model.is_valid_signature', 

276 'saved_model.signature_def_utils.is_valid_signature' 

277 ]) 

278@deprecation.deprecated_endpoints( 

279 'saved_model.signature_def_utils.is_valid_signature') 

280def is_valid_signature(signature_def): 

281 """Determine whether a SignatureDef can be served by TensorFlow Serving.""" 

282 if signature_def is None: 

283 return False 

284 return (_is_valid_classification_signature(signature_def) or 

285 _is_valid_regression_signature(signature_def) or 

286 _is_valid_predict_signature(signature_def)) 

287 

288 

289def _is_valid_predict_signature(signature_def): 

290 """Determine whether the argument is a servable 'predict' SignatureDef.""" 

291 if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME: 

292 return False 

293 if not signature_def.inputs.keys(): 

294 return False 

295 if not signature_def.outputs.keys(): 

296 return False 

297 return True 

298 

299 

300def _is_valid_regression_signature(signature_def): 

301 """Determine whether the argument is a servable 'regress' SignatureDef.""" 

302 if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME: 

303 return False 

304 

305 if (set(signature_def.inputs.keys()) 

306 != set([signature_constants.REGRESS_INPUTS])): 

307 return False 

308 if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype != 

309 types_pb2.DT_STRING): 

310 return False 

311 

312 if (set(signature_def.outputs.keys()) 

313 != set([signature_constants.REGRESS_OUTPUTS])): 

314 return False 

315 if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype != 

316 types_pb2.DT_FLOAT): 

317 return False 

318 

319 return True 

320 

321 

322def _is_valid_classification_signature(signature_def): 

323 """Determine whether the argument is a servable 'classify' SignatureDef.""" 

324 if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME: 

325 return False 

326 

327 if (set(signature_def.inputs.keys()) 

328 != set([signature_constants.CLASSIFY_INPUTS])): 

329 return False 

330 if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype != 

331 types_pb2.DT_STRING): 

332 return False 

333 

334 allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES, 

335 signature_constants.CLASSIFY_OUTPUT_SCORES]) 

336 

337 if not signature_def.outputs.keys(): 

338 return False 

339 if set(signature_def.outputs.keys()) - allowed_outputs: 

340 return False 

341 if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs 

342 and 

343 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype 

344 != types_pb2.DT_STRING): 

345 return False 

346 if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs 

347 and 

348 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype != 

349 types_pb2.DT_FLOAT): 

350 return False 

351 

352 return True 

353 

354 

355def op_signature_def(op, key): 

356 """Creates a signature def with the output pointing to an op. 

357 

358 Note that op isn't strictly enforced to be an Op object, and may be a Tensor. 

359 It is recommended to use the build_signature_def() function for Tensors. 

360 

361 Args: 

362 op: An Op (or possibly Tensor). 

363 key: Key to graph element in the SignatureDef outputs. 

364 

365 Returns: 

366 A SignatureDef with a single output pointing to the op. 

367 """ 

368 # Use build_tensor_info_from_op, which creates a TensorInfo from the element's 

369 # name. 

370 return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)}) 

371 

372 

373def load_op_from_signature_def(signature_def, key, import_scope=None): 

374 """Load an Op from a SignatureDef created by op_signature_def(). 

375 

376 Args: 

377 signature_def: a SignatureDef proto 

378 key: string key to op in the SignatureDef outputs. 

379 import_scope: Scope used to import the op 

380 

381 Returns: 

382 Op (or possibly Tensor) in the graph with the same name as saved in the 

383 SignatureDef. 

384 

385 Raises: 

386 NotFoundError: If the op could not be found in the graph. 

387 """ 

388 tensor_info = signature_def.outputs[key] 

389 try: 

390 # The init and train ops are not strictly enforced to be operations, so 

391 # retrieve any graph element (can be either op or tensor). 

392 return utils.get_element_from_tensor_info( 

393 tensor_info, import_scope=import_scope) 

394 except KeyError: 

395 raise errors.NotFoundError( 

396 None, None, 

397 f'The key "{key}" could not be found in the graph. Please make sure the' 

398 ' SavedModel was created by the internal _SavedModelBuilder. If you ' 

399 'are using the public API, please make sure the SignatureDef in the ' 

400 f'SavedModel does not contain the key "{key}".')