Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/model_utils/export_utils.py: 24%

139 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# LINT.IfChange 

16"""Utilities for creating SavedModels.""" 

17 

18import collections 

19import os 

20import time 

21 

22from tensorflow.python.lib.io import file_io 

23from tensorflow.python.ops import op_selector 

24from tensorflow.python.platform import gfile 

25from tensorflow.python.platform import tf_logging as logging 

26from tensorflow.python.saved_model import signature_constants 

27from tensorflow.python.saved_model import signature_def_utils 

28from tensorflow.python.saved_model import tag_constants 

29from tensorflow.python.saved_model import utils 

30from tensorflow.python.saved_model.model_utils import export_output as export_output_lib 

31from tensorflow.python.saved_model.model_utils import mode_keys 

32from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys 

33from tensorflow.python.util import compat 

34from tensorflow.python.util import nest 

35from tensorflow.python.util import object_identity 

36 

37 

38# Mapping of the modes to appropriate MetaGraph tags in the SavedModel. 

39EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{ 

40 ModeKeys.PREDICT: [tag_constants.SERVING], 

41 ModeKeys.TRAIN: [tag_constants.TRAINING], 

42 ModeKeys.TEST: [tag_constants.EVAL]}) 

43 

44# For every exported mode, a SignatureDef map should be created using the 

45# functions `export_outputs_for_mode` and `build_all_signature_defs`. By 

46# default, this map will contain a single Signature that defines the input 

47# tensors and output predictions, losses, and/or metrics (depending on the mode) 

48# The default keys used in the SignatureDef map are defined below. 

49SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{ 

50 ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, 

51 ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY, 

52 ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY}) 

53 

54# Default names used in the SignatureDef input map, which maps strings to 

55# TensorInfo protos. 

56SINGLE_FEATURE_DEFAULT_NAME = 'feature' 

57SINGLE_RECEIVER_DEFAULT_NAME = 'input' 

58SINGLE_LABEL_DEFAULT_NAME = 'label' 

59 

60### Below utilities are specific to SavedModel exports. 

61 

62 

63def _must_be_fed(op): 

64 return op.type == 'Placeholder' 

65 

66 

67def _ensure_servable(input_tensors, names_to_output_tensor_infos): 

68 """Check that the signature outputs don't depend on unreachable placeholders. 

69 

70 Args: 

71 input_tensors: An iterable of `Tensor`s specified as the signature's inputs. 

72 names_to_output_tensor_infos: An mapping from output names to respective 

73 `TensorInfo`s corresponding to the signature's output tensors. 

74 

75 Raises: 

76 ValueError: If any of the signature's outputs depend on placeholders not 

77 provided as signature's inputs. 

78 """ 

79 plain_input_tensors = nest.flatten(input_tensors, expand_composites=True) 

80 

81 graph = op_selector.get_unique_graph(plain_input_tensors) 

82 

83 output_tensors = [ 

84 utils.get_tensor_from_tensor_info(tensor, graph=graph) 

85 for tensor in names_to_output_tensor_infos.values() 

86 ] 

87 plain_output_tensors = nest.flatten(output_tensors, expand_composites=True) 

88 

89 dependency_ops = op_selector.get_backward_walk_ops( 

90 plain_output_tensors, stop_at_ts=plain_input_tensors) 

91 

92 fed_tensors = object_identity.ObjectIdentitySet(plain_input_tensors) 

93 for dependency_op in dependency_ops: 

94 if _must_be_fed(dependency_op) and (not all( 

95 output in fed_tensors for output in dependency_op.outputs)): 

96 input_tensor_names = [tensor.name for tensor in plain_input_tensors] 

97 output_tensor_keys = list(names_to_output_tensor_infos.keys()) 

98 output_tensor_names = [tensor.name for tensor in plain_output_tensors] 

99 dependency_path = op_selector.show_path(dependency_op, 

100 plain_output_tensors, 

101 plain_input_tensors) 

102 raise ValueError( 

103 f'The signature\'s input tensors {input_tensor_names} are ' 

104 f'insufficient to compute its output keys {output_tensor_keys} ' 

105 f'(respectively, tensors {output_tensor_names}) because of the ' 

106 f'dependency on `{dependency_op.name}` which is not given as ' 

107 'a signature input, as illustrated by the following dependency path: ' 

108 f'{dependency_path}') 

