Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/loader_impl.py: 31%

170 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Loader implementation for SavedModel with hermetic, language-neutral exports. 

16""" 

17 

18import os 

19import sys 

20 

21from google.protobuf import message 

22from google.protobuf import text_format 

23 

24from tensorflow.core.framework import graph_debug_info_pb2 

25from tensorflow.core.protobuf import meta_graph_pb2 

26from tensorflow.core.protobuf import saved_model_pb2 

27from tensorflow.python.framework import ops 

28from tensorflow.python.lib.io import file_io 

29from tensorflow.python.ops import variables 

30from tensorflow.python.platform import tf_logging 

31from tensorflow.python.saved_model import constants 

32from tensorflow.python.saved_model import path_helpers 

33from tensorflow.python.saved_model import signature_def_utils 

34from tensorflow.python.saved_model import utils_impl as saved_model_utils 

35from tensorflow.python.saved_model.pywrap_saved_model import metrics 

36from tensorflow.python.training import saver as tf_saver 

37from tensorflow.python.util import compat 

38from tensorflow.python.util import deprecation 

39from tensorflow.python.util.tf_export import tf_export 

40 

41# API label for SavedModel metrics. 

42_LOADER_LABEL = "loader" 

43 

44 

45def parse_saved_model_with_debug_info(export_dir): 

46 """Reads the savedmodel as well as the graph debug info. 

47 

48 Args: 

49 export_dir: Directory containing the SavedModel and GraphDebugInfo files. 

50 

51 Returns: 

52 `SavedModel` and `GraphDebugInfo` protocol buffers. 

53 

54 Raises: 

55 IOError: If the saved model file does not exist, or cannot be successfully 

56 parsed. Missing graph debug info file is fine. 

