Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py: 19%

562 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Exports a SavedModel from a Trackable Python object.""" 

16 

17import collections 

18import os 

19import re 

20import sys 

21import traceback 

22 

23from absl import logging 

24 

25from tensorflow.core.framework import function_pb2 

26from tensorflow.core.framework import versions_pb2 

27from tensorflow.core.protobuf import meta_graph_pb2 

28from tensorflow.core.protobuf import saved_model_pb2 

29from tensorflow.core.protobuf import saved_object_graph_pb2 

30from tensorflow.python.checkpoint import checkpoint 

31from tensorflow.python.checkpoint import checkpoint_options 

32from tensorflow.python.checkpoint import functional_saver 

33from tensorflow.python.checkpoint import graph_view 

34from tensorflow.python.checkpoint import save_util_v1 

35from tensorflow.python.checkpoint import util as checkpoint_util 

36from tensorflow.python.eager import context 

37from tensorflow.python.eager import def_function 

38from tensorflow.python.eager import function as defun 

39from tensorflow.python.eager.polymorphic_function import polymorphic_function 

40from tensorflow.python.eager.polymorphic_function import saved_model_exported_concrete 

41from tensorflow.python.eager.polymorphic_function import saved_model_utils 

42from tensorflow.python.framework import dtypes 

43from tensorflow.python.framework import error_interpolation 

44from tensorflow.python.framework import errors 

45from tensorflow.python.framework import function as framework_fn 

46from tensorflow.python.framework import meta_graph 

47from tensorflow.python.framework import ops 

48from tensorflow.python.framework import tensor_util 

49from tensorflow.python.framework import versions 

50from tensorflow.python.lib.io import file_io 

51from tensorflow.python.ops import array_ops 

52from tensorflow.python.ops import control_flow_ops 

53from tensorflow.python.ops import resource_variable_ops 

54from tensorflow.python.saved_model import builder_impl 

55from tensorflow.python.saved_model import fingerprinting_utils 

56from tensorflow.python.saved_model import function_serialization 

57from tensorflow.python.saved_model import path_helpers 

58from tensorflow.python.saved_model import pywrap_saved_model 

59from tensorflow.python.saved_model import registration 

60from tensorflow.python.saved_model import revived_types 

61from tensorflow.python.saved_model import save_context 

62from tensorflow.python.saved_model import save_options 

63from tensorflow.python.saved_model import signature_constants 

64from tensorflow.python.saved_model import signature_def_utils 

65from tensorflow.python.saved_model import signature_serialization 

66from tensorflow.python.saved_model import tag_constants 

67from tensorflow.python.saved_model import tracing_utils 

68from tensorflow.python.saved_model import utils_impl 

69from tensorflow.python.saved_model.pywrap_saved_model import constants 

70from tensorflow.python.saved_model.pywrap_saved_model import metrics 

71from tensorflow.python.trackable import asset 

72from tensorflow.python.trackable import base 

73from tensorflow.python.trackable import resource 

74from tensorflow.python.trackable import trackable_utils 

75from tensorflow.python.training.saving import trace_saveable_util 

76from tensorflow.python.types import core as types_core 

77from tensorflow.python.util import compat 

78from tensorflow.python.util import object_identity 

79from tensorflow.python.util.tf_export import tf_export 

80 

81_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant)) 

82 

83# Container for tensors captured from external functions. 

84_CapturedTensor = collections.namedtuple("_CapturedTensor", 

85 ["name", "concrete_function"]) 

86 

87# Number of untraced functions to display to user in warning message. 

88_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5 

89 

90# API label for SavedModel metrics. 

91_SAVE_V2_LABEL = "save_v2" 

92 

93 

94class _AugmentedGraphView(graph_view.ObjectGraphView): 

95 """An extendable graph which also tracks functions attached to objects. 

96 

97 Extensions through `add_object` appear in the object graph and any checkpoints 

98 generated from it, even if they are not dependencies of the node they were 

99 attached to in the saving program. For example a `.signatures` attribute is 

100 added to exported SavedModel root objects without modifying the root object 

101 itself. 

102 

103 Also tracks functions attached to objects in the graph, through the caching 

104 `_list_functions` method. Enumerating functions only through this method 

105 ensures that we get a consistent view of functions, even if object attributes 

106 create new functions every time they are accessed. 

107 """ 

108 

109 def __init__(self, root): 

110 super(_AugmentedGraphView, self).__init__(root) 

111 

112 # Cache the results of `GraphView.list_children()` to ensure that the 

113 # `Trackable` children are gathered exactly once. 

114 self._children_cache = object_identity.ObjectIdentityDictionary() 

115 

116 # Cache shared between objects in the same object graph. This is passed to 

117 # `Trackable._trackable_children()`. 

118 self._serialization_cache = object_identity.ObjectIdentityDictionary() 

119 

120 # Maps functions -> wrapped functions that capture non-cached variables. 

121 self._wrapped_functions = {} 

122 

123 self.untraced_functions = [] 

124 

125 def set_signature(self, signature_map, wrapped_functions): 

126 """Attach signature to the root object. 

127 

128 Args: 

129 signature_map: An object that contains signature functions. 

130 wrapped_functions: A dictionary mapping functions to functions that are 