109 

110 

111def build_all_signature_defs(receiver_tensors, 

112 export_outputs, 

113 receiver_tensors_alternatives=None, 

114 serving_only=True): 

115 """Build `SignatureDef`s for all export outputs. 

116 

117 Args: 

118 receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying 

119 input nodes where this receiver expects to be fed by default. Typically, 

120 this is a single placeholder expecting serialized `tf.Example` protos. 

121 export_outputs: a dict of ExportOutput instances, each of which has 

122 an as_signature_def instance method that will be called to retrieve 

123 the signature_def for all export output tensors. 

124 receiver_tensors_alternatives: a dict of string to additional 

125 groups of receiver tensors, each of which may be a `Tensor` or a dict of 

126 string to `Tensor`. These named receiver tensor alternatives generate 

127 additional serving signatures, which may be used to feed inputs at 

128 different points within the input receiver subgraph. A typical usage is 

129 to allow feeding raw feature `Tensor`s *downstream* of the 

130 tf.io.parse_example() op. Defaults to None. 

131 serving_only: boolean; if true, resulting signature defs will only include 

132 valid serving signatures. If false, all requested signatures will be 

133 returned. 

134 

135 Returns: 

136 signature_def representing all passed args. 

137 

138 Raises: 

139 ValueError: if export_outputs is not a dict 

140 """ 

141 if not isinstance(receiver_tensors, dict): 

142 receiver_tensors = {SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} 

143 if export_outputs is None or not isinstance(export_outputs, dict): 

144 raise ValueError('`export_outputs` must be a dict. Received ' 

145 f'{export_outputs} with type ' 

146 f'{type(export_outputs).__name__}.') 

147 

148 signature_def_map = {} 

149 excluded_signatures = {} 

150 input_tensors = receiver_tensors.values() 

151 for output_key, export_output in export_outputs.items(): 

152 signature_name = '{}'.format(output_key or 'None') 

153 try: 

154 signature = export_output.as_signature_def(receiver_tensors) 

155 _ensure_servable(input_tensors, signature.outputs) 

156 signature_def_map[signature_name] = signature 

157 except ValueError as e: 

158 excluded_signatures[signature_name] = str(e) 

159 

160 if receiver_tensors_alternatives: 

161 for receiver_name, receiver_tensors_alt in ( 

162 receiver_tensors_alternatives.items()): 

163 if not isinstance(receiver_tensors_alt, dict): 

164 receiver_tensors_alt = { 

165 SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt 

166 } 

167 alt_input_tensors = receiver_tensors_alt.values() 

168 for output_key, export_output in export_outputs.items(): 

169 signature_name = '{}:{}'.format(receiver_name or 'None', output_key or 

170 'None') 

171 try: 

172 signature = export_output.as_signature_def(receiver_tensors_alt) 

173 _ensure_servable(alt_input_tensors, signature.outputs) 

174 signature_def_map[signature_name] = signature 

175 except ValueError as e: 

176 excluded_signatures[signature_name] = str(e) 

177 

178 _log_signature_report(signature_def_map, excluded_signatures) 

179 

180 # The above calls to export_output_lib.as_signature_def should return only 

181 # valid signatures; if there is a validity problem, they raise a ValueError, 

182 # in which case we exclude that signature from signature_def_map above. 

183 # The is_valid_signature check ensures that the signatures produced are 

184 # valid for serving, and acts as an additional sanity check for export 

185 # signatures produced for serving. We skip this check for training and eval 

186 # signatures, which are not intended for serving. 

187 if serving_only: 

188 signature_def_map = { 

189 k: v 

190 for k, v in signature_def_map.items() 

191 if signature_def_utils.is_valid_signature(v) 

192 } 

193 return signature_def_map 

194 

195 

196_FRIENDLY_METHOD_NAMES = { 

197 signature_constants.CLASSIFY_METHOD_NAME: 'Classify', 

198 signature_constants.REGRESS_METHOD_NAME: 'Regress', 

199 signature_constants.PREDICT_METHOD_NAME: 'Predict', 

200 signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train', 

201 signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval', 

202} 

