Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/builder_impl.py: 23%

239 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""SavedModel builder implementation.""" 

16 

17import functools 

18import os 

19 

20from google.protobuf.any_pb2 import Any 

21 

22from tensorflow.core.framework import types_pb2 

23from tensorflow.core.protobuf import meta_graph_pb2 

24from tensorflow.core.protobuf import saved_model_pb2 

25from tensorflow.core.protobuf import saver_pb2 

26from tensorflow.python.framework import dtypes 

27from tensorflow.python.framework import ops 

28from tensorflow.python.lib.io import file_io 

29from tensorflow.python.ops import variables 

30from tensorflow.python.platform import tf_logging 

31from tensorflow.python.saved_model import fingerprinting_utils 

32from tensorflow.python.saved_model import path_helpers 

33from tensorflow.python.saved_model import signature_def_utils 

34from tensorflow.python.saved_model.pywrap_saved_model import constants 

35from tensorflow.python.saved_model.pywrap_saved_model import metrics 

36from tensorflow.python.training import saver as tf_saver 

37from tensorflow.python.util import compat 

38from tensorflow.python.util.deprecation import deprecated_args 

39from tensorflow.python.util.tf_export import tf_export 

40 

41# API label for SavedModel metrics. 

42_SAVE_BUILDER_LABEL = "save_v1_builder" 

43 

44 

45# Base class for the SavedModelBuilder that is only used by Tensorflow 

46# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead. 

47@tf_export("__internal__.saved_model.SavedModelBuilder", v1=[]) 

48class _SavedModelBuilder(object): 

49 """Builds the `SavedModel` protocol buffer and saves variables and assets. 

50 

51 The `SavedModelBuilder` class provides the functionality to build a 

52 `SavedModel` protocol buffer. Specifically, this allows multiple meta 

53 graphs to be saved as part of a single language-neutral `SavedModel`, 

54 while sharing variables and assets. 

55 

56 To build a SavedModel, the first meta graph must be saved with variables. 

57 Subsequent meta graphs will simply be saved with their graph definitions. If 

58 assets need to be saved and written or copied to disk, they can be provided 

59 when the meta graph def is added. If multiple meta graph defs are associated 

60 an asset of the same name, only the first version is retained. 

61 

62 Each meta graph added to the SavedModel must be annotated with tags. The tags 

63 provide a means to identify the specific meta graph to load and restore, along 

64 with the shared set of variables and assets. 

65 

66 Typical usage for the `SavedModelBuilder`: 

67 

68 ```python 

69 ... 

70 builder = tf.compat.v1.saved_model.Builder(export_dir) 

71 

72 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 

73 ... 

74 builder.add_meta_graph_and_variables(sess, 

75 ["foo-tag"], 

76 signature_def_map=foo_signatures, 

77 assets_list=foo_assets) 

78 ... 

79 

80 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 

81 ... 

82 builder.add_meta_graph(["bar-tag", "baz-tag"]) 

83 ... 

84 

85 builder.save() 

86 ``` 

87 

88 Note: This function will only be available through the v1 compatibility 

89 library as tf.compat.v1.saved_model.builder.SavedModelBuilder or 

90 tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new 

91 object-based method of creating SavedModels. 

92 """ 

93 

94 def __init__(self, export_dir): 

95 self._saved_model = saved_model_pb2.SavedModel() 

96 self._saved_model.saved_model_schema_version = ( 

97 constants.SAVED_MODEL_SCHEMA_VERSION) 

98 

99 self._export_dir = export_dir 

100 if file_io.file_exists(export_dir): 

101 if file_io.list_directory(export_dir): 

102 raise AssertionError( 

103 f"Export directory {export_dir} already exists, and isn't empty. " 

104 "Please choose a different export directory, or delete all the " 

105 "contents of the specified directory.") 

106 else: 

107 file_io.recursive_create_dir(self._export_dir) 

108 

109 # Boolean to track whether variables and assets corresponding to the 

110 # SavedModel have been saved. Specifically, the first meta graph to be added 

111 # MUST use the add_meta_graph_and_variables() API. Subsequent add operations 

112 # on the SavedModel MUST use the add_meta_graph() API which does not save 

113 # weights. 

114 self._has_saved_variables = False 

115 self._saved_asset_files = set() 

116 

117 def _save_and_write_assets(self, meta_graph_def, assets_list=None): 

118 """Saves asset to the meta graph and writes asset files to disk. 