131 guaranteed to not capture cached variables (functions that capture 

132 cached variables can't be saved). 

133 """ 

134 self.list_children(self.root) 

135 # Overrides existing dependency. 

136 name = signature_serialization.SIGNATURE_ATTRIBUTE_NAME 

137 self._children_cache[self.root][name] = signature_map 

138 self._wrapped_functions.update(wrapped_functions) 

139 

140 def _breadth_first_traversal(self): 

141 """Returns all trackable objects in the SavedObjectGraph.""" 

142 # This method is overriden to merge all equivalent constant tensors and 

143 # Assets in the object graph. 

144 

145 trackable_objects, _ = ( 

146 super(_AugmentedGraphView, self)._breadth_first_traversal()) 

147 

148 asset_paths = object_identity.ObjectIdentityDictionary() 

149 constant_captures = object_identity.ObjectIdentityDictionary() 

150 for obj in trackable_objects: 

151 if isinstance(obj, asset.Asset): 

152 asset_paths[obj.asset_path] = obj 

153 if isinstance(obj, saved_model_utils.TrackableConstant): 

154 constant_captures[obj.capture] = obj 

155 

156 def _get_merged_trackable(x): 

157 if isinstance(x, asset.Asset): 

158 return asset_paths[x.asset_path] 

159 if isinstance(x, saved_model_utils.TrackableConstant): 

160 if x.capture in asset_paths: 

161 return asset_paths[x.capture] 

162 else: 

163 return constant_captures[x.capture] 

164 return x 

165 

166 for obj in list(self._children_cache.keys()): 

167 if _get_merged_trackable(obj) is not obj: 

168 del self._children_cache[obj] 

169 continue 

170 for name, child in self._children_cache[obj].items(): 

171 self._children_cache[obj][name] = _get_merged_trackable(child) 

172 

173 return super(_AugmentedGraphView, self)._breadth_first_traversal() 

174 

175 def list_children(self, obj): 

176 """Lists children of `obj` for SavedModel.""" 

177 if obj not in self._children_cache: 

178 children = self._children_cache[obj] = {} 

179 

180 for name, child in super(_AugmentedGraphView, self).list_children( 

181 obj, 

182 save_type=base.SaveType.SAVEDMODEL, 

183 cache=self._serialization_cache): 

184 if isinstance(child, defun.ConcreteFunction): 

185 child = self._maybe_uncache_variable_captures(child) 

186 children[name] = child 

187 

188 # Keep track of untraced functions for later reporting to the user. 

189 if isinstance(obj, def_function.Function) and not children: 

190 self.untraced_functions.append(obj.name) 

191 

192 for name, child in self._children_cache[obj].items(): 

193 yield base.TrackableReference(name, child) 

194 

195 def get_child(self, obj, name): 

196 return self._children_cache[obj][name] 

197 

198 def _maybe_uncache_variable_captures(self, concrete_function): 

199 if concrete_function in self._wrapped_functions: 

200 return self._wrapped_functions[concrete_function] 

201 for capture in concrete_function.captured_inputs: 

202 if hasattr(capture, "_cached_variable"): 

203 if concrete_function not in self._wrapped_functions: 

204 wrapped = self._wrapped_functions[concrete_function] = ( 

205 function_serialization.wrap_cached_variables(concrete_function)) 

206 return wrapped 

207 return concrete_function 

208 

209 def list_dependencies(self, obj): 

210 """Yields `Trackables` that must be loaded before `obj`. 

211 

212 Dependencies and children are both dictionaries of `Trackables`. Children 

213 define the object graph structure (used in both checkpoints and SavedModel), 

214 while dependency defines the order used to load the SavedModel 

215 

216 Args: 

217 obj: A `Trackable` object 

218 

219 Yields: 

220 Tuple of dependency names and trackable objects. 

221 

222 Raises: 

223 TypeError: if any of the returned dependencies are not instances of 

224 `Trackable`. 

225 """ 

226 if obj not in self._children_cache: 

227 # Slot variables do not appear in the children_cache. 

228 children = {} 

229 else: 

230 children = self._children_cache[obj] 

231 for name, dep in obj._deserialization_dependencies(children).items(): # pylint: disable=protected-access 

232 if not isinstance(dep, base.Trackable): 

233 raise TypeError( 

234 f"The dependency of type {type(dep)} is not an instance `Trackable`" 

235 ", and can't be saved to SavedModel. Please check the " 

236 "implementation of `_deserialization_dependencies` in the parent " 

237 f"object {obj}.") 

238 yield name, dep 

239 

240 

241class _SaveableView(object): 

242 """Provides a frozen view over a trackable root. 

243 

244 This class helps to create a single stable view over an object to save. The 

245 saving code should access properties and functions via this class and not via 

246 the original object as there are cases where an object construct their 

247 trackable attributes and functions dynamically per call and will yield 

248 different objects if invoked more than once. 

249 

250 Changes to the graph, for example adding objects, must happen in 

251 `augmented_graph_view` (an `_AugmentedGraphView`) before the `_SaveableView` 

252 is constructed. Changes after the `_SaveableView` has been constructed will be 

253 ignored. 

254 """ 

255 

256 def __init__(self, augmented_graph_view, options): 

257 """Initializes a SaveableView. 

258 

259 Args: 

260 augmented_graph_view: A GraphView object. 

261 options: A SaveOptions instance. 

262 """ 

263 

264 self.augmented_graph_view = augmented_graph_view 

265 self._options = options 

266 

267 (self._trackable_objects, self.node_paths, self.node_ids, 

268 self._slot_variables, self.object_names) = ( 

269 checkpoint_util.objects_ids_and_slot_variables_and_paths( 

270 self.augmented_graph_view)) 

271 

272 untraced_functions = self.augmented_graph_view.untraced_functions 

273 if untraced_functions: 

274 logging.info( 

275 "Found untraced functions such as %s while saving (showing %d of %d)." 

276 " These functions will not be directly callable after loading.", 

277 ", ".join(untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]), 

278 min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(untraced_functions)), 

279 len(untraced_functions)) 

280 

281 self._initialize_save_and_restore_functions() 

282 self._initialize_nodes_and_concrete_functions() 

283 

284 self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() 

285 

286 def _initialize_save_and_restore_functions(self): 

287 """Generates all checkpoint save/restore functions. 

288 

289 The save and restore functions are generated in the eager context (or in the 

290 user's Graph/Session) before being copied to the exported GraphDef. These 

291 functions record the ops for saving/restoring the entire object or 

292 individual objects (e.g. variables and hash tables). 

293 

294 The global save and restore functions are generated for compatibility with 

295 TF1 and loading from C++, and is saved in the `MetaGraphDef.saver_def`. 

296 

297 The individual functions are generated for the Python TF2 use case, where 

298 users use the loaded SavedModel as-is, or compose new models using parts 

299 of the object loaded from the SavedModel. These functions are recorded in 

300 the `saveable_objects` map in the `SavedObject` proto. 

301 """ 

302 checkpoint_factory_map, registered_savers = ( 

303 save_util_v1.get_checkpoint_factories_and_keys(self.object_names)) 

304 self._obj_to_registered_saver = object_identity.ObjectIdentityDictionary() 

305 for saver_name, trackables in registered_savers.items(): 

306 for trackable in trackables.values(): 

307 self._obj_to_registered_saver[trackable] = saver_name 

308 self._saveable_objects_map = ( 

309 _gen_save_and_restore_functions(checkpoint_factory_map)) 

310 

311 def _initialize_nodes_and_concrete_functions(self): 

312 """Creates graph with nodes for trackable objects and functions. 

313 

314 Adds functions for each trackable object to `self.nodes` and associated 

315 concrete functions to `self.concrete_functions` for serialization. 

316 """ 

317 self.nodes = list(self._trackable_objects) 

318 self.gradient_functions = [] 

319 self.gradient_defs = [] 

320 

321 for obj in self.nodes: 

322 if obj in self._saveable_objects_map: 

323 for save_fn, restore_fn in self._saveable_objects_map[obj].values(): 

324 self.node_ids[save_fn] = len(self.nodes) 

325 self.nodes.append(save_fn) 

326 

327 self.node_ids[restore_fn] = len(self.nodes) 

328 self.nodes.append(restore_fn) 

329 

330 self.concrete_functions = [ 

331 obj for obj in self.nodes if isinstance(obj, defun.ConcreteFunction) 

332 ] 

333 

334 @property 

335 def concrete_and_gradient_functions(self): 

336 return self.concrete_functions + self.gradient_functions 

337 

338 @property 

339 def root(self): 

340 return self.nodes[0] 

341 

342 def fill_object_graph_proto(self, proto): 

343 """Populate the nodes, children and slot_variables of a SavedObjectGraph.""" 

344 for node_id, node in enumerate(self.nodes): 

345 assert self.node_ids[node] == node_id 

346 object_proto = proto.nodes.add() 

347 object_proto.slot_variables.extend(self._slot_variables.get(node, ())) 

348 if isinstance(node, _CapturedTensor): 

349 continue 

350 for child in self.augmented_graph_view.list_children(node): 

351 child_proto = object_proto.children.add() 

352 child_proto.node_id = self.node_ids[child.ref] 

353 child_proto.local_name = child.name 

354 for name, ref in self.augmented_graph_view.list_dependencies(node): 

355 child_proto = object_proto.dependencies.add() 

356 child_proto.node_id = self.node_ids[ref] 

357 child_proto.local_name = name 

358 

359 if node in self._saveable_objects_map: 

360 assert node not in self._obj_to_registered_saver, ( 

361 "Objects can't have both SaveableObjects and a registered saver") 

362 

363 for local_name, (save_fn, restore_fn) in ( 

364 self._saveable_objects_map[node].items()): 

365 saveable_object_proto = object_proto.saveable_objects[local_name] 

366 saveable_object_proto.save_function = self.node_ids[save_fn] 

367 saveable_object_proto.restore_function = self.node_ids[restore_fn] 

368 

369 elif node in self._obj_to_registered_saver: 

370 object_proto.registered_saver = self._obj_to_registered_saver[node] 

371 

372 def map_resources(self): 

373 """Makes new resource handle ops corresponding to existing resource tensors. 

374 

375 Creates resource handle ops in the current default graph, whereas 

376 `accessible_objects` will be from an eager context. Resource mapping adds 

377 resource handle ops to the main GraphDef of a SavedModel, which allows the 

378 C++ loader API to interact with resources. 

379 

380 Returns: 

381 A tuple of (object_map, tensor_map, asset_info): 

382 object_map: A dictionary mapping from object in `accessible_objects` to 

383 replacement objects created to hold the new resource tensors. 

384 tensor_map: A dictionary mapping from resource tensors extracted from 

385 `accessible_objects` to newly created resource tensors. 

386 asset_info: An _AssetInfo tuple describing external assets referenced 

387 from accessible_objects. 

388 """ 

389 # Only makes sense when adding to the export Graph 

390 assert not context.executing_eagerly() 

391 # TODO(b/205007558): Handle MirroredVariables and other types of variables 

392 # which may need special casing. 

393 object_map = object_identity.ObjectIdentityDictionary() 

394 tensor_map = object_identity.ObjectIdentityDictionary() 

395 asset_info = _AssetInfo( 

396 asset_defs=[], 

397 asset_initializers_by_resource=object_identity.ObjectIdentityDictionary(), 

398 asset_filename_map={}, 

399 asset_index={}) 

400 

401 for node_id in _dependency_sorted_node_ids(self): 

402 obj = self.nodes[node_id] 

403 tensors = obj._export_to_saved_model_graph( # pylint: disable=protected-access 

404 object_map=object_map, 

405 tensor_map=tensor_map, 

406 options=self._options) 

407 if isinstance(obj, asset.Asset): 

408 _add_asset_info(obj, asset_info, tensor_map[obj.asset_path]) 

409 if tensors: 

410 for tensor in tensors: 

411 self.captured_tensor_node_ids[tensor] = node_id 

412 

413 return object_map, tensor_map, asset_info 

414 

415 def add_capture_and_node(self, capture, node): 

416 node_id = len(self.nodes) 

417 self.nodes.append(node) 

418 self.node_ids[capture] = node_id 

419 self.node_ids[node] = node_id 

420 self.captured_tensor_node_ids[capture] = node_id 

421 return node_id 

422 

423 def get_concrete_resource_initializers(self): 

424 concrete_initializers = [] 

425 for obj in self.nodes: 

426 if isinstance(obj, resource.CapturableResource): 

427 concrete_initializers.append( 

428 self.augmented_graph_view.get_child( 

429 obj, "_initialize").get_concrete_function()) 

430 return concrete_initializers 

431 

432 

433def _gen_save_and_restore_functions(checkpoint_factory_map): 

434 """Generates global and individual save/restore concrete functions. 

435 

436 The global functions records the ops to save and restore the entire object to 

437 a file prefix, while the individual functions save and restore value tensors 

438 for resources. 

439 

440 This function is intended to run on the output of 

441 `save_util_v1.get_checkpoint_factories_and_keys(object_names)`, 

442 which returns the generated a map of `_CheckpointFactoryData`. 

443 

444 Args: 

445 checkpoint_factory_map: A dictionary mapping trackable objects to 

446 a list of `_CheckpointFactoryData`. 

447 

448 Returns: 

449 Tuple of ( 

450 saveable_fn_map: Maps obj -> factory name -> (concrete save, restore) 

451 ) 

452 """ 

453 # Maps obj -> factory attribute_name -> (concrete save, concrete restore) 

454 # This 

455 saveable_fn_map = object_identity.ObjectIdentityDictionary() 

456 

457 for obj, factory_data_list in checkpoint_factory_map.items(): 

458 if resource_variable_ops.is_resource_variable(obj) or not factory_data_list: 

459 # There is no need to trace the save and restore functions for variables. 

460 continue 

461 

462 if factory_data_list[0].name == trackable_utils.SERIALIZE_TO_TENSORS_NAME: 

463 # Trace Trackable save and restore functions. 

464 assert len(factory_data_list) == 1 

465 saveable_fn_map[obj] = {trackable_utils.SERIALIZE_TO_TENSORS_NAME: ( 

466 tracing_utils.trace_save_and_restore(obj))} 

467 else: 

468 # Trace deprecated SaveableObject save and restore functions. 

469 saveable_fn_map[obj] = ( 

470 trace_saveable_util.trace_save_restore_function_map( 

471 obj, factory_data_list)) 

472 return saveable_fn_map 

473 

474 

475def _tensor_dict_to_tensorinfo(tensor_dict): 

476 return { 

477 key: utils_impl.build_tensor_info_internal(value) 

478 for key, value in tensor_dict.items() 

479 } 

480 

481 

482def _to_safe_name_scope(signature_key, user_input_name): 

483 """Creates a sanitized name scope from user signature and input names. 

484 

485 Concatenates signature and input names, sanitizing as needed to be a valid 

486 scope name. 

487 

488 Args: 

489 signature_key: The user-provided key for the signature. 

490 user_input_name: The user-provided name for the input placeholder. 

491 

492 Returns: 

493 A name scope that is safe to be used in tf.name_scope(). 

494 """ 

495 name_scope = "{}_{}".format(signature_key, user_input_name) 

496 if re.match(r"^[A-Za-z0-9.][A-Za-z0-9_.\\-]*$", name_scope): 

497 return name_scope 

498 invalid_prefix_stripped = re.sub(r"^[^A-Za-z0-9.]*", "", name_scope) 

499 return re.sub(r"[^A-Za-z0-9_.\\-]", "_", invalid_prefix_stripped) 

500 

501 

502def _map_function_arguments_to_created_inputs( 

503 function_arguments, signature_key, function_name, defaults=None 

504): 

505 """Creates exterior placeholders in the exported graph for function arguments. 

506 

507 Functions have two types of inputs: tensors captured from the outside (eager) 

508 context, and arguments to the function which we expect to receive from the 

509 user at each call. `_map_captures_to_created_tensors` replaces 

510 captured tensors with stand-ins (typically these are resource dtype tensors 

511 associated with variables). `_map_function_inputs_to_created_inputs` runs over 

512 every argument, creating a new placeholder for each which will belong to the 

513 exported graph rather than the function body. 

514 

515 Args: 

516 function_arguments: A list of argument placeholders in the function body. 

517 signature_key: The name of the signature being exported, for error messages. 

518 function_name: The name of the function, for error messages. 

519 defaults: A dictionary mapping (signature_key, user_specified_name) to 

520 Tensor representing default values. 

521 

522 Returns: 

523 A tuple of (mapped_inputs, exterior_placeholders) 

524 mapped_inputs: A list with entries corresponding to `function_arguments` 

525 containing all of the inputs of the function gathered from the exported 

526 graph (both captured resources and arguments). 

527 exterior_argument_placeholders: A dictionary mapping from argument names 

528 to placeholders in the exported graph, containing the explicit arguments 

529 to the function which a user is expected to provide. 

530 

531 Raises: 

532 ValueError: If argument names are not unique. 

533 """ 

534 # `exterior_argument_placeholders` holds placeholders which are outside the 

535 # function body, directly contained in a MetaGraph of the SavedModel. The 

536 # function body itself contains nearly identical placeholders used when 

537 # running the function, but these exterior placeholders allow Session-based 

538 # APIs to call the function using feeds and fetches which name Tensors in the 

539 # MetaGraph. 

540 exterior_argument_placeholders = {} 

541 mapped_inputs = [] 

542 for placeholder in function_arguments: 

543 # `export_captures` contains an exhaustive set of captures, so if we don't 

544 # find the input there then we now know we have an argument. 

545 user_input_name = compat.as_str_any( 

546 placeholder.op.get_attr("_user_specified_name")) 

547 # If the internal placeholders for a function have names which were 

548 # uniquified by TensorFlow, then a single user-specified argument name 

549 # must refer to multiple Tensors. The resulting signatures would be 

550 # confusing to call. Instead, we throw an exception telling the user to 

551 # specify explicit names. 

552 if user_input_name != placeholder.op.name: 

553 # This should be unreachable, since concrete functions may not be 

554 # generated with non-unique argument names. 

555 raise ValueError( 

556 "Got non-flat/non-unique argument names for SavedModel signature " 

557 f"'{signature_key}': more than one argument to " 

558 f"'{compat.as_str_any(function_name)}' was named " 

559 f"'{user_input_name}'. " 

560 "Signatures have one Tensor per named input, so to have " 

561 "predictable names Python functions used to generate these " 

562 "signatures should avoid *args and Tensors in nested " 

563 "structures unless unique names are specified for each. Use " 

564 "tf.TensorSpec(..., name=...) to provide a name for a Tensor " 

565 "input.") 

566 default_value = defaults.get((signature_key, user_input_name)) 

567 if default_value is not None: 

568 placeholder_with_default = array_ops.placeholder_with_default( 

569 input=default_value.numpy(), 

570 shape=placeholder.shape, 

571 name=_to_safe_name_scope(signature_key, user_input_name), 

572 ) 

573 exterior_argument_placeholders[user_input_name] = placeholder_with_default 

574 mapped_inputs.append(placeholder_with_default) 

575 else: 

576 arg_placeholder = array_ops.placeholder( 

577 shape=placeholder.shape, 

578 dtype=placeholder.dtype, 

579 name=_to_safe_name_scope(signature_key, user_input_name), 

580 ) 

581 exterior_argument_placeholders[user_input_name] = arg_placeholder 

582 mapped_inputs.append(arg_placeholder) 

583 return mapped_inputs, exterior_argument_placeholders 

584 

585 

586def _generate_signatures(signature_functions, object_map, defaults=None): 

587 """Validates and calls `signature_functions` in the exported graph. 

588 

589 Args: 

590 signature_functions: A dictionary mapping string keys to concrete TensorFlow 

591 functions (e.g. from `signature_serialization.canonicalize_signatures`) 

592 which will be used to generate SignatureDefs. 

593 object_map: A dictionary that contains mappings from signature functions to 

594 concrete functions in the exported graph. 

595 defaults: A dictionary mapping (signature_key, user_specified_name) to 

596 Tensor representing default values. 

597 

598 Returns: 

599 Each function in the `signature_functions` dictionary is called with 

600 placeholder Tensors, generating a function call operation and output 

601 Tensors. The placeholder Tensors, the function call operation, and the 

602 output Tensors from the function call are part of the default Graph. 

603 

604 This function then returns a dictionary with the same structure as 

605 `signature_functions`, with the concrete functions replaced by SignatureDefs 

606 implicitly containing information about how to call each function from a 

607 TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference 

608 the generated placeholders and Tensor outputs by name. 

609 

610 The caller is expected to include the default Graph set while calling this 

611 function as a MetaGraph in a SavedModel, including the returned 

612 SignatureDefs as part of that MetaGraph. 

613 """ 

614 signatures = {} 

615 for signature_key, function in sorted(signature_functions.items()): 

616 if function.graph.captures: 

617 argument_inputs = function.graph.inputs[:-len(function.graph.captures)] 

618 else: 

619 argument_inputs = function.graph.inputs 

620 mapped_inputs, exterior_argument_placeholders = ( 

621 _map_function_arguments_to_created_inputs( 

622 argument_inputs, signature_key, function.name, defaults 

623 ) 

624 ) 

625 kwarg_names = list( 

626 sorted( 

627 object_map[function].function.structured_input_signature[1].keys())) 

628 outputs = object_map[function](**{ 

629 kwarg_name: mapped_input 

630 for kwarg_name, mapped_input in zip(kwarg_names, mapped_inputs) 

631 }) 

632 signatures[signature_key] = signature_def_utils.build_signature_def( 

633 _tensor_dict_to_tensorinfo(exterior_argument_placeholders), 

634 _tensor_dict_to_tensorinfo(outputs), 

635 method_name=signature_constants.PREDICT_METHOD_NAME) 

636 return signatures 

637 

638 

639_AssetInfo = collections.namedtuple( 

640 "_AssetInfo", 

641 [ 

642 # List of AssetFileDef protocol buffers 

643 "asset_defs", 

644 # Map from asset variable resource Tensors to their init ops 

645 "asset_initializers_by_resource", 

646 # Map from base asset filenames to full paths 

647 "asset_filename_map", 

648 # Map from Asset to index of corresponding AssetFileDef 

649 "asset_index" 

650 ]) 

651 

652 

653def _add_asset_info(trackable_asset, asset_info, mapped_path_variable): 

654 """Add `trackable_asset` to `asset_info`.""" 

655 original_path_tensor = trackable_asset.asset_path 

656 original_path = tensor_util.constant_value(original_path_tensor) 

657 try: 

658 original_path = str(original_path.astype(str)) 

659 except AttributeError: 

660 # Already a string rather than a numpy array 

661 pass 

662 

663 path = builder_impl.get_asset_filename_to_add( 

664 asset_filepath=original_path, 

665 asset_filename_map=asset_info.asset_filename_map) 

666 asset_info.asset_filename_map[path] = original_path 

667 asset_def = meta_graph_pb2.AssetFileDef() 

668 asset_def.filename = path 

669 asset_def.tensor_info.name = mapped_path_variable.initial_value.name 

670 asset_info.asset_defs.append(asset_def) 

671 asset_info.asset_initializers_by_resource[original_path_tensor] = ( 

672 mapped_path_variable.initializer) 

673 asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1 

674 

675 

676def _iterate_op_types(fn): 

677 """Iterates through each op in the function and returns the op type and op.""" 

678 if isinstance(fn, framework_fn._DefinedFunction): # pylint: disable=protected-access 

679 for node in fn.definition.node_def: 

680 op_type = node.attr["_gradient_op_type"].s 

681 if op_type: 

682 raise ValueError( 

683 "Unable to save gradient functions when exporting a " 

684 "_DefinedFunction (generally created through graph freezing utils " 

685 "or through V1 graph importers). Please save with " 

686 "`options=tf.SaveOptions(experimental_custom_gradients=False)`") 

687 else: 

688 for op in fn.graph.get_operations(): 

689 try: 

690 op_type = op.get_attr("_gradient_op_type") 

691 except ValueError: 

692 continue 

693 yield op_type, op 

694 

695 

696def _get_outer_most_capture(fn, capture, func_graph_map): 

697 """Tries to find the original captured tensor if capture more than once.""" 

698 outer_fn = fn 

699 while outer_fn is not None and not isinstance(capture, ops.EagerTensor): 

700 if capture.graph is not outer_fn.graph: 

701 outer_fn = func_graph_map.get(outer_fn.graph.outer_graph) 

702 else: 

703 try: 

704 capture_index = outer_fn.graph.internal_captures.index(capture) 

705 except ValueError: 

706 break # Capture is a tensor inside function, and not captured from 

707 # another external function 

708 capture = outer_fn.graph.external_captures[capture_index] 

709 outer_fn = func_graph_map.get(outer_fn.graph.outer_graph) 

710 return outer_fn, capture 

711 

712 

713def _trace_gradient_functions(graph, saveable_view): 

714 """Traces gradient functions and records them in the SaveableView.""" 

715 functions = list(graph._functions.values()) # pylint: disable=protected-access 

716 func_graph_map = {f.graph: f for f in functions if hasattr(f, "graph")} 

717 seen_op_types = set() 

718 

719 for fn in functions: 

720 for op_type, op in _iterate_op_types(fn): 

721 if op_type in seen_op_types: 

722 continue 

723 seen_op_types.add(op_type) 

724 

725 try: 

726 custom_gradient = ops.gradient_registry.lookup(op_type) 

727 except LookupError: 

728 continue 

729 

730 try: 

731 grad_fn = ( 

732 def_function.function(custom_gradient).get_concrete_function( 

733 None, *op.inputs)) 

734 except Exception as exc: 

735 traceback.print_exc() 

736 raise ValueError( 

737 "Error when tracing gradients for SavedModel.\n\n" 

738 "Check the error log to see the error that was raised when " 

739 "converting a gradient function to a concrete function. You may " 

740 "need to update the custom gradient, or disable saving gradients " 

741 "with the option " 

742 "tf.saved_model.SaveOptions(experimental_custom_gradients=False)" 

743 f".\n\tProblematic op name: {op.name}\n\tGradient inputs: " 

744 f"{op.inputs}") from exc 

745 

746 with graph.as_default(): 

747 # The gradient function will capture all intermediate values. These 

748 # captures be serialized so that they can be re-bound to the function 

749 # when loading. 

750 bad_captures = [] 

751 for capture in grad_fn.captured_inputs: 

752 if capture.dtype in _UNCOPIABLE_DTYPES: 

753 continue 

754 # Tries to find the outermost capture in case the tensor is a constant 

755 # or not actually captured in the current function (this could happen 

756 # if the function is a while loop body, in which case the captured 

757 # input is not the internal captured tensor). 

758 outer_fn, outer_capture = _get_outer_most_capture( 

759 fn, capture, func_graph_map 

760 ) 

761 if outer_fn is None or isinstance(outer_capture, ops.EagerTensor): 

762 if outer_capture not in saveable_view.captured_tensor_node_ids: 

763 raise ValueError( 

764 f"Found invalid capture {outer_capture} when " 

765 "saving custom gradients." 

766 ) 

767 saveable_view.captured_tensor_node_ids[capture] = ( 

768 saveable_view.captured_tensor_node_ids[outer_capture] 

769 ) 

770 elif outer_capture.graph is outer_fn.graph: 

771 capture_name = outer_capture.name 

772 # It's possible for AtomicFunctions to save different names 

773 # for input tensors when serialized to FunctionDef (all 

774 # non-alphanumeric characters are converted to '_'). 

775 if isinstance(outer_fn, defun.AtomicFunction): # pylint:disable=protected-access 

776 try: 

777 arg_index = outer_fn.graph.inputs.index(outer_capture) 

778 capture_name = ( 

779 outer_fn.cached_definition.signature.input_arg[ 

780 arg_index 

781 ].name 

782 + ":0" 

783 ) 

784 except ValueError: 

785 pass 

786 

787 node = _CapturedTensor(capture_name, outer_fn.name) 

788 saveable_view.add_capture_and_node(capture, node) 

789 else: 

790 bad_captures.append(capture.name) 

791 if not bad_captures: 

792 grad_fn.add_to_graph(graph) 

793 else: 

794 raise ValueError( 

795 f"Cannot save custom gradient {op_type} called in function {fn} " 

796 "because SavedModel is unable to serialize the captured " 

797 f"inputs: {bad_captures}" 

798 ) 

799 

800 saveable_view.gradient_functions.append(grad_fn) 

801 func_graph_map[grad_fn.graph] = grad_fn 

802 

803 grad_def = function_pb2.RegisteredGradient() 

804 grad_def.gradient_func = grad_fn.name 

805 grad_def.registered_op_type = op_type 

806 saveable_view.gradient_defs.append(grad_def) 

807 

808 

809def _fill_meta_graph_def( 

810 meta_graph_def, 

811 saveable_view, 

812 signature_functions, 

813 namespace_whitelist, 

814 save_custom_gradients, 

815 defaults=None, 

816): 

817 """Generates a MetaGraph which calls `signature_functions`. 

818 

819 Args: 

820 meta_graph_def: The MetaGraphDef proto to fill. 

821 saveable_view: The _SaveableView being exported. 

822 signature_functions: A dictionary mapping signature keys to concrete 

823 functions containing signatures to add to the MetaGraph. 

824 namespace_whitelist: List of strings containing whitelisted op namespaces. 

825 save_custom_gradients: Whether to save custom gradients. 

826 defaults: A dictionary mapping (signature_key, user_specified_name) to 

827 Tensor representing default values. 

828 

829 Returns: 

830 A tuple of (_AssetInfo, Graph) containing the captured assets and 

831 exported Graph generated from tracing the saveable_view. 

832 """ 

833 # List objects from the eager context to make sure Optimizers give us the 

834 # right Graph-dependent variables. 

835 resource_initializers = saveable_view.get_concrete_resource_initializers() 

836 exported_graph = ops.Graph() 

837 resource_initializer_ops = [] 

838 with exported_graph.as_default(): 

839 object_map, tensor_map, asset_info = saveable_view.map_resources() 

840 signatures = _generate_signatures(signature_functions, object_map, defaults) 

841 if save_custom_gradients: 

842 # Custom gradients functions must be traced in the same context as the 

843 # when they are registered. 

844 _trace_gradient_functions(exported_graph, saveable_view) 

845 with exported_graph.as_default(): 

846 # Create initializers for assets and resources. 

847 for resource_initializer_function in resource_initializers: 

848 asset_dependencies = [] 

849 for capture in resource_initializer_function.graph.external_captures: 

850 asset_initializer = asset_info.asset_initializers_by_resource.get( 

851 capture, None) 

852 if asset_initializer is not None: 

853 asset_dependencies.append(asset_initializer) 

854 with ops.control_dependencies(asset_dependencies): 

855 mapped_initializer = object_map[resource_initializer_function] 

856 resource_initializer_ops.append(mapped_initializer()) 

857 resource_initializer_ops.extend( 

858 asset_info.asset_initializers_by_resource.values()) 

859 with ops.control_dependencies(resource_initializer_ops): 

860 init_op = control_flow_ops.no_op() 

861 # Add the same op to the main_op collection and to the init_op 

862 # signature. The collection is for compatibility with older loader APIs; 

863 # only one will be executed. 

864 meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( 

865 init_op.name) 

866 meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( 

867 signature_def_utils.op_signature_def(init_op, 

868 constants.INIT_OP_SIGNATURE_KEY)) 

869 

870 # Saving an object-based checkpoint again gathers variables. We need to do the 

871 # gathering from the eager context so Optimizers save the right set of 

872 # variables, but want any operations associated with the save/restore to be in 

873 # the exported graph (thus the `to_graph` argument). 

874 def call_with_mapped_captures(function, args): 

875 if function in object_map: 

876 return object_map[function](*args) 

877 # Registered saver/restore functions do not appear in `object_map`, because 

878 # they are not in the object graph. 

879 return saved_model_exported_concrete.ExportedConcreteFunction( 

880 function, tensor_map)(*args) 

881 

882 for obj in object_map.values(): 

883 obj._maybe_initialize_trackable() # pylint: disable=protected-access 

884 named_saveable_objects, registered_savers = ( 

885 save_util_v1.frozen_saveables_and_savers( 

886 graph_view=saveable_view.augmented_graph_view, 

887 object_map=object_map, 

888 to_graph=exported_graph, 

889 call_with_mapped_captures=call_with_mapped_captures)) 

890 saver = functional_saver.MultiDeviceSaver.from_saveables( 

891 named_saveable_objects, registered_savers, call_with_mapped_captures) 

892 

893 with exported_graph.as_default(): 

894 saver_def = saver.to_proto() 

895 meta_graph_def.saver_def.CopyFrom(saver_def) 

896 

897 # At this point all nodes that can be added to the SavedObjectGraph have been 

898 # added, so run the following to validate deserialization dependencies. 

899 _dependency_sorted_node_ids(saveable_view) 

900 

901 graph_def = exported_graph.as_graph_def(add_shapes=True) 

902 graph_def.library.registered_gradients.extend(saveable_view.gradient_defs) 

903 _verify_ops(graph_def, namespace_whitelist) 

904 

905 meta_graph_def.graph_def.CopyFrom(graph_def) 

906 meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) 

907 meta_graph_def.meta_info_def.tensorflow_version = versions.__version__ 

908 meta_graph_def.meta_info_def.tensorflow_git_version = ( 

909 versions.__git_version__) 

910 # We currently always strip default attributes. 

911 meta_graph_def.meta_info_def.stripped_default_attrs = True 

912 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( 

913 meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) 

914 meta_graph_def.asset_file_def.extend(asset_info.asset_defs) 

915 for signature_key, signature in signatures.items(): 

916 meta_graph_def.signature_def[signature_key].CopyFrom(signature) 

917 meta_graph.strip_graph_default_valued_attrs(meta_graph_def) 

918 # store tensor_content in litle endian format 

919 if sys.byteorder == "big": 

920 utils_impl.swap_function_tensor_content(meta_graph_def, "big", "little") 

921 return asset_info, exported_graph 

922 

923 

924def _verify_ops(graph_def, namespace_whitelist): 

925 """Verifies that all namespaced ops in the graph are whitelisted. 

926 

927 Args: 

928 graph_def: the GraphDef to validate. 

929 namespace_whitelist: a list of namespaces to allow. If `None`, all will be 

930 allowed. If an op does not have a namespace, it will be allowed. 

931 

932 Raises: 

933 ValueError: If the graph contains ops that violate the whitelist. 

934 """ 

935 # By default, if the user has not specified a whitelist, we want to allow 

936 # everything. We check for None directly rather than falseness, since the 

937 # user may instead want to pass an empty list to disallow all custom 

938 # namespaced ops. 

939 if namespace_whitelist is None: 

940 return 

941 

942 invalid_ops = [] 

943 invalid_namespaces = set() 

944 

945 all_operations = [] 

946 all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def)) 

947 

948 for op in all_operations: 

949 if ">" in op: 

950 namespace = op.split(">")[0] 

951 if namespace not in namespace_whitelist: 

952 invalid_ops.append(op) 

953 invalid_namespaces.add(namespace) 

954 if invalid_ops: 

955 raise ValueError( 

956 "Attempted to save ops from non-whitelisted namespaces to SavedModel: " 

957 f"{invalid_ops}.\nPlease verify that these ops should be saved, since " 

958 "they must be available when loading the SavedModel. If loading from " 

959 "Python, you must import the library defining these ops. From C++, " 

960 "link the custom ops to the serving binary. Once you've confirmed this," 

961 " add the following namespaces to the `namespace_whitelist` " 

962 f"argument in tf.saved_model.SaveOptions: {invalid_namespaces}.") 

963 

964 

965def _dependency_sorted_node_ids(saveable_view): 

966 """Returns topologically sorted nodes, sorted by dependencies.""" 

967 dependency_map = {} 

968 for node in saveable_view.nodes: 

969 node_id = saveable_view.node_ids[node] 

970 deps = dependency_map[node_id] = [] 

971 # TODO(kathywu): Remove once all of these have been converted to trackable. 

972 if isinstance(node, _CapturedTensor): 

973 continue # These are not `Trackable` and therefore have no dependencies. 

974 for _, dep in saveable_view.augmented_graph_view.list_dependencies(node): 

975 if dep not in saveable_view.node_ids: 

976 node_path = trackable_utils.pretty_print_node_path( 

977 saveable_view.node_paths[node]) 

978 raise ValueError( 

979 f"Found an untracked dependency. Object {node_path} depends " 

980 f"on {dep}, but this dependency isn't listed as a child. " 

981 "Please track this child by overriding `_trackable_children` " 

982 "or use `._track_trackable`.") 

983 deps.append(saveable_view.node_ids[dep]) 

984 try: 

985 return trackable_utils.order_by_dependency(dependency_map) 

986 except trackable_utils.CyclicDependencyError as err: 

987 pretty_printed_nodes = [] 

988 pretty_printed_dependencies = [] 

989 

990 for x, deps in err.leftover_dependency_map.items(): 

991 node_path = trackable_utils.pretty_print_node_path( 

992 saveable_view.node_paths[saveable_view.nodes[x]]) 

993 pretty_printed_nodes.append( 

994 f"\tNode {x} = {node_path} (type {type(saveable_view.nodes[x])})") 

995 pretty_printed_dependencies.append(f"\tNode {x} depends on nodes {deps}") 

996 pretty_printed_nodes = "\n".join(pretty_printed_nodes) 

997 pretty_printed_dependencies = "\n".join(pretty_printed_dependencies) 

998 raise ValueError( 

999 "There is one or more dependency cycle in the saved Trackable object. " 

1000 "Saving cannot continue until this cycle is resolved." 

1001 f"\n>> Unresolved nodes:\n{pretty_printed_nodes}" 

1002 f"\n>> Unresolved cyclic dependencies:\n{pretty_printed_dependencies}") 

1003 

1004 

1005def _serialize_object_graph(saveable_view, asset_file_def_index): 

1006 """Save a SavedObjectGraph proto for `root`.""" 

1007 # SavedObjectGraph is similar to the TrackableObjectGraph proto in the 

1008 # checkpoint. It will eventually go into the SavedModel. 

1009 proto = saved_object_graph_pb2.SavedObjectGraph() 

1010 saveable_view.fill_object_graph_proto(proto) 

1011 

1012 for concrete_function in saveable_view.concrete_and_gradient_functions: 

1013 name = compat.as_text(concrete_function.name) 

1014 serialized = function_serialization.serialize_concrete_function( 

1015 concrete_function, saveable_view.captured_tensor_node_ids) 

1016 if serialized is not None: 

1017 proto.concrete_functions[name].CopyFrom(serialized) 

1018 

1019 for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): 

1020 _write_object_proto(obj, obj_proto, asset_file_def_index, 

1021 saveable_view.augmented_graph_view.list_children) 

1022 return proto 

1023 

1024 

1025def _write_object_proto(obj, proto, asset_file_def_index, list_children_fn): 

1026 """Saves an object into SavedObject proto.""" 

1027 if isinstance(obj, asset.Asset): 

1028 proto.asset.SetInParent() 

1029 proto.asset.asset_file_def_index = asset_file_def_index[obj] 

1030 elif resource_variable_ops.is_resource_variable(obj): 

1031 options = save_context.get_save_options() 

1032 obj._write_object_proto(proto, options) # pylint: disable=protected-access 

1033 elif isinstance(obj, def_function.Function): 

1034 proto.function.CopyFrom( 

1035 function_serialization.serialize_function( 

1036 obj, [x.ref for x in list_children_fn(obj)])) 

1037 elif isinstance(obj, defun.ConcreteFunction): 

1038 proto.bare_concrete_function.CopyFrom( 

1039 function_serialization.serialize_bare_concrete_function(obj)) 

1040 elif isinstance(obj, _CapturedTensor): 

1041 proto.captured_tensor.name = obj.name 

1042 proto.captured_tensor.concrete_function = obj.concrete_function 

1043 elif isinstance(obj, resource.CapturableResource): 

1044 proto.resource.device = obj._resource_device # pylint: disable=protected-access 

1045 else: 

1046 registered_type_proto = revived_types.serialize(obj) 

1047 if registered_type_proto is None: 

1048 # Fallback for types with no matching registration 

1049 # pylint:disable=protected-access 

1050 registered_type_proto = saved_object_graph_pb2.SavedUserObject( 

1051 identifier=obj._object_identifier, 

1052 version=versions_pb2.VersionDef( 

1053 producer=1, min_consumer=1, bad_consumers=[])) 

1054 # pylint:enable=protected-access 

1055 proto.user_object.CopyFrom(registered_type_proto) 

1056 

1057 registered_name = registration.get_registered_class_name(obj) 

1058 if registered_name: 

1059 proto.registered_name = registered_name 

1060 serialized_user_proto = obj._serialize_to_proto(object_proto=proto) # pylint: disable=protected-access 

1061 if serialized_user_proto is not None: 

1062 proto.serialized_user_proto.Pack(serialized_user_proto) 

1063 

1064 

1065def _export_debug_info(exported_graph, export_dir): 

1066 """Exports debug information from graph to file. 

1067 

1068 Creates and writes GraphDebugInfo with traces for ops in all functions of the 

1069 exported_graph. 

1070 

1071 Args: 

1072 exported_graph: A Graph that has been created by tracing a saveable view. 

1073 export_dir: SavedModel directory in which to write the debug info. 

1074 """ 

1075 exported_operations = [] 

1076 for fn_name in exported_graph._functions: # pylint: disable=protected-access 

1077 fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access 

1078 if not isinstance(fn, defun.AtomicFunction): # pylint: disable=protected-access 

1079 continue 

1080 

1081 fn_graph = fn.graph 

1082 for fn_op in fn_graph.get_operations(): 

1083 exported_operations.append((fn_name, fn_op)) 

1084 

1085 graph_debug_info = error_interpolation.create_graph_debug_info_def( 

1086 exported_operations) 

1087 file_io.atomic_write_string_to_file( 

1088 file_io.join( 

1089 path_helpers.get_or_create_debug_dir(export_dir), 

1090 constants.DEBUG_INFO_FILENAME_PB), 

1091 graph_debug_info.SerializeToString(deterministic=True)) 

1092 

1093 

1094@tf_export( 

1095 "saved_model.save", 

1096 v1=["saved_model.save", "saved_model.experimental.save"]) 

1097def save(obj, export_dir, signatures=None, options=None): 

1098 # pylint: disable=line-too-long 

1099 """Exports a [tf.Module](https://www.tensorflow.org/api_docs/python/tf/Module) (and subclasses) `obj` to [SavedModel format](https://www.tensorflow.org/guide/saved_model#the_savedmodel_format_on_disk). 

1100 

1101 The `obj` must inherit from the [`Trackable` 

1102 class](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/tracking/base.py#L591). 

1103 

1104 Example usage: 

1105 

1106 >>> class Adder(tf.Module): 

1107 ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)]) 

1108 ... def add(self, x): 

1109 ... return x + x 

1110 

1111 >>> model = Adder() 

1112 >>> tf.saved_model.save(model, '/tmp/adder') 

1113 

1114 The resulting SavedModel is then servable with an input named "x", a scalar 

1115 with dtype float32. 

1116 

1117 _Signatures_ 

1118 

1119 Signatures define the input and output types for a computation. The optional 

1120 save `signatures` argument controls which methods in `obj` will be 

1121 available to programs which consume `SavedModel`s, for example, serving 

1122 APIs. Python functions may be decorated with 

1123 `@tf.function(input_signature=...)` and passed as signatures directly, or 

1124 lazily with a call to `get_concrete_function` on the method decorated with 

1125 `@tf.function`. 

1126 

1127 Example: 

1128 

1129 >>> class Adder(tf.Module): 

1130 ... @tf.function 

1131 ... def add(self, x): 

1132 ... return x + x 

1133 

1134 >>> model = Adder() 

1135 >>> tf.saved_model.save( 

1136 ... model, '/tmp/adder',signatures=model.add.get_concrete_function( 

1137 ... tf.TensorSpec([], tf.float32))) 

1138 

1139 If a `@tf.function` does not have an input signature and 

1140 `get_concrete_function` is not called on that method, the function will not 

1141 be directly callable in the restored SavedModel. 

1142 

1143 Example: 

1144 

1145 >>> class Adder(tf.Module): 

1146 ... @tf.function 

1147 ... def add(self, x): 

1148 ... return x + x 

1149 

1150 >>> model = Adder() 

1151 >>> tf.saved_model.save(model, '/tmp/adder') 

1152 >>> restored = tf.saved_model.load('/tmp/adder') 

1153 >>> restored.add(1.) 

1154 Traceback (most recent call last): 

1155 ... 

1156 ValueError: Found zero restored functions for caller function. 

1157 

1158 If the `signatures` argument is omitted, `obj` will be searched for 

1159 `@tf.function`-decorated methods. If exactly one traced `@tf.function` is 

1160 found, that method will be used as the default signature for the SavedModel. 

1161 Else, any `@tf.function` attached to `obj` or its dependencies will be 

1162 exported for use with `tf.saved_model.load`. 

1163 

1164 When invoking a signature in an exported SavedModel, `Tensor` arguments are 

1165 identified by name. These names will come from the Python function's argument 

1166 names by default. They may be overridden by specifying a `name=...` argument 

1167 in the corresponding `tf.TensorSpec` object. Explicit naming is required if 

1168 multiple `Tensor`s are passed through a single argument to the Python 

1169 function. 

1170 

1171 The outputs of functions used as `signatures` must either be flat lists, in 

1172 which case outputs will be numbered, or a dictionary mapping string keys to 

1173 `Tensor`, in which case the keys will be used to name outputs. 

1174 

1175 Signatures are available in objects returned by `tf.saved_model.load` as a 

1176 `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` 

1177 on an object with a custom `.signatures` attribute will raise an exception. 

1178 

1179 _Using `tf.saved_model.save` with Keras models_ 

1180 

1181 While Keras has its own [saving and loading 

1182 API](https://www.tensorflow.org/guide/keras/save_and_serialize), 

1183 this function can be used to export Keras models. For example, exporting with 

1184 a signature specified: 

1185 

1186 >>> class Adder(tf.keras.Model): 

1187 ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) 

1188 ... def concat(self, x): 

1189 ... return x + x 

1190 

1191 >>> model = Adder() 

1192 >>> tf.saved_model.save(model, '/tmp/adder') 

1193 

1194 Exporting from a function without a fixed signature: 

1195 

1196 >>> class Adder(tf.keras.Model): 

1197 ... @tf.function 

1198 ... def concat(self, x): 

1199 ... return x + x 

1200 

1201 >>> model = Adder() 

1202 >>> tf.saved_model.save( 

1203 ... model, '/tmp/adder', 

1204 ... signatures=model.concat.get_concrete_function( 

1205 ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))) 