57 """ 

58 saved_model = parse_saved_model(export_dir) 

59 

60 debug_info_path = file_io.join( 

61 path_helpers.get_debug_dir(export_dir), 

62 constants.DEBUG_INFO_FILENAME_PB) 

63 debug_info = graph_debug_info_pb2.GraphDebugInfo() 

64 if file_io.file_exists(debug_info_path): 

65 with file_io.FileIO(debug_info_path, "rb") as debug_file: 

66 try: 

67 debug_info.ParseFromString(debug_file.read()) 

68 except message.DecodeError as e: 

69 raise IOError(f"Cannot parse file {debug_info_path}: {e}.") 

70 

71 return (saved_model, debug_info) 

72 

73 

74@tf_export("__internal__.saved_model.parse_saved_model", v1=[]) 

75def parse_saved_model(export_dir): 

76 """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`. 

77 

78 Args: 

79 export_dir: String or Pathlike, path to the directory containing the 

80 SavedModel file. 

81 

82 Returns: 

83 A `SavedModel` protocol buffer. 

84 

85 Raises: 

86 IOError: If the file does not exist, or cannot be successfully parsed. 

87 """ 

88 # Build the path to the SavedModel in pbtxt format. 

89 path_to_pbtxt = file_io.join( 

90 compat.as_bytes(compat.path_to_str(export_dir)), 

91 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 

92 # Build the path to the SavedModel in pb format. 

93 path_to_pb = file_io.join( 

94 compat.as_bytes(compat.path_to_str(export_dir)), 

95 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 

96 

97 # Parse the SavedModel protocol buffer. 

98 saved_model = saved_model_pb2.SavedModel() 

99 if file_io.file_exists(path_to_pb): 

100 with file_io.FileIO(path_to_pb, "rb") as f: 

101 file_content = f.read() 

102 try: 

103 saved_model.ParseFromString(file_content) 

104 return saved_model 

105 except message.DecodeError as e: 

106 raise IOError(f"Cannot parse file {path_to_pb}: {str(e)}.") 

107 elif file_io.file_exists(path_to_pbtxt): 

108 with file_io.FileIO(path_to_pbtxt, "rb") as f: 

109 file_content = f.read() 

110 try: 

111 text_format.Merge(file_content.decode("utf-8"), saved_model) 

112 return saved_model 

113 except text_format.ParseError as e: 

114 raise IOError(f"Cannot parse file {path_to_pbtxt}: {str(e)}.") 

115 else: 

116 raise IOError( 

117 f"SavedModel file does not exist at: {export_dir}{os.path.sep}" 

118 f"{{{constants.SAVED_MODEL_FILENAME_PBTXT}|" 

119 f"{constants.SAVED_MODEL_FILENAME_PB}}}") 

120 

121 

122def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): 

123 """Gets the asset tensors, if defined in the meta graph def to load. 

124 

125 Args: 

126 export_dir: Directory where the SavedModel is located. 

127 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 

128 import_scope: Optional `string` -- if specified, prepend this followed by 

129 '/' to all returned asset tensor names. 

130 

131 Returns: 

132 A dictionary of asset tensors, keyed by the name of the asset tensor. The 

133 value in the map corresponds to the absolute path of the asset file. 

134 """ 

135 # Collection-def that may contain the assets key. 

136 collection_def = meta_graph_def_to_load.collection_def 

137 

138 asset_tensor_dict = {} 

139 asset_protos = [] 

140 

141 if meta_graph_def_to_load.asset_file_def: 

142 asset_protos = meta_graph_def_to_load.asset_file_def 

143 elif constants.ASSETS_KEY in collection_def: 

144 assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value 

145 for asset_any_proto in assets_any_proto: 

146 asset_proto = meta_graph_pb2.AssetFileDef() 

147 asset_any_proto.Unpack(asset_proto) 

148 asset_protos.append(asset_proto) 

149 

150 # Location of the assets for SavedModel. 

151 assets_directory = file_io.join( 

152 compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY)) 

153 # Process each asset and add it to the asset tensor dictionary. 

154 for asset_proto in asset_protos: 

155 tensor_name = asset_proto.tensor_info.name 

156 if import_scope: 

157 tensor_name = "%s/%s" % (import_scope, tensor_name) 

158 asset_tensor_dict[tensor_name] = file_io.join( 

159 compat.as_bytes(assets_directory), 

160 compat.as_bytes(asset_proto.filename)) 

161 

162 return asset_tensor_dict 

163 

164 

165def _get_main_op_tensor( 

166 meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY): 

167 """Gets the main op tensor, if one exists. 

168 

169 Args: 

170 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 

171 init_op_key: name of the collection to check; should be one of MAIN_OP_KEY 

172 or the deprecated LEGACY_INIT_OP_KEY 

173 

174 Returns: 

175 The main op tensor, if it exists and `None` otherwise. 

176 

177 Raises: 

178 RuntimeError: If the collection def corresponding to the main op key has 

179 other than exactly one tensor. 

180 """ 

181 # TODO(kathywu): Rename this method to _get_op_from_collection when 

182 # dependency from SavedModelEstimator is removed. 

183 collection_def = meta_graph_def_to_load.collection_def 

184 init_op = None 

185 if init_op_key in collection_def: 

186 init_op_list = collection_def[init_op_key].node_list.value 

187 if len(init_op_list) != 1: 

188 raise RuntimeError("Expected exactly one SavedModel init op. " 

189 f"Found {len(init_op_list)}: {init_op_list}.") 

190 init_op = ops.get_collection(init_op_key)[0] 

191 return init_op 

192 

193 

194def _get_op_from_collection(meta_graph_def, op_key): 

195 return _get_main_op_tensor(meta_graph_def, op_key) 

196 

197 

198def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope): 

199 """Retrieve op stored in the imported meta graph's signature def.""" 

200 if op_signature_key in meta_graph_def.signature_def: 

201 return signature_def_utils.load_op_from_signature_def( 

202 meta_graph_def.signature_def[op_signature_key], op_signature_key, 

203 import_scope) 

204 else: 

205 return None 

206 

207 

208def get_init_op(meta_graph_def, import_scope=None): 

209 return (_get_op_from_signature_def( 

210 meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or 

211 _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or 

212 _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY)) 

213 

214 

215def get_train_op(meta_graph_def, import_scope=None): 

216 train_op = _get_op_from_signature_def( 

217 meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope) 

218 if train_op is None: 

219 train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY) 

220 return train_op 

221 

222 

223@tf_export(v1=[ 

224 "saved_model.contains_saved_model", 

225 "saved_model.maybe_saved_model_directory", 

226 "saved_model.loader.maybe_saved_model_directory" 

227]) 

228@deprecation.deprecated_endpoints( 

229 "saved_model.loader.maybe_saved_model_directory") 

230def maybe_saved_model_directory(export_dir): 

231 """Checks whether the provided export directory could contain a SavedModel. 

232 

233 Note that the method does not load any data by itself. If the method returns 

234 `false`, the export directory definitely does not contain a SavedModel. If the 

235 method returns `true`, the export directory may contain a SavedModel but 

236 provides no guarantee that it can be loaded. 

237 

238 Args: 

239 export_dir: Absolute string path to possible export location. For example, 

240 '/my/foo/model'. 

241 

242 Returns: 

243 True if the export directory contains SavedModel files, False otherwise. 

244 """ 

245 txt_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) 

246 pb_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PB) 

247 return file_io.file_exists(txt_path) or file_io.file_exists(pb_path) 

248 

249 

250@tf_export("saved_model.contains_saved_model", v1=[]) 

251def contains_saved_model(export_dir): 

252 """Checks whether the provided export directory could contain a SavedModel. 

253 

254 Note that the method does not load any data by itself. If the method returns 

255 `false`, the export directory definitely does not contain a SavedModel. If the 

256 method returns `true`, the export directory may contain a SavedModel but 

257 provides no guarantee that it can be loaded. 

258 

259 Args: 

260 export_dir: Absolute path to possible export location. For example, 

261 '/my/foo/model'. 

262 

263 Returns: 

264 True if the export directory contains SavedModel files, False otherwise. 

265 """ 

266 if isinstance(export_dir, os.PathLike): 

267 export_dir = os.fspath(export_dir) 

268 return maybe_saved_model_directory(export_dir) 

269 

270 

271@tf_export(v1=["saved_model.load", "saved_model.loader.load"]) 

272@deprecation.deprecated( 

273 None, 

274 "Use `tf.saved_model.load` instead.") 

275def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): 

276 """Loads the model from a SavedModel as specified by tags. 

277 

278 Args: 

279 sess: The TensorFlow session to restore the variables. 

280 tags: Set of string tags to identify the required MetaGraphDef. These should 

281 correspond to the tags used when saving the variables using the 

282 SavedModel `save()` API. 

283 export_dir: Directory in which the SavedModel protocol buffer and variables 

284 to be loaded are located. 

285 import_scope: Optional `string` -- if specified, prepend this string 

286 followed by '/' to all loaded tensor names. This scope is applied to 

287 tensor instances loaded into the passed session, but it is *not* written 

288 through to the static `MetaGraphDef` protocol buffer that is returned. 

289 **saver_kwargs: Optional keyword arguments passed through to Saver. 

290 

291 Returns: 

292 The `MetaGraphDef` protocol buffer loaded in the provided session. This 

293 can be used to further extract signature-defs, collection-defs, etc. 

294 

295 Raises: 

296 RuntimeError: MetaGraphDef associated with the tags cannot be found. 

297 

298 @compatibility(TF2) 

299 

300 `tf.compat.v1.saved_model.load` or `tf.compat.v1.saved_model.loader.load` is 

301 not compatible with eager execution. Please use `tf.saved_model.load` instead 

302 to load your model. You can refer to the [SavedModel guide] 

303 (https://www.tensorflow.org/guide/saved_model) for more information as well as 

304 "Importing SavedModels from TensorFlow 1.x" in the [`tf.saved_model.load`] 

305 (https://www.tensorflow.org/api_docs/python/tf/saved_model/load) docstring. 

306 

307 #### How to Map Arguments 

308 

309 | TF1 Arg Name | TF2 Arg Name | Note | 

310 | :-------------------- | :-------------- | :------------------------- | 

311 | `sess` | Not supported | - | 

312 | `tags` | `tags` | - | 

313 | `export_dir` | `export_dir` | - | 

314 | `import_scope` | Not supported | Name scopes are not needed. 

315 : : : By default, variables are : 

316 : : : associated with the loaded : 

317 : : : object and function names : 

318 : : : are deduped. : 

319 | `saver_kwargs` | Not supported | - | 

320 

321 #### Before & After Usage Example 

322 

323 Before: 

324 

325 ``` 