119 

120 Args: 

121 meta_graph_def: The meta graph def to which the assets will be added. 

122 assets_list: The list where the asset paths are setup. 

123 """ 

124 # Creates a function that adds assets into the meta graph def. 

125 write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def) 

126 asset_filename_map = _maybe_save_assets(write_fn, assets_list) 

127 

128 # Return if there are no assets to write. 

129 if not asset_filename_map: 

130 tf_logging.info("No assets to write.") 

131 return 

132 

133 # Copy assets from source path to destination path. 

134 copy_assets_to_destination_dir(asset_filename_map, self._export_dir, 

135 self._saved_asset_files) 

136 

137 def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): 

138 """Tags the meta graph def and adds it to the SavedModel. 

139 

140 Tags the meta graph def with the supplied tags, adds signature defs to it if 

141 provided and appends the meta graph def to the SavedModel proto. 

142 

143 Args: 

144 meta_graph_def: The meta graph def to add to the SavedModel. 

145 tags: The set of tags to annotate the meta graph def with. 

146 signature_def_map: The map of signature defs to be added to the meta graph 

147 def. 

148 """ 

149 for tag in tags: 

150 meta_graph_def.meta_info_def.tags.append(tag) 

151 

152 if signature_def_map is not None: 

153 for key in signature_def_map: 

154 meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key]) 

155 

156 proto_meta_graph_def = self._saved_model.meta_graphs.add() 

157 proto_meta_graph_def.CopyFrom(meta_graph_def) 

158 

159 def _validate_tensor_info(self, tensor_info): 

160 """Validates the `TensorInfo` proto. 

161 

162 Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and 

163 `dtype` fields exist and are non-empty. 

164 

165 Args: 

166 tensor_info: `TensorInfo` protocol buffer to validate. 

167 

168 Raises: 

169 AssertionError: If the `encoding` or `dtype` fields of the supplied 

170 `TensorInfo` proto are not populated. 

171 """ 

172 if tensor_info is None: 

173 raise AssertionError( 

174 "All TensorInfo protos used in the SignatureDefs must have the name " 

175 "and dtype fields set.") 

176 if tensor_info.WhichOneof("encoding") is None: 

177 # TODO(soergel) validate each of the fields of coo_sparse 

178 raise AssertionError( 

179 f"Invalid `tensor_info`: {tensor_info}. All TensorInfo protos used " 

180 "in the SignatureDefs must have one of the 'encoding' fields (e.g., " 

181 "name or coo_sparse) set.") 

182 if tensor_info.WhichOneof("encoding") == "composite_tensor": 

183 for component in tensor_info.composite_tensor.components: 

184 self._validate_tensor_info(component) 

185 elif tensor_info.dtype == types_pb2.DT_INVALID: 

186 raise AssertionError( 

187 f"Invalid `tensor_info`: {tensor_info}. All TensorInfo protos used in" 

188 " the SignatureDefs must have the dtype field set.") 

189 

190 def _validate_signature_def_map(self, signature_def_map): 

191 """Validates the `SignatureDef` entries in the signature def map. 

192 

193 Validation of entries in the signature def map includes ensuring that the 

194 `name` and `dtype` fields of the TensorInfo protos of the `inputs` and 

195 `outputs` of each `SignatureDef` are populated. Also ensures that reserved 

196 SignatureDef keys for the initialization and train ops are not used. 

197 

198 Args: 

199 signature_def_map: The map of signature defs to be validated. 

200 

201 Raises: 

202 AssertionError: If a TensorInfo is not valid. 

203 KeyError: If a reserved signature key is used in the map. 