1206 

1207 `tf.keras.Model` instances constructed from inputs and outputs already have a 

1208 signature and so do not require a `@tf.function` decorator or a `signatures` 

1209 argument. If neither are specified, the model's forward pass is exported. 

1210 

1211 >>> x = tf.keras.layers.Input((4,), name="x") 

1212 >>> y = tf.keras.layers.Dense(5, name="out")(x) 

1213 >>> model = tf.keras.Model(x, y) 

1214 >>> tf.saved_model.save(model, '/tmp/saved_model/') 

1215 

1216 The exported SavedModel takes "x" with shape [None, 4] and returns "out" 

1217 with shape [None, 5] 

1218 

1219 _Variables and Checkpoints_ 

1220 

1221 Variables must be tracked by assigning them to an attribute of a tracked 

1222 object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers 

1223 from `tf.keras.layers`, optimizers from `tf.train`) track their variables 

1224 automatically. This is the same tracking scheme that `tf.train.Checkpoint` 

1225 uses, and an exported `Checkpoint` object may be restored as a training 

1226 checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's 

1227 "variables/" subdirectory. 

1228 

1229 `tf.function` does not hard-code device annotations from outside the function 

1230 body, instead of using the calling context's device. This means for example 

1231 that exporting a model that runs on a GPU and serving it on a CPU will 