203 

204 

205def _log_signature_report(signature_def_map, excluded_signatures): 

206 """Log a report of which signatures were produced.""" 

207 sig_names_by_method_name = collections.defaultdict(list) 

208 

209 # We'll collect whatever method_names are present, but also we want to make 

210 # sure to output a line for each of the three standard methods even if they 

211 # have no signatures. 

212 for method_name in _FRIENDLY_METHOD_NAMES: 

213 sig_names_by_method_name[method_name] = [] 

214 

215 for signature_name, sig in signature_def_map.items(): 

216 sig_names_by_method_name[sig.method_name].append(signature_name) 

217 

218 # TODO(b/67733540): consider printing the full signatures, not just names 

219 for method_name, sig_names in sig_names_by_method_name.items(): 

220 if method_name in _FRIENDLY_METHOD_NAMES: 

221 method_name = _FRIENDLY_METHOD_NAMES[method_name] 

222 logging.info('Signatures INCLUDED in export for {}: {}'.format( 

223 method_name, sig_names if sig_names else 'None')) 

224 

225 if excluded_signatures: 

226 logging.info('Signatures EXCLUDED from export because they cannot be ' 

227 'be served via TensorFlow Serving APIs:') 

228 for signature_name, message in excluded_signatures.items(): 

229 logging.info('\'{}\' : {}'.format(signature_name, message)) 

230 

231 if not signature_def_map: 

232 logging.warn('Export includes no signatures!') 

233 elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in 

234 signature_def_map): 

235 logging.warn('Export includes no default signature!') 

236 

237 

238# When we create a timestamped directory, there is a small chance that the 

239# directory already exists because another process is also creating these 

240# directories. In this case we just wait one second to get a new timestamp and 

241# try again. If this fails several times in a row, then something is seriously 

242# wrong. 

243MAX_DIRECTORY_CREATION_ATTEMPTS = 10 

244 

245 

246def get_timestamped_export_dir(export_dir_base): 

247 """Builds a path to a new subdirectory within the base directory. 

248 

249 Each export is written into a new subdirectory named using the 

250 current time. This guarantees monotonically increasing version 

251 numbers even across multiple runs of the pipeline. 

252 The timestamp used is the number of seconds since epoch UTC. 

253 

254 Args: 

255 export_dir_base: A string containing a directory to write the exported 

256 graph and checkpoints. 

257 Returns: 

258 The full path of the new subdirectory (which is not actually created yet). 

259 

260 Raises: 

261 RuntimeError: if repeated attempts fail to obtain a unique timestamped 

262 directory name. 

263 """ 

264 attempts = 0 

265 while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: 

266 timestamp = int(time.time()) 

267 

268 result_dir = file_io.join( 

269 compat.as_bytes(export_dir_base), compat.as_bytes(str(timestamp))) 

270 if not gfile.Exists(result_dir): 

271 # Collisions are still possible (though extremely unlikely): this 

272 # directory is not actually created yet, but it will be almost 

273 # instantly on return from this function. 

274 return result_dir 

275 time.sleep(1) 

276 attempts += 1 

277 logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format( 

278 compat.as_str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) 

279 raise RuntimeError('Failed to obtain a unique export directory name after ' 

280 f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.') 

281 

282 

283def get_temp_export_dir(timestamped_export_dir): 

284 """Builds a directory name based on the argument but starting with 'temp-'. 

285 

286 This relies on the fact that TensorFlow Serving ignores subdirectories of 

287 the base directory that can't be parsed as integers. 

288 

289 Args: 

290 timestamped_export_dir: the name of the eventual export directory, e.g. 

291 /foo/bar/<timestamp> 

292 

293 Returns: 

294 A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>. 

295 """ 

296 (dirname, basename) = os.path.split(timestamped_export_dir) 

297 if isinstance(basename, bytes): 

298 str_name = basename.decode('utf-8') 

299 else: 

300 str_name = str(basename) 

301 temp_export_dir = file_io.join( 

302 compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(str_name))) 

303 return temp_export_dir 

304 

305 

306def export_outputs_for_mode( 

307 mode, serving_export_outputs=None, predictions=None, loss=None, 

308 metrics=None): 