204 """ 

205 for signature_def_key in signature_def_map: 

206 signature_def = signature_def_map[signature_def_key] 

207 inputs = signature_def.inputs 

208 outputs = signature_def.outputs 

209 for inputs_key in inputs: 

210 self._validate_tensor_info(inputs[inputs_key]) 

211 for outputs_key in outputs: 

212 self._validate_tensor_info(outputs[outputs_key]) 

213 if constants.INIT_OP_SIGNATURE_KEY in signature_def_map: 

214 raise KeyError( 

215 f"SignatureDef map key \"{constants.INIT_OP_SIGNATURE_KEY}\" is " 

216 "reserved for initialization. Please use a different key.") 

217 if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map: 

218 raise KeyError( 

219 f"SignatureDef map key \"{constants.TRAIN_OP_SIGNATURE_KEY}\" is " 

220 f"reserved for the train op. Please use a different key.") 

221 

222 def _maybe_create_saver(self, saver=None): 

223 """Creates a sharded saver if one does not already exist.""" 

224 if not saver: 

225 # Initialize a saver to generate a sharded output for all saveables in the 

226 # current scope. 

227 saver = tf_saver.Saver( 

228 variables._all_saveable_objects(), # pylint: disable=protected-access 

229 sharded=True, 

230 write_version=saver_pb2.SaverDef.V2, 

231 allow_empty=True) 

232 return saver 

233 

234 def add_meta_graph(self, 

235 tags, 

236 signature_def_map=None, 

237 assets_list=None, 

238 clear_devices=False, 

239 init_op=None, 

240 train_op=None, 

241 saver=None): 

242 """Adds the current meta graph to the SavedModel. 

243 

244 Creates a Saver in the current scope and uses the Saver to export the meta 

245 graph def. Invoking this API requires the `add_meta_graph_and_variables()` 

246 API to have been invoked before. 

247 

248 Args: 

249 tags: The set of tags to annotate the meta graph def with. 

250 signature_def_map: The map of signature defs to be added to the meta graph 

251 def. 

252 assets_list: Assets to be saved with SavedModel. Note 

253 that this list should be a subset of the assets saved as part of 

254 the first meta graph in the SavedModel. 

255 clear_devices: Set to true if the device info on the default graph should 

256 be cleared. 

257 init_op: Op or group of ops to execute when the graph is loaded. Note 

258 that when the init_op is specified it is run after the restore op at 

259 load-time. 

260 train_op: Op or group of opts that trains the model when run. This will 

261 not be run automatically when the graph is loaded, instead saved in 

262 a SignatureDef accessible through the exported MetaGraph. 

263 saver: An instance of tf.compat.v1.train.Saver that will be used to export 

264 the metagraph. If None, a sharded Saver that restores all variables will 

265 be used. 

266 

267 Raises: 

268 AssertionError: If the variables for the SavedModel have not been saved 

269 yet, or if the graph already contains one or more legacy init ops. 

270 """ 

271 if not self._has_saved_variables: 

272 raise AssertionError( 

273 "Graph state including variables and assets has not been saved yet. " 

274 "Please invoke `add_meta_graph_and_variables()` first.") 

275 

276 # Validate the signature def map to ensure all included TensorInfos are 

277 # properly populated. 

278 signature_def_map = signature_def_map or {} 

279 self._validate_signature_def_map(signature_def_map) 

280 

281 # Create a SignatureDef pointing to the graph initialization op, which will 

282 # be added to the MetaGraphDef. 

283 _add_op_to_signature_def_map(signature_def_map, init_op, 

284 constants.INIT_OP_SIGNATURE_KEY) 

285 _add_op_to_signature_def_map(signature_def_map, train_op, 

286 constants.TRAIN_OP_SIGNATURE_KEY) 

287 

288 saver = self._maybe_create_saver(saver) 

289 

290 # The graph almost certainly previously contained at least one Saver, and 

291 # possibly several (e.g. one for loading a pretrained embedding, and another 

292 # for the model weights). Removing the preexisting ones was the 

293 # motivation for the clear_extraneous_savers option, but it turns out that 

294 # there are edge cases where that option breaks the graph. Until that is 

295 # resolved, we just leave the option set to False for now. 

296 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 

297 meta_graph_def = saver.export_meta_graph( 

298 clear_devices=clear_devices, strip_default_attrs=True) 

299 

300 # Save asset files and write them to disk, if any. 

301 self._save_and_write_assets(meta_graph_def, assets_list) 

302 

303 # Tag the meta graph def and add it to the SavedModel. 

304 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 

305 

306 def add_meta_graph_and_variables(self, 

307 sess, 

308 tags, 

309 signature_def_map=None, 

310 assets_list=None, 

311 clear_devices=False, 

312 init_op=None, 

313 train_op=None, 

314 strip_default_attrs=False, 

315 saver=None): 

316 # pylint: disable=line-too-long 

317 """Adds the current meta graph to the SavedModel and saves variables. 