1232 generally work, with some exceptions: 

1233 

1234 * `tf.device` annotations inside the body of the function will be hard-coded 

1235 in the exported model; this type of annotation is discouraged. 

1236 * Device-specific operations, e.g. with "cuDNN" in the name or with 

1237 device-specific layouts, may cause issues. 

1238 * For `ConcreteFunctions`, active distribution strategies will cause device 

1239 placements to be hard-coded in the function. 

1240 

1241 SavedModels exported with `tf.saved_model.save` [strip default-valued 

1242 attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) 

1243 automatically, which removes one source of incompatibilities when the consumer 

1244 of a SavedModel is running an older TensorFlow version than the 

1245 producer. There are however other sources of incompatibilities which are not 

1246 handled automatically, such as when the exported model contains operations 

1247 which the consumer does not have definitions for. 

1248 

1249 Args: 

1250 obj: A trackable object (e.g. tf.Module or tf.train.Checkpoint) to export. 

1251 export_dir: A directory in which to write the SavedModel. 

1252 signatures: Optional, one of three types: 

1253 * A `tf.function` with an input signature specified, which will use the 

1254 default serving signature key. 

1255 * The result of `f.get_concrete_function` on a `@tf.function`-decorated 

1256 function `f`, in which case `f` will be used to generate a signature for 