326 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 

327 tf.compat.v1.saved_model.loader.load(sess, ["foo-tag"], export_dir) 

328 ``` 

329 

330 After: 

331 

332 ``` 

333 model = tf.saved_model.load(export_dir, tags=["foo-tag"]) 

334 ``` 

335 @end_compatibility 

336 """ 

337 loader = SavedModelLoader(export_dir) 

338 return loader.load(sess, tags, import_scope, **saver_kwargs) 

339 

340 

341class SavedModelLoader(object): 

342 """Load graphs and restore variable values from a `SavedModel`.""" 

343 

344 def __init__(self, export_dir): 

345 """Creates a `SavedModelLoader`. 

346 

347 Args: 

348 export_dir: Directory in which the SavedModel protocol buffer and 

349 variables to be loaded are located. 

350 """ 

351 self._export_dir = export_dir 

352 self._variables_path = path_helpers.get_variables_path(export_dir) 

353 self._saved_model = parse_saved_model(export_dir) 

354 

355 @property 

356 def export_dir(self): 

357 """Directory containing the SavedModel.""" 

358 return self._export_dir 

359 

360 @property 

361 def variables_path(self): 

362 """Path to variable checkpoint files.""" 

363 return self._variables_path 

364 

365 @property 

366 def saved_model(self): 

367 """SavedModel object parsed from the export directory.""" 

368 return self._saved_model 

369 

370 def get_meta_graph_def_from_tags(self, tags): 

371 """Return MetaGraphDef with the exact specified tags. 

372 

373 Args: 

374 tags: A list or set of string tags that identify the MetaGraphDef. 

375 

376 Returns: 

377 MetaGraphDef with the same tags. 

378 

379 Raises: 

380 RuntimeError: if no metagraphs were found with the associated tags. 

381 """ 

382 found_match = False 

383 available_tags = [] 

384 for meta_graph_def in self._saved_model.meta_graphs: 

385 available_tags.append(set(meta_graph_def.meta_info_def.tags)) 

386 if set(meta_graph_def.meta_info_def.tags) == set(tags): 

387 meta_graph_def_to_load = meta_graph_def 

388 found_match = True 

389 break 

390 

391 if not found_match: 

392 raise RuntimeError( 

393 f"MetaGraphDef associated with tags {str(tags).strip('[]')} " 

394 "could not be found in SavedModel, with available tags " 

395 f"'{available_tags}'. To inspect available tag-sets in" 

396 " the SavedModel, please use the SavedModel CLI: `saved_model_cli`.") 

397 return meta_graph_def_to_load 

398 

399 def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): 

400 """Load ops and nodes from SavedModel MetaGraph into graph. 

401 

402 Args: 

403 graph: tf.Graph object. 

404 tags: a set of string tags identifying a MetaGraphDef. 

405 import_scope: Optional `string` -- if specified, prepend this string 

406 followed by '/' to all loaded tensor names. This scope is applied to 

407 tensor instances loaded into the passed session, but it is *not* written 

408 through to the static `MetaGraphDef` protocol buffer that is returned. 

409 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 

410 

411 Returns: 

412 A tuple of 

413 * Saver defined by the MetaGraph, which can be used to restore the 

414 variable values. 

415 * List of `Operation`/`Tensor` objects returned from 

416 `tf.import_graph_def` (may be `None`). 

417 """ 

418 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 

419 if sys.byteorder == "big": 

420 saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", 

421 "big") 

422 with graph.as_default(): 

423 return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access 

424 meta_graph_def, import_scope=import_scope, **saver_kwargs) 

425 

426 def restore_variables(self, sess, saver, import_scope=None): 

427 """Restore SavedModel variable values into the session. 

428 

429 Args: 

430 sess: tf.compat.v1.Session to restore variable values. 

431 saver: a tf.compat.v1.train.Saver object. Can be None if there are no 

432 variables in graph. This may be the saver returned by the load_graph() 

433 function, or a default `tf.compat.v1.train.Saver()`. 

434 import_scope: Optional `string` -- if specified, prepend this string 

435 followed by '/' to all loaded tensor names. This scope is applied to 

436 tensor instances loaded into the passed session, but it is *not* written 

437 through to the static `MetaGraphDef` protocol buffer that is returned. 

438 

439 Raises: 

440 ValueError: if no saver was passed to the saver argument, and there are 

441 variables in the graph. 

442 """ 

443 with sess.graph.as_default(): 

444 if (saver is None and 

445 not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access 

446 tf_logging.info("The specified SavedModel has no variables; no " 

447 "checkpoints were restored.") 

448 elif isinstance(saver, tf_saver.Saver): 

449 saver.restore(sess, self._variables_path) 

450 else: 

451 raise ValueError( 

452 "No tf.train.Saver object was passed to the function " 

453 "`SavedModelLoader.restore_variables`. Since there are variables in" 

454 " the graph, a saver is required.") 

455 

456 def run_init_ops(self, sess, tags, import_scope=None): 

457 """Run initialization ops defined in the `MetaGraphDef`. 

458 

459 Args: 

460 sess: tf.compat.v1.Session to restore variable values. 

461 tags: a set of string tags identifying a MetaGraphDef. 

462 import_scope: Optional `string` -- if specified, prepend this string 

463 followed by '/' to all loaded tensor names. This scope is applied to 

464 tensor instances loaded into the passed session, but it is *not* written 

465 through to the static `MetaGraphDef` protocol buffer that is returned. 

466 """ 

467 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 

468 with sess.graph.as_default(): 

469 # Get asset tensors, if any. 

470 asset_tensors_dictionary = get_asset_tensors( 

471 self._export_dir, meta_graph_def, import_scope=import_scope) 

472 

473 init_op = get_init_op(meta_graph_def, import_scope) 

474 if init_op is not None: 

475 sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary) 

476 

477 def load(self, sess, tags, import_scope=None, **saver_kwargs): 

478 """Load the MetaGraphDef graph and restore variable values into the session. 

479 

480 Args: 

481 sess: tf.compat.v1.Session to restore variable values. 

482 tags: a set of string tags identifying a MetaGraphDef. 

483 import_scope: Optional `string` -- if specified, prepend this string 

484 followed by '/' to all loaded tensor names. This scope is applied to 

485 tensor instances loaded into the passed session, but it is *not* written 

486 through to the static `MetaGraphDef` protocol buffer that is returned. 

487 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 

488 

489 Returns: 

490 `MetagraphDef` proto of the graph that was loaded. 

491 """ 

492 saved_model_proto = parse_saved_model(self._export_dir) 

493 metrics.IncrementReadApi(_LOADER_LABEL) 

494 

495 with sess.graph.as_default(): 

496 saver, _ = self.load_graph(sess.graph, tags, import_scope, 

497 **saver_kwargs) 

498 self.restore_variables(sess, saver, import_scope) 

499 self.run_init_ops(sess, tags, import_scope) 

500 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 

501 

502 if (len(saved_model_proto.meta_graphs) == 1 and 

503 saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 

504 metrics.IncrementRead(write_version="2") 

505 else: 

506 metrics.IncrementRead(write_version="1") 

507 

508 return meta_graph_def