318 

319 Creates a Saver to save the variables from the provided session. Exports the 

320 corresponding meta graph def. This function assumes that the variables to be 

321 saved have been initialized. For a given `SavedModelBuilder`, this API must 

322 be called exactly once and for the first meta graph to save. For subsequent 

323 meta graph defs to be added, the `add_meta_graph()` API must be used. 

324 

325 Args: 

326 sess: The TensorFlow session from which to save the meta graph and 

327 variables. 

328 tags: The set of tags with which to save the meta graph. 

329 signature_def_map: The map of signature def map to add to the meta graph 

330 def. 

331 assets_list: Assets to be saved with SavedModel. 

332 clear_devices: Set to true if the device info on the default graph should 

333 be cleared. 

334 init_op: Op or group of ops to execute when the graph is loaded. Note 

335 that when the init_op is specified it is run after the restore op at 

336 load-time. 

337 train_op: Op or group of ops that trains the model when run. This will 

338 not be run automatically when the graph is loaded, instead saved in 

339 a SignatureDef accessible through the exported MetaGraph. 

340 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 

341 removed from the NodeDefs. For a detailed guide, see 

342 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 

343 saver: An instance of tf.compat.v1.train.Saver that will be used to export the 

344 metagraph and save variables. If None, a sharded Saver that restores 

345 all variables will be used. 

346 

347 """ 

348 # pylint: enable=line-too-long 

349 if self._has_saved_variables: 

350 raise AssertionError("Graph state including variables and assets has " 

351 "already been saved. Please invoke " 

352 "`add_meta_graph()` instead.") 

353 

354 # Validate the signature def map to ensure all included TensorInfos are 

355 # properly populated. 

356 signature_def_map = signature_def_map or {} 

357 self._validate_signature_def_map(signature_def_map) 

358 

359 # Create a SignatureDef pointing to the graph initialization op, which will 

360 # be added to the MetaGraphDef. 

361 _add_op_to_signature_def_map(signature_def_map, init_op, 

362 constants.INIT_OP_SIGNATURE_KEY) 

363 _add_op_to_signature_def_map(signature_def_map, train_op, 

364 constants.TRAIN_OP_SIGNATURE_KEY) 

365 

366 path_helpers.get_or_create_variables_dir(self._export_dir) 

367 variables_path = path_helpers.get_variables_path(self._export_dir) 

368 

369 saver = self._maybe_create_saver(saver) 

370 

371 # Save the variables. Also, disable writing the checkpoint state proto. The 

372 # file is not used during SavedModel loading. In addition, since a 

373 # SavedModel can be copied or moved, this avoids the checkpoint state to 

374 # become outdated. 

375 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 

376 

377 # Export the meta graph def. 

378 

379 # The graph almost certainly previously contained at least one Saver, and 

380 # possibly several (e.g. one for loading a pretrained embedding, and another 

381 # for the model weights). Removing the preexisting ones was the 

382 # motivation for the clear_extraneous_savers option, but it turns out that 

383 # there are edge cases where that option breaks the graph. Until that is 

384 # resolved, we just leave the option set to False for now. 

385 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 

386 meta_graph_def = saver.export_meta_graph( 

387 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 

388 

389 # Save asset files and write them to disk, if any. 

390 self._save_and_write_assets(meta_graph_def, assets_list) 

391 

392 # Tag the meta graph def and add it to the SavedModel. 

393 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 

394 

395 # Mark this instance of SavedModel as having saved variables, such that 

396 # subsequent attempts to save variables will fail. 

397 self._has_saved_variables = True 

398 

399 def save(self, as_text=False): 

400 """Writes a `SavedModel` protocol buffer to disk. 

401 

402 The function writes the SavedModel protocol buffer to the export directory 

403 in a serialized format. 

404 

405 Args: 

406 as_text: Writes the SavedModel protocol buffer in text format to 