1257 the SavedModel under the default serving signature key. 

1258 * A dictionary, which maps signature keys to either `tf.function` 

1259 instances with input signatures or concrete functions. Keys of such a 

1260 dictionary may be arbitrary strings, but will typically be from the 

1261 `tf.saved_model.signature_constants` module. 

1262 options: `tf.saved_model.SaveOptions` object for configuring save options. 

1263 

1264 Raises: 

1265 ValueError: If `obj` is not trackable. 

1266 

1267 @compatibility(eager) 

1268 Not well supported when graph building. From TensorFlow 1.x, 

1269 `tf.compat.v1.enable_eager_execution()` should run first. Calling 

1270 tf.saved_model.save in a loop when graph building from TensorFlow 1.x will 

1271 add new save operations to the default graph each iteration. 

1272 

1273 May not be called from within a function body. 

1274 @end_compatibility 

1275 """ 

1276 if isinstance(export_dir, os.PathLike): 

1277 export_dir = os.fspath(export_dir) 

1278 # pylint: enable=line-too-long 

1279 metrics.IncrementWriteApi(_SAVE_V2_LABEL) 

1280 save_and_return_nodes(obj, export_dir, signatures, options) 

1281 

1282 metrics.IncrementWrite(write_version="2") 

1283 

1284 

1285def save_and_return_nodes(obj, 

1286 export_dir, 

1287 signatures=None, 

1288 options=None, 

1289 experimental_skip_checkpoint=False): 

1290 """Saves a SavedModel while returning all saved nodes and their paths. 