309 """Util function for constructing a `ExportOutput` dict given a mode. 

310 

311 The returned dict can be directly passed to `build_all_signature_defs` helper 

312 function as the `export_outputs` argument, used for generating a SignatureDef 

313 map. 

314 

315 Args: 

316 mode: A `ModeKeys` specifying the mode. 

317 serving_export_outputs: Describes the output signatures to be exported to 

318 `SavedModel` and used during serving. Should be a dict or None. 

319 predictions: A dict of Tensors or single Tensor representing model 

320 predictions. This argument is only used if serving_export_outputs is not 

321 set. 

322 loss: A dict of Tensors or single Tensor representing calculated loss. 

323 metrics: A dict of (metric_value, update_op) tuples, or a single tuple. 

324 metric_value must be a Tensor, and update_op must be a Tensor or Op 

325 

326 Returns: 

327 Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object 

328 The key is the expected SignatureDef key for the mode. 

329 

330 Raises: 

331 ValueError: if an appropriate ExportOutput cannot be found for the mode. 

332 """ 

333 if mode not in SIGNATURE_KEY_MAP: 

334 raise ValueError( 

335 f'Export output type not found for `mode`: {mode}. Expected one of: ' 

336 f'{list(SIGNATURE_KEY_MAP.keys())}.\n' 

337 'One likely error is that V1 Estimator Modekeys were somehow passed to ' 

338 'this function. Please ensure that you are using the new ModeKeys.') 

339 signature_key = SIGNATURE_KEY_MAP[mode] 

340 if mode_keys.is_predict(mode): 

341 return get_export_outputs(serving_export_outputs, predictions) 

342 elif mode_keys.is_train(mode): 

343 return {signature_key: export_output_lib.TrainOutput( 

344 loss=loss, predictions=predictions, metrics=metrics)} 

345 else: 

346 return {signature_key: export_output_lib.EvalOutput( 

347 loss=loss, predictions=predictions, metrics=metrics)} 

348 

349 

350def get_export_outputs(export_outputs, predictions): 

351 """Validate export_outputs or create default export_outputs. 

352 

353 Args: 

354 export_outputs: Describes the output signatures to be exported to 

355 `SavedModel` and used during serving. Should be a dict or None. 

356 predictions: Predictions `Tensor` or dict of `Tensor`. 

357 

358 Returns: 

359 Valid export_outputs dict 

360 

361 Raises: 

362 TypeError: if export_outputs is not a dict or its values are not 

363 ExportOutput instances. 

364 """ 

365 if export_outputs is None: 

366 default_output = export_output_lib.PredictOutput(predictions) 

367 export_outputs = { 

368 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output} 

369 

370 if not isinstance(export_outputs, dict): 

371 raise TypeError( 

372 f'`export_outputs` must be dict, received: {export_outputs}.') 

373 for v in export_outputs.values(): 

374 if not isinstance(v, export_output_lib.ExportOutput): 

375 raise TypeError( 

376 'Values in `export_outputs` must be ExportOutput objects, ' 

377 f'received: {export_outputs}.') 

378 

379 _maybe_add_default_serving_output(export_outputs) 

380 

381 return export_outputs 

382 

383 

384def _maybe_add_default_serving_output(export_outputs): 

385 """Add a default serving output to the export_outputs if not present. 

386 

387 Args: 

388 export_outputs: Describes the output signatures to be exported to 

389 `SavedModel` and used during serving. Should be a dict. 

390 

391 Returns: 

392 export_outputs dict with default serving signature added if necessary 

393 

394 Raises: 

395 ValueError: if multiple export_outputs were provided without a default 

396 serving key. 

397 """ 

398 if len(export_outputs) == 1: 

399 (key, value), = export_outputs.items() 

400 if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 

401 export_outputs[ 

402 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value 

403 if len(export_outputs) > 1: 

404 if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

405 not in export_outputs): 

406 raise ValueError( 

407 'Multiple `export_outputs` were provided, but none of them are ' 

408 'specified as the default. Use' 

409 '`tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY` to ' 

410 'specify a default.') 

411 

412 return export_outputs 

413# LINT.ThenChange(//keras/saving/utils_v1/export_utils.py)