407 disk. Protocol buffers in text format are useful for debugging, but 

408 parsing fails when it encounters an unknown field and so is not forward 

409 compatible. This means changes to TensorFlow may prevent deployment of 

410 new text format SavedModels to existing serving binaries. Do not deploy 

411 `as_text` SavedModels to production. 

412 

413 Returns: 

414 The path to which the SavedModel protocol buffer was written. 

415 """ 

416 metrics.IncrementWriteApi(_SAVE_BUILDER_LABEL) 

417 if not file_io.file_exists(self._export_dir): 

418 file_io.recursive_create_dir(self._export_dir) 

419 

420 saved_model_serialized = self._saved_model.SerializeToString( 

421 deterministic=True) 

422 

423 if as_text: 

424 path = file_io.join( 

425 compat.as_bytes(self._export_dir), 

426 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 

427 file_io.write_string_to_file(path, str(self._saved_model)) 

428 else: 

429 path = file_io.join( 

430 compat.as_bytes(self._export_dir), 

431 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 

432 file_io.write_string_to_file( 

433 path, saved_model_serialized) 

434 tf_logging.info("SavedModel written to: %s", compat.as_text(path)) 

435 metrics.IncrementWrite(write_version="1") 

436 

437 # Placeholder for internal TF1 model fingerprint write 

438 

439 return path 

440 

441 

442@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) # pylint: disable=missing-docstring 

443class SavedModelBuilder(_SavedModelBuilder): 

444 __doc__ = _SavedModelBuilder.__doc__.replace("assets_list", 

445 "assets_collection") 

446 

447 def __init__(self, export_dir): 

448 super(SavedModelBuilder, self).__init__(export_dir=export_dir) 

449 

450 def _add_collections(self, assets_collection, main_op, train_op): 

451 """Add asset and op collections to be saved.""" 

452 # Save asset files and write them to disk, if any. 

453 self._save_and_write_assets(assets_collection) 

454 

455 self._maybe_add_main_op(main_op) 

456 

457 self._add_train_op(train_op) 

458 

459 def _save_and_write_assets(self, assets_collection_to_add=None): 

460 """Saves asset to the meta graph and writes asset files to disk. 

461 

462 Args: 

463 assets_collection_to_add: The collection where the asset paths are setup. 

464 """ 

465 # Add assets to the collection with key `saved_model.ASSETS_KEY`, in the 

466 # graph. 

467 asset_filename_map = _maybe_save_assets(_add_asset_to_collection, 

468 assets_collection_to_add) 

469 

470 # Return if there are no assets to write. 

471 if not asset_filename_map: 

472 tf_logging.info("No assets to write.") 

473 return 

474 

475 # Copy assets from source path to destination path. 

476 copy_assets_to_destination_dir(asset_filename_map, self._export_dir, 

477 self._saved_asset_files) 

478 

479 def _maybe_add_main_op(self, main_op): 

480 """Adds main op to the SavedModel. 

481 

482 Args: 

483 main_op: Main op to run as part of graph initialization. If None, no main 

484 op will be added to the graph. 

485 

486 Raises: 

487 TypeError: If the main op is provided but is not of type `Operation`. 

488 ValueError: if the Graph already contains an init op. 

489 """ 

490 if main_op is None: 

491 return 

492 

493 if not isinstance(main_op, ops.Operation): 

494 raise TypeError(f"Expected {main_op} to be an Operation but got type " 

495 f"{type(main_op)} instead.") 

496 

497 # Validate that no other init ops have been added to this graph already. 

498 # We check main_op and legacy_init_op for thoroughness and explicitness. 

499 for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY): 

500 if ops.get_collection(init_op_key): 

501 raise ValueError( 

502 "Graph already contains one or more main ops under the " 

503 f"collection {init_op_key}.") 

504 

505 ops.add_to_collection(constants.MAIN_OP_KEY, main_op) 

506 

507 def _add_train_op(self, train_op): 

508 """Add train op to the SavedModel. 

509 

510 Note that this functionality is in development, and liable to be 

511 moved elsewhere. 

512 

513 Args: 

514 train_op: Op or group of ops that are used for training. These are stored 