1291 

1292 Please see `tf.saved_model.save` for details. 

1293 

1294 Args: 

1295 obj: A trackable object to export. 

1296 export_dir: A directory in which to write the SavedModel. 

1297 signatures: A function or dictionary of functions to save in the SavedModel 

1298 as signatures. 

1299 options: `tf.saved_model.SaveOptions` object for configuring save options. 

1300 experimental_skip_checkpoint: If set to `True`, the checkpoint will not be 

1301 written. 

1302 

1303 Returns: 

1304 A tuple of (a list of saved nodes in the order they are serialized to the 

1305 `SavedObjectGraph`, dictionary mapping nodes to one possible path from 

1306 the root node to the key node) 

1307 """ 

1308 options = options or save_options.SaveOptions() 

1309 saved_model = saved_model_pb2.SavedModel() 

1310 meta_graph_def = saved_model.meta_graphs.add() 

1311 

1312 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = ( 

1313 _build_meta_graph(obj, signatures, options, meta_graph_def)) 

1314 saved_model.saved_model_schema_version = ( 

1315 constants.SAVED_MODEL_SCHEMA_VERSION) 

1316 

1317 # Write the checkpoint, copy assets into the assets directory, and write out 

1318 # the SavedModel proto itself. 

1319 if not experimental_skip_checkpoint: 

1320 path_helpers.get_or_create_variables_dir(export_dir) 

1321 ckpt_options = checkpoint_options.CheckpointOptions( 

1322 experimental_io_device=options.experimental_io_device) 

1323 object_saver.save( 

1324 path_helpers.get_variables_path(export_dir), options=ckpt_options) 

1325 builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, 

1326 export_dir) 

1327 # Note that this needs to be the last file operation when saving the 

1328 # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an 

1329 # indication that the SavedModel is completely written. 

1330 if context.executing_eagerly(): 

1331 try: 

1332 context.async_wait() # Ensure save operations have completed. 

1333 except errors.NotFoundError as err: 

1334 raise FileNotFoundError( 

1335 f"{err}\n You may be trying to save on a different device from the " 

1336 "computational device. Consider setting the " 

1337 "`experimental_io_device` option in `tf.saved_model.SaveOptions` " 

1338 "to the io_device such as '/job:localhost'.") from err 

1339 

1340 # We will slowly migrate code in this function to pywrap_saved_model.Save 

1341 # as we build up the C++ API. 

1342 pywrap_saved_model.Save(export_dir) 

1343 

1344 saved_model_serialized = saved_model.SerializeToString(deterministic=True) 

1345 

1346 fingerprinting_utils.write_fingerprint(export_dir, saved_model_serialized) 

1347 

1348 path = file_io.join( 

1349 compat.as_str(export_dir), 

1350 compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) 

1351 file_io.atomic_write_string_to_file(path, saved_model_serialized) 

1352 

1353 # Save debug info, if requested. 

1354 if options.save_debug_info: 

1355 _export_debug_info(exported_graph, export_dir) 

1356 # For privacy concerns, please see the note in 

1357 # tensorflow/cc/saved_model/metrics.h 

1358 metrics.SetWritePath(saved_model_path=str(export_dir)) 

1359 # Clean reference cycles so repeated export()s don't make work for the garbage 

1360 # collector. Before this point, we need to keep references to captured 

1361 # constants in the saved graph. 

1362 ops.dismantle_graph(exported_graph) 

1363 

1364 return saved_nodes, node_paths 

1365 

1366 

1367def export_meta_graph(obj, filename, signatures=None, options=None): 

1368 """Exports the MetaGraph proto of the `obj` to a file. 

1369 

1370 This function goes through the same procedures saved_model.save goes to 

1371 produce the given object's MetaGraph, then saves it to the given file. It 

1372 skips saving checkpoint information, and is useful when all one wants is the 

1373 graph defining the model. 

1374 

1375 Args: 

1376 obj: A trackable object to build the MetaGraph from. 

1377 filename: The file into which to write the MetaGraph. 

1378 signatures: Optional, either a `tf.function` with an input signature 

1379 specified or the result of `f.get_concrete_function` on a 

1380 `@tf.function`-decorated function `f`, in which case `f` will be used to 

1381 generate a signature for the SavedModel under the default serving 

1382 signature key. `signatures` may also be a dictionary, in which case it 

1383 maps from signature keys to either `tf.function` instances with input 

1384 signatures or concrete functions. The keys of such a dictionary may be 

1385 arbitrary strings, but will typically be from the 

1386 `tf.saved_model.signature_constants` module. 

1387 options: Optional, `tf.saved_model.SaveOptions` object that specifies 

1388 options for saving. 

1389 """ 

1390 options = options or save_options.SaveOptions() 

1391 export_dir = os.path.dirname(filename) 

1392 meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph( 

1393 obj, signatures, options) 

1394 

1395 file_io.atomic_write_string_to_file( 

1396 filename, meta_graph_def.SerializeToString(deterministic=True)) 

1397 

1398 # Save debug info, if requested. 

1399 if options.save_debug_info: 

1400 _export_debug_info(exported_graph, export_dir) 

1401 

1402 # Clean reference cycles so repeated export()s don't make work for the garbage 

1403 # collector. Before this point, we need to keep references to captured 

1404 # constants in the saved graph. 

1405 ops.dismantle_graph(exported_graph) 

1406 

1407 

1408def _build_meta_graph_impl(obj, signatures, options, meta_graph_def=None): 

1409 """Creates a MetaGraph containing the resources and functions of an object.""" 

1410 if ops.inside_function(): 

1411 raise AssertionError( 

1412 "`tf.saved_model.save` is not supported inside a traced @tf.function. " 

1413 "Move the call to the outer eagerly-executed context.") 

1414 # pylint: enable=line-too-long 

1415 if not isinstance(obj, base.Trackable): 

1416 raise ValueError( 

1417 "Expected an object of type `Trackable`, such as `tf.Module` or a " 

1418 f"subclass of the `Trackable` class, for export. Got {obj} " 

1419 f"with type {type(obj)}.") 

1420 meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() 

1421 

1422 augmented_graph_view = _AugmentedGraphView(obj) 

1423 if signatures is None: 

1424 signatures = signature_serialization.find_function_to_export( 

1425 augmented_graph_view) 

1426 

1427 signatures, wrapped_functions, defaults = ( 

1428 signature_serialization.canonicalize_signatures(signatures) 

1429 ) 

1430 signature_serialization.validate_augmented_graph_view(augmented_graph_view) 

1431 signature_map = signature_serialization.create_signature_map(signatures) 

1432 augmented_graph_view.set_signature(signature_map, wrapped_functions) 

1433 

1434 # Use _SaveableView to provide a frozen listing of properties and functions. 

1435 saveable_view = _SaveableView(augmented_graph_view, options) 

1436 object_saver = checkpoint.TrackableSaver(augmented_graph_view) 

1437 asset_info, exported_graph = _fill_meta_graph_def( 

1438 meta_graph_def, 

1439 saveable_view, 

1440 signatures, 

1441 options.namespace_whitelist, 

1442 options.experimental_custom_gradients, 

1443 defaults, 

1444 ) 

1445 if options.function_aliases: 

1446 function_aliases = meta_graph_def.meta_info_def.function_aliases 

1447 for alias, func in options.function_aliases.items(): 

1448 if isinstance(func, types_core.ConcreteFunction): 

1449 function_aliases[func.name] = alias 

1450 elif isinstance(func, polymorphic_function.Function): 

1451 for fdef in func._list_all_concrete_functions(): # pylint: disable=protected-access 

1452 function_aliases[fdef.name] = alias 

1453 else: 

1454 raise TypeError( 

1455 f"Unsupported type f{type(func)}. Functions in `function_aliases`" 

1456 " should be created by tf.function, or concrete functions." 

1457 ) 

1458 object_graph_proto = _serialize_object_graph(saveable_view, 

1459 asset_info.asset_index) 

1460 meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) 

1461 return (meta_graph_def, exported_graph, object_saver, asset_info, 

1462 saveable_view.nodes, saveable_view.node_paths) 

1463 

1464 

1465def _build_meta_graph(obj, signatures, options, meta_graph_def=None): 

1466 """Creates a MetaGraph under a save context. 