515 as a collection with key TRAIN_OP_KEY, but not executed. 

516 

517 Raises: 

518 TypeError if Train op is not of type `Operation`. 

519 """ 

520 if train_op is not None: 

521 if (not isinstance(train_op, ops.Tensor) and 

522 not isinstance(train_op, ops.Operation)): 

523 raise TypeError(f"`train_op` {train_op} needs to be a Tensor or Op.") 

524 ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) 

525 

526 @deprecated_args(None, 

527 "Pass your op to the equivalent parameter main_op instead.", 

528 "legacy_init_op") 

529 def add_meta_graph(self, 

530 tags, 

531 signature_def_map=None, 

532 assets_collection=None, 

533 legacy_init_op=None, 

534 clear_devices=False, 

535 main_op=None, 

536 strip_default_attrs=False, 

537 saver=None): 

538 if not self._has_saved_variables: 

539 raise AssertionError( 

540 "Graph state including variables and assets has not been saved yet. " 

541 "Please invoke `add_meta_graph_and_variables()` first.") 

542 

543 # Validate the signature def map to ensure all included TensorInfos are 

544 # properly populated. 

545 signature_def_map = signature_def_map or {} 

546 self._validate_signature_def_map(signature_def_map) 

547 

548 # legacy_init_op is deprecated, and going away in TF 2.0. 

549 # Re-mapping to main_op, as treatment is identical regardless. 

550 main_op = main_op if main_op is not None else legacy_init_op 

551 

552 # Add assets and ops 

553 self._add_collections(assets_collection, main_op, None) 

554 

555 saver = self._maybe_create_saver(saver) 

556 

557 # The graph almost certainly previously contained at least one Saver, and 

558 # possibly several (e.g. one for loading a pretrained embedding, and another 

559 # for the model weights). Removing the preexisting ones was the 

560 # motivation for the clear_extraneous_savers option, but it turns out that 

561 # there are edge cases where that option breaks the graph. Until that is 

562 # resolved, we just leave the option set to False for now. 

563 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 

564 meta_graph_def = saver.export_meta_graph( 

565 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 

566 

567 # Tag the meta graph def and add it to the SavedModel. 

568 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 

569 

570 @deprecated_args(None, 

571 "Pass your op to the equivalent parameter main_op instead.", 

572 "legacy_init_op") 

573 def add_meta_graph_and_variables(self, 

574 sess, 

575 tags, 

576 signature_def_map=None, 

577 assets_collection=None, 

578 legacy_init_op=None, 

579 clear_devices=False, 

580 main_op=None, 

581 strip_default_attrs=False, 

582 saver=None): 

583 if self._has_saved_variables: 

584 raise AssertionError("Graph state including variables and assets has " 

585 "already been saved. Please invoke " 

586 "`add_meta_graph()` instead.") 

587 

588 # Validate the signature def map to ensure all included TensorInfos are 

589 # properly populated. 

590 signature_def_map = signature_def_map or {} 

591 self._validate_signature_def_map(signature_def_map) 

592 

593 # legacy_init_op is deprecated, and going away in TF 2.0. 

594 # Re-mapping to main_op, as treatment is identical regardless. 

595 main_op = main_op or legacy_init_op 

596 

597 # Add assets and ops 

598 self._add_collections(assets_collection, main_op, None) 

599 

600 path_helpers.get_or_create_variables_dir(self._export_dir) 

601 variables_path = path_helpers.get_variables_path(self._export_dir) 

602 

603 saver = self._maybe_create_saver(saver) 

604 

605 # Save the variables. Also, disable writing the checkpoint state proto. The 

606 # file is not used during SavedModel loading. In addition, since a 

607 # SavedModel can be copied or moved, this avoids the checkpoint state to 

608 # become outdated. 

609 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 

610 

611 # Export the meta graph def. 

612 

613 # The graph almost certainly previously contained at least one Saver, and 

614 # possibly several (e.g. one for loading a pretrained embedding, and another 

615 # for the model weights). Removing the preexisting ones was the 

616 # motivation for the clear_extraneous_savers option, but it turns out that 

617 # there are edge cases where that option breaks the graph. Until that is 

618 # resolved, we just leave the option set to False for now. 

619 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 

620 meta_graph_def = saver.export_meta_graph( 

621 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 

622 

623 # Tag the meta graph def and add it to the SavedModel. 

624 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 

625 

626 # Mark this instance of SavedModel as having saved variables, such that 

627 # subsequent attempts to save variables will fail. 

628 self._has_saved_variables = True 

629 

630 add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace( 

631 "assets_list", "assets_collection") 

632 add_meta_graph_and_variables.__doc__ = \ 

633 _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace( 

634 "assets_list", "assets_collection") 

635 

636 

637def _maybe_save_assets(write_fn, assets_to_add=None): 

638 """Saves assets to the meta graph. 