1467 

1468 Args: 

1469 obj: A trackable object to build the MetaGraph from. 

1470 signatures: Can be a `tf.function` with an input signature specified or the 

1471 result of `f.get_concrete_function` on a `@tf.function`-decorated function 

1472 `f`. `signatures` may also be a dictionary, in which case it maps from 

1473 signature keys to `tf.function` instances. If None, finds signature to 

1474 export from the `@tf.function`-decorated methods in `obj`. 

1475 options: `tf.saved_model.SaveOptions` object that specifies options for 

1476 saving. 

1477 meta_graph_def: Optional, the MetaGraphDef proto fill. 

1478 

1479 Raises: 

1480 AssertionError: If `export_meta_graph` is executing inside a `tf.function`. 

1481 ValueError: If `obj` is not trackable. 

1482 

1483 Returns: 

1484 meta_graph_def: Filled MetaGraphDef proto 

1485 exported_graph: `tf.Graph` object generated from `obj`. 

1486 object_saver: `checkpoint.TrackableSaver` of the `obj` and its dependencies. 

1487 asset_info: `_AssetInfo` tuple containing external assets in the `obj`. 

1488 saveable_view.nodes: _SaveableView nodes. 

1489 saveable_view.node_paths: _SaveableView paths. 

1490 """ 

1491 

1492 with save_context.save_context(options): 

1493 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)