639 

640 Args: 

641 write_fn: A function callback that writes assets into meta graph. 

642 assets_to_add: The list where the asset paths are setup. 

643 

644 Returns: 

645 A dict of asset basenames for saving to the original full path to the asset. 

646 

647 Raises: 

648 ValueError: Indicating an invalid filepath tensor. 

649 """ 

650 # Map of target file names to original filenames 

651 asset_filename_map = {} 

652 

653 if assets_to_add is None: 

654 tf_logging.info("No assets to save.") 

655 return asset_filename_map 

656 

657 # Iterate over the supplied assets, build the `AssetFile` proto and add them 

658 # to the meta graph. 

659 for asset_tensor in assets_to_add: 

660 asset_source_filepath = _asset_path_from_tensor(asset_tensor) 

661 if not asset_source_filepath: 

662 raise ValueError(f"Asset filepath tensor {asset_tensor} in is invalid.") 

663 

664 asset_filename = get_asset_filename_to_add( 

665 asset_source_filepath, asset_filename_map) 

666 

667 # Call the passed-in function that builds AssetFileDef proto and adds it 

668 # to either the collection or asset_file_def field of the meta graph. 

669 # Note that this should be done even when the file is a duplicate of an 

670 # already-added file, as the tensor reference should still exist. 

671 write_fn(asset_filename, asset_tensor) 

672 

673 # In the cases where we are adding a duplicate, this will result in the 

674 # last of the filepaths being the one used for copying the file to the 

675 # SavedModel. Since the files in question are the same, it doesn't matter 

676 # either way. 

677 asset_filename_map[asset_filename] = asset_source_filepath 

678 

679 tf_logging.info("Assets added to graph.") 

680 return asset_filename_map 

681 

682 

683def get_asset_filename_to_add(asset_filepath, asset_filename_map): 

684 """Get a unique basename to add to the SavedModel if this file is unseen. 

685 

686 Assets come from users as full paths, and we save them out to the 

687 SavedModel as basenames. In some cases, the basenames collide. Here, 

688 we dedupe asset basenames by first checking if the file is the same, 

689 and, if different, generate and return an index-suffixed basename 

690 that can be used to add the asset to the SavedModel. 

691 

692 Args: 

693 asset_filepath: the full path to the asset that is being saved 

694 asset_filename_map: a dict of filenames used for saving the asset in 

695 the SavedModel to full paths from which the filenames were derived. 

696 

697 Returns: 

698 Uniquified filename string if the file is not a duplicate, or the original 

699 filename if the file has already been seen and saved. 

700 """ 

701 asset_filename = os.path.basename(asset_filepath) 

702 

703 if asset_filename not in asset_filename_map: 

704 # This is an unseen asset. Safe to add. 

705 return asset_filename 

706 

707 other_asset_filepath = asset_filename_map[asset_filename] 

708 if other_asset_filepath == asset_filepath: 

709 # This is the same file, stored twice in the list. No need 

710 # to make unique. 

711 return asset_filename 

712 

713 # Else, asset_filename is in the map, and the filepath is different. Dedupe. 

714 if not file_io.filecmp(asset_filepath, other_asset_filepath): 

715 # Files are different; dedupe filenames. 

716 return _get_unique_asset_filename(asset_filename, asset_filename_map) 

717 

718 # Files are the same; don't make unique. 

719 return asset_filename 

720 

721 

722def _get_unique_asset_filename(asset_filename, asset_filename_map): 

723 i = 1 

724 unique_filename = asset_filename 

725 while unique_filename in asset_filename_map: 

726 unique_filename = compat.as_bytes("_").join( 

727 [compat.as_bytes(asset_filename), compat.as_bytes(str(i))]) 

728 i += 1 

729 return unique_filename 

730 

731 

732def _asset_path_from_tensor(path_tensor): 

733 """Returns the filepath value stored in constant `path_tensor`. 

734 

735 Args: 

736 path_tensor: Tensor of a file-path. 

737 

738 Returns: 

739 The string value i.e. path of the tensor, if valid. 

740 

741 Raises: 

742 TypeError if tensor does not match expected op type, dtype or value. 

743 """ 

744 if not isinstance(path_tensor, ops.Tensor): 

745 raise TypeError(f"Asset path tensor {path_tensor} must be a Tensor.") 

746 if path_tensor.op.type != "Const": 

747 raise TypeError(f"Asset path tensor {path_tensor} must be of type constant." 

748 f"Has type {path_tensor.op.type} instead.") 

749 if path_tensor.dtype != dtypes.string: 

750 raise TypeError(f"Asset path tensor {path_tensor}` must be of dtype string." 

751 f"Has type {path_tensor.dtype} instead.") 

752 str_values = path_tensor.op.get_attr("value").string_val 

753 if len(str_values) != 1: 

754 raise TypeError(f"Asset path tensor {path_tensor} must be a scalar.") 

755 return str_values[0] 

756 

757 

758def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor): 

759 """Builds an asset proto and adds it to the meta graph def. 

760 

761 Args: 

762 meta_graph_def: The meta graph def to which the asset will be added. 

763 asset_filename: The filename of the asset to be added. 

764 asset_tensor: The asset tensor used to populate the tensor info of the asset 

765 proto. 

766 """ 

767 asset_proto = meta_graph_def.asset_file_def.add() 

768 asset_proto.filename = asset_filename 

769 asset_proto.tensor_info.name = asset_tensor.name 

770 

771 

772def copy_assets_to_destination_dir(asset_filename_map, destination_dir, 

773 saved_files=None): 

774 """Copy all assets from source path to destination path. 

775 

776 Args: 

777 asset_filename_map: a dict of filenames used for saving the asset in 

778 the SavedModel to full paths from which the filenames were derived. 

779 destination_dir: the destination directory that assets are stored in. 

780 saved_files: a set of destination filepaths that have already been copied 

781 and will be skipped 

782 """ 

783 if saved_files is None: 

784 saved_files = set() 

785 

786 assets_destination_dir = path_helpers.get_or_create_assets_dir( 

787 destination_dir) 

788 

789 # Copy each asset from source path to destination path. 

790 for asset_basename, asset_source_filepath in asset_filename_map.items(): 

791 asset_destination_filepath = file_io.join( 

792 compat.as_bytes(assets_destination_dir), 

793 compat.as_bytes(asset_basename)) 

794 

795 # Copy if source file exists, src & dst are not the same, and dst is not in 

796 # saved_files 

797 if (file_io.file_exists(asset_source_filepath) and 

798 asset_source_filepath != asset_destination_filepath and 

799 asset_destination_filepath not in saved_files): 

800 file_io.copy( 

801 asset_source_filepath, asset_destination_filepath, overwrite=True) 

802 saved_files.add(asset_destination_filepath) 

803 

804 tf_logging.info("Assets written to: %s", 

805 compat.as_text(assets_destination_dir)) 

806 

807 

808def _add_asset_to_collection(asset_filename, asset_tensor): 

809 """Builds an asset proto and adds it to the asset collection of the graph. 

810 

811 Args: 

812 asset_filename: The filename of the asset to be added. 

813 asset_tensor: The asset tensor used to populate the tensor info of the 

814 asset proto. 

815 """ 

816 asset_proto = meta_graph_pb2.AssetFileDef() 

817 asset_proto.filename = asset_filename 

818 asset_proto.tensor_info.name = asset_tensor.name 

819 

820 asset_any_proto = Any() 

821 asset_any_proto.Pack(asset_proto) 

822 ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto) 

823 

824 

825def _add_op_to_signature_def_map(signature_def_map, op, key): 

826 if op is not None: 

827 signature_def_map[key] = signature_def_utils.op_signature_def(op, key)