Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py: 19%

461 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Import a trackable object from a SavedModel.""" 

16 

17import collections 

18import functools 

19import os 

20import sys 

21 

22from absl import logging 

23 

24from tensorflow.core.framework import graph_debug_info_pb2 

25from tensorflow.core.function.capture import restore_captures 

26from tensorflow.python.checkpoint import checkpoint 

27from tensorflow.python.checkpoint import checkpoint_options 

28from tensorflow.python.checkpoint import graph_view 

29from tensorflow.python.checkpoint import restore 

30from tensorflow.python.distribute import distribute_lib 

31from tensorflow.python.distribute import distribute_utils 

32from tensorflow.python.distribute import values_util 

33from tensorflow.python.eager import context 

34from tensorflow.python.eager import function 

35from tensorflow.python.eager.polymorphic_function import saved_model_utils as function_saved_model_utils 

36from tensorflow.python.framework import config 

37from tensorflow.python.framework import constant_op 

38from tensorflow.python.framework import dtypes 

39from tensorflow.python.framework import errors 

40from tensorflow.python.framework import ops 

41from tensorflow.python.ops import array_ops 

42from tensorflow.python.ops import control_flow_assert 

43from tensorflow.python.ops import control_flow_ops 

44from tensorflow.python.ops import lookup_ops 

45from tensorflow.python.ops import resource_variable_ops 

46from tensorflow.python.ops import variables 

47from tensorflow.python.saved_model import fingerprinting 

48from tensorflow.python.saved_model import fingerprinting_utils 

49from tensorflow.python.saved_model import function_deserialization 

50from tensorflow.python.saved_model import load_options 

51from tensorflow.python.saved_model import load_v1_in_v2 

52from tensorflow.python.saved_model import loader_impl 

53from tensorflow.python.saved_model import path_helpers 

54from tensorflow.python.saved_model import registration 

55from tensorflow.python.saved_model import revived_types 

56from tensorflow.python.saved_model import utils_impl as saved_model_utils 

57from tensorflow.python.saved_model.pywrap_saved_model import metrics 

58from tensorflow.python.trackable import asset 

59from tensorflow.python.trackable import autotrackable 

60from tensorflow.python.trackable import base 

61from tensorflow.python.trackable import data_structures 

62from tensorflow.python.trackable import resource 

63from tensorflow.python.trackable import trackable_utils 

64from tensorflow.python.training import py_checkpoint_reader 

65from tensorflow.python.training.saving import saveable_object_util 

66from tensorflow.python.util import nest 

67from tensorflow.python.util.tf_export import tf_export 

68 

69# API label for SavedModel metrics. 

70_LOAD_V2_LABEL = "load_v2" 

71# Built-in registrations use the "oneof kind" field in the SavedObject proto, 

72# instead of "registered_name" field. The "kind" field has almost the same 

73# functionality as the registered_name, but only contains built-in TensorFlow 

74# types (like variable, functions, assets). 

75_BUILT_IN_REGISTRATIONS = { 

76 "asset": asset.Asset, 

77 "resource": resource.RestoredResource, 

78 "constant": function_saved_model_utils.TrackableConstant} 

79 

80 

81def _unused_handle(): 

82 """Returns a placeholder as a handle that is not supposed to be accessed.""" 

83 error_message = ("Trying to access a placeholder that is not supposed to be " 

84 "executed. This means you are executing a graph generated " 

85 "from the cross-replica context in an in-replica context.") 

86 save_error_message = ( 

87 "It seems that you are trying to save a " 

88 "tf.types.experimental.ConcreteFunction that involves a distributed " 

89 "model, and the model contains parts that are loaded form a SavedModel. " 

90 "It's not supported to save such tf.types.experimental.ConcreteFunction. " 

91 "Try saving a tf.function with input_signature instead, and file a bug if" 

92 " there are still issues.") 

93 

94 assert_op = control_flow_assert.Assert( 

95 array_ops.placeholder_with_default(False, shape=()), [error_message]) 

96 if (not context.executing_eagerly() 

97 ) and ops.get_default_graph().building_function: 

98 ops.get_default_graph().mark_as_unsaveable(save_error_message) 

99 

100 with ops.control_dependencies([assert_op]): 

101 return array_ops.placeholder(dtype=dtypes.resource) 

102 

103 

104class _WrapperFunction(function.ConcreteFunction): 

105 """A class wraps a concrete function to handle different distributed contexts. 

106 

107 The reason for wrapping a concrete function is because the _captured_inputs 

108 fields used for in-replica context and cross-replica context are different. 

109 When `load()` is called from within a tf.distribute.strategy scope, the 

110 captured inputs are distributed variables. When using these distributed 

111 variables during calling the function, we need different approaches when it is 

112 in-replica and when it is not in-replica. When it is in replica, naturally we 

113 should use the corresponding component of the distributed variable; when it is 

114 not in-replica, calling the function should mean that it is constructing a 

115 graph that is not actually going to be used. A typical use case is when 

116 constructing a functional model. In this case, return a placeholder with a 

117 control dependency to ensure that is never accessed. 

118 """ 

119 

120 def __init__(self, concrete_function): 

121 # Shallow copy the concrete_function 

122 self.__dict__.update(vars(concrete_function)) 

123 

124 def _call_flat(self, args, captured_inputs): 

125 

126 def get_handle(x): 

127 return x.handle if distribute_utils.is_distributed_variable(x) else x 

128 

129 def get_unused_handle(x): 

130 return _unused_handle() if distribute_utils.is_distributed_variable(x) \ 

131 else x 

132 

133 if (distribute_lib.get_replica_context() is not None or 

134 values_util.is_saving_non_distributed()): 

135 # If we're in the replica context or are saving a non-distributed version 

136 # of the model, we resolve the captured variables to the corresponding 

137 # resource handle. In both situation we call var.handle, but it has 

138 # different behavior. In the replica context, var.handle resolves the 

139 # replica local variable handle if the variable is replicated. When saving 

140 # a non-distributed version of the model, var.handle resolves to the 

141 # primary variable handle, since we only save one copy of a replicated 

142 # variable. 

143 captured_inputs = list(map(get_handle, captured_inputs)) 

144 else: # cross-replica context 

145 captured_inputs = list(map(get_unused_handle, captured_inputs)) 

146 return super()._call_flat(args, captured_inputs) 

147 

148 

149class Loader(object): 

150 """Helper class to load an object-based SavedModel.""" 

151 

152 def __init__(self, object_graph_proto, saved_model_proto, export_dir, 

153 ckpt_options, save_options, filters): 

154 meta_graph = saved_model_proto.meta_graphs[0] 

155 self._asset_file_def = meta_graph.asset_file_def 

156 self._operation_attributes = { 

157 node.name: node.attr for node in meta_graph.graph_def.node} 

158 self._proto = object_graph_proto 

159 self._export_dir = export_dir 

160 self._concrete_functions = ( 

161 function_deserialization.load_function_def_library( 

162 library=meta_graph.graph_def.library, 

163 saved_object_graph=self._proto, 

164 wrapper_function=_WrapperFunction)) 

165 # Store a set of all concrete functions that have been set up with 

166 # captures. 

167 self._restored_concrete_functions = set() 

168 self._checkpoint_options = ckpt_options 

169 self._save_options = save_options 

170 

171 # Metagraph has a mapping from FunctionDef name to aliases 

172 self._concrete_function_aliases = meta_graph.meta_info_def.function_aliases 

173 # Create a mapping from alias to Function, which can be used with 

174 # SaveOptions 

175 self.function_aliases = {} 

176 

177 self._pretty_printer = checkpoint.ObjectGraphProtoPrettyPrinter(self._proto) 

178 

179 # Stores user-defined node_filters argument. 

180 self._node_filters = filters 

181 # Stores map of string paths to integers. 

182 self._node_path_to_id = self._convert_node_paths_to_ints() 

183 self._loaded_nodes = {} 

184 if isinstance(filters, dict): 

185 # If node_filters is a dict, then the values may contain already created 

186 # trackable objects. In this case, create a dictionary mapping node IDs to 

187 # the already created nodes. This dict will be updated in 

188 # `_retrieve_all_filtered_nodes` with tracked children. 

189 for node_path, node in filters.items(): 

190 if isinstance(node, tuple): 

191 self._loaded_nodes[self._node_path_to_id[node_path]] = node 

192 else: 

193 self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr) 

194 

195 # Get a list of all integer node ids to load, or None if all nodes should be 

196 # loaded. This list includes ids of child nodes. 

197 self._filtered_nodes = self._retrieve_all_filtered_nodes() 

198 

199 # Order all nodes or filtered nodes using the dependencies. 

200 self._ordered_node_ids = self._generate_ordered_node_ids() 

201 

202 self._load_all() 

203 

204 if not save_options.experimental_skip_checkpoint: 

205 self._restore_checkpoint() 

206 for node in self._nodes: 

207 if isinstance(node, resource.CapturableResource): 

208 init_op = node._initialize() # pylint: disable=protected-access 

209 if not context.executing_eagerly(): 

210 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 

211 

212 def _convert_node_paths_to_ints(self): 

213 """Maps all string node paths in node_filters to the int node ids.""" 

214 if self._node_filters is None: 

215 return None 

216 path_to_int = {} 

217 for node_id in self._node_filters: 

218 int_node_id = None 

219 if isinstance(node_id, str): 

220 node_path = node_id.split(".") 

221 if node_path[0] != "root": 

222 raise ValueError( 

223 "When passing string identifiers to node_filters, the first name" 

224 f" must be root. Received {node_path[0]}.") 

225 int_node_id = 0 

226 for n, name in enumerate(node_path[1:]): 

227 int_node_id = self._find_node_child( 

228 int_node_id, name, ".".join(node_path[:n+2])) 

229 path_to_int[node_id] = int_node_id 

230 else: 

231 raise TypeError("Elements in node_filters must be strings.") 

232 return path_to_int 

233 

234 def _retrieve_all_filtered_nodes(self): 

235 """Traverses through the object graph to get the IDs of all nodes to load. 

236 

237 As a side-effect, if node_filters is a dictionary that contains already- 

238 created objects, then the children tracked by those objects will be 

239 added to node_filters. 

240 

241 Returns: 

242 List of all nodes to load, or None if all nodes should be loaded. 

243 

244 """ 

245 if self._node_filters is None: 

246 return None # All nodes should be loaded. 

247 

248 all_filtered_nodes = set() 

249 nodes_to_visit = list(self._node_filters) 

250 

251 while nodes_to_visit: 

252 node_path = nodes_to_visit.pop(0) 

253 node_id = self._node_path_to_id[node_path] 

254 if node_id in all_filtered_nodes: 

255 continue 

256 all_filtered_nodes.add(node_id) 

257 

258 node, setter = self._loaded_nodes.get(node_id, (None, None)) 

259 if node is not None: 

260 if not isinstance(node, base.Trackable): 

261 raise TypeError( 

262 "Error when processing dictionary values passed to nodes_to_load." 

263 f"Object at {node_path} is expected to be a checkpointable (i.e. " 

264 "'trackable') TensorFlow object (e.g. tf.Variable, tf.Module or " 

265 "Keras layer).") 

266 node._maybe_initialize_trackable() # pylint: disable=protected-access 

267 

268 for reference in self._proto.nodes[node_id].children: 

269 child_object, _ = self._loaded_nodes.get( 

270 reference.node_id, (None, None)) 

271 

272 # See if node already tracks the child reference, in which case add the 

273 # child to the loaded_nodes dict. 

274 if child_object is None and node is not None: 

275 child_object = node._lookup_dependency(reference.local_name) # pylint: disable=protected-access 

276 if isinstance(child_object, data_structures.TrackableDataStructure): 

277 # Make setattr a noop to avoid overwriting already existing data 

278 # structures. 

279 setter = lambda *args: None 

280 

281 self._loaded_nodes[reference.node_id] = (child_object, setter) 

282 

283 child_path = "{}.{}".format(node_path, reference.local_name) 

284 self._node_path_to_id[child_path] = reference.node_id 

285 nodes_to_visit.append(child_path) 

286 

287 if 0 in all_filtered_nodes: 

288 return None 

289 return all_filtered_nodes 

290 

291 def _find_node_child(self, node_id, child_name, path): 

292 for reference in self._proto.nodes[node_id].children: 

293 if reference.local_name == child_name: 

294 return reference.node_id 

295 raise ValueError(f"Unable to find node {path}.") 

296 

297 def _load_all(self): 

298 """Loads all nodes and functions from the SavedModel and their edges.""" 

299 self._load_nodes() 

300 self._load_edges() 

301 

302 # Set up concrete functions that aren't part of the object graph 

303 # (e.g. gradient functions) 

304 self._setup_remaining_functions() 

305 self._load_checkpoint_save_and_restore_functions() 

306 

307 def _load_checkpoint_save_and_restore_functions(self): 

308 """Restores the checkpoint-related save/restore functions to all nodes.""" 

309 temp_session = [None] 

310 for node_id, proto in self._iter_all_nodes(): 

311 node = self.get(node_id) 

312 if proto.saveable_objects.keys() == { 

313 trackable_utils.SERIALIZE_TO_TENSORS_NAME}: 

314 # Restore Trackable serialize- and restore-from-tensor functions. 

315 assert len(proto.saveable_objects) == 1 

316 saveable_object_proto = next(iter(proto.saveable_objects.values())) 

317 save_fn_id = saveable_object_proto.save_function 

318 restore_fn_id = saveable_object_proto.restore_function 

319 node._serialize_to_tensors = self.get(save_fn_id) # pylint: disable=protected-access 

320 node._restore_from_tensors = self.get(restore_fn_id) # pylint: disable=protected-access 

321 else: 

322 # Restore legacy SaveableObject functions. 

323 saveable_fn_by_name = {} 

324 for name, saveable_object_proto in proto.saveable_objects.items(): 

325 save_fn_id = saveable_object_proto.save_function 

326 restore_fn_id = saveable_object_proto.restore_function 

327 saveable_fn_by_name[name] = (self.get(save_fn_id), 

328 self.get(restore_fn_id)) 

329 

330 node._self_saveable_object_factories = ( # pylint: disable=protected-access 

331 saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, 

332 temp_session)) 

333 

334 def _load_edges(self): 

335 """Adds edges from objects to other objects and functions.""" 

336 for node_id, object_proto in self._iter_all_nodes(): 

337 self._add_object_graph_edges(object_proto, node_id) 

338 

339 # If root object isn't loaded, then create edges from the root for 

340 # checkpoint compatibility. 

341 if self._filtered_nodes is not None and 0 not in self._filtered_nodes: 

342 root = self.get(0) 

343 for node_path in self._node_filters: 

344 loaded_node = self._nodes[self._node_path_to_id[node_path]] 

345 path = node_path.split(".") 

346 current_node = root 

347 for name in path[1:-1]: 

348 if not hasattr(current_node, name): 

349 setattr(current_node, name, self._recreate_base_user_object()[0]) 

350 current_node = getattr(current_node, name) 

351 if not hasattr(current_node, path[-1]): 

352 setattr(current_node, path[-1], loaded_node) 

353 

354 def _add_object_graph_edges(self, proto, node_id): 

355 """Adds edges from an object to its children.""" 

356 obj = self._nodes[node_id] 

357 setter = self._node_setters[node_id] 

358 

359 for reference in proto.children: 

360 setter(obj, reference.local_name, self._nodes[reference.node_id]) 

361 # Note: if an object has an attribute `__call__` add a class method 

362 # that allows `obj()` syntax to work. This is done per-instance to 

363 # allow `callable` to be used to find out if an object is callable. 

364 if reference.local_name == "__call__" and not callable(obj): 

365 setattr(type(obj), "__call__", _call_attribute) 

366 

367 def _setup_remaining_functions(self): 

368 concrete_function_names = sorted(self._proto.concrete_functions.keys()) 

369 for name in concrete_function_names: 

370 if name in self._restored_concrete_functions: 

371 continue 

372 self._setup_function_captures(name, self._nodes) 

373 

374 def _setup_function_captures(self, concrete_function_name, nodes): 

375 """Setup captures and variables in a restored function.""" 

376 if concrete_function_name in self._restored_concrete_functions: 

377 return 

378 self._restored_concrete_functions.add(concrete_function_name) 

379 concrete_function = self._concrete_functions[concrete_function_name] 

380 proto = self._proto.concrete_functions[concrete_function_name] 

381 inputs = [nodes[node_id] for node_id in proto.bound_inputs] 

382 restore_captures.restore_captures(concrete_function, inputs) 

383 

384 def _initialize_loaded_nodes(self): 

385 nodes = {} 

386 node_setters = {} 

387 for node_id, (node, setter) in self._loaded_nodes.items(): 

388 nodes[node_id] = node 

389 node_setters[node_id] = setter 

390 return nodes, node_setters 

391 

392 def _get_node_dependencies(self, proto): 

393 """Returns a dictionary of all dependencies of an object. 

394 

395 Args: 

396 proto: A SavedObject proto. 

397 

398 Returns: 

399 Dict mapping string dependency name *or* int node id to the node id. 

400 The int node id key is used for mapping function captures. 

401 """ 

402 dependencies = {ref.local_name: ref.node_id for ref in proto.dependencies} 

403 kind = proto.WhichOneof("kind") 

404 if kind == "function": 

405 concrete_functions = proto.function.concrete_functions 

406 for fn_name in concrete_functions: 

407 for bound_input in self._proto.concrete_functions[fn_name].bound_inputs: 

408 dependencies[bound_input] = bound_input 

409 elif kind == "bare_concrete_function": 

410 fn_name = proto.bare_concrete_function.concrete_function_name 

411 for bound_input in self._proto.concrete_functions[fn_name].bound_inputs: 

412 dependencies[bound_input] = bound_input 

413 elif kind == "resource": 

414 # Make sure that the resource creator is listed as a dependency. 

415 for child in proto.children: 

416 if child.local_name == "_create_resource": 

417 dependencies["_create_resource"] = child.node_id 

418 return dependencies 

419 

420 def _generate_ordered_node_ids(self): 

421 """Orders the node ids so that dependencies appear first.""" 

422 if self._filtered_nodes is None: 

423 unordered_ids = range(len(self._proto.nodes)) 

424 else: 

425 unordered_ids = list(self._filtered_nodes) 

426 

427 # Maps node ids -> list of dependencies (ids of other nodes that must be 

428 # loaded before it). 

429 dependency_map = collections.defaultdict(list) 

430 for node_id in unordered_ids: 

431 deps = dependency_map[node_id] 

432 if self._loaded_nodes.get(node_id) is not None: 

433 # Deps are only used if the node has not been created. 

434 continue 

435 proto = self._proto.nodes[node_id] 

436 for dep in set(self._get_node_dependencies(proto).values()): 

437 deps.append(dep) 

438 if self._filtered_nodes is not None and dep not in self._filtered_nodes: 

439 raise ValueError( 

440 "Unable to partially load SavedModel since the specified filter " 

441 "does not include all required objects for loading (e.g. " 

442 "variables used in functions or deserialization dependencies). " 

443 "Please include this path in the filter: " 

444 f"{self._pretty_printer.node_names[dep]}") 

445 

446 # Add optimizer slot variable to dependency map. 

447 prev_slot = None 

448 for slot_variable_proto in proto.slot_variables: 

449 slot_variable_node_id = slot_variable_proto.slot_variable_node_id 

450 # The optimizer and original variable must be created before the slot 

451 # variable, since the slot variable is generated using the Optimizer's 

452 # add_slot API. 

453 slot_deps = dependency_map[slot_variable_node_id] 

454 slot_deps.append(node_id) 

455 slot_deps.append(slot_variable_proto.original_variable_node_id) 

456 

457 if prev_slot is not None: 

458 # Add previous slot to deps so that the optimizer slot variables are 

459 # added in order. The ordering is needed because the slot name and 

460 # variable are both added to ordered lists, which are exposed to the 

461 # user via `Optimizer.get_slot_names()` and `Optimizer.weights`. 

462 # TODO(kathywu): Maybe enforce some sort of deterministic ordering in 

463 # `order_by_dependency` to avoid doing this? 

464 slot_deps.append(prev_slot) 

465 prev_slot = slot_variable_node_id 

466 try: 

467 return list(trackable_utils.order_by_dependency(dependency_map)) 

468 except trackable_utils.CyclicDependencyError: 

469 # This should not happen since there is already a validation for cycles 

470 # when saving, but raise an error just in case. 

471 raise ValueError("Encountered a cycle in the deserialization dependencies" 

472 "in the SavedModel. This is extremely unexpected, please" 

473 "file a bug and make sure you are not manually modifying" 

474 " the SavedModel.") 

475 

476 def _iter_all_nodes(self): 

477 for node_id in self._ordered_node_ids: 

478 yield node_id, self._proto.nodes[node_id] 

479 

480 def _load_nodes(self): 

481 """Load all saved objects.""" 

482 # `nodes` maps from node ids to recreated objects 

483 # `node_setters` maps from node ids to setter functions 

484 # (same signature as setattr) for setting children. 

485 nodes, node_setters = self._initialize_loaded_nodes() 

486 

487 # Figure out which objects are slot variables. These objects are created 

488 # with Optimizer.add_slot rather than _recreate_variable. 

489 # Maps slot node id -> optimizer node id, SlotVariableReference proto 

490 slot_variable_node_ids = {} 

491 

492 for node_id, proto in self._iter_all_nodes(): 

493 for slot_variable_proto in proto.slot_variables: 

494 slot_variable_node_id = slot_variable_proto.slot_variable_node_id 

495 slot_variable_node_ids[slot_variable_node_id] = (node_id, 

496 slot_variable_proto) 

497 

498 # Re-create everything. 

499 for node_id, proto in self._iter_all_nodes(): 

500 if nodes.get(node_id) is not None: 

501 continue 

502 elif node_id in slot_variable_node_ids: 

503 # Use the public Optimizer interface when creating slot variables. 

504 optimizer_node_id, slot_variable_proto = slot_variable_node_ids[node_id] 

505 optimizer_object = nodes[optimizer_node_id] 

506 optimized_variable = nodes[ 

507 slot_variable_proto.original_variable_node_id] 

508 slot_variable = optimizer_object.add_slot( 

509 var=optimized_variable, 

510 slot_name=slot_variable_proto.slot_name) 

511 nodes[slot_variable_proto.slot_variable_node_id] = slot_variable 

512 node_setters[slot_variable_proto.slot_variable_node_id] = setattr 

513 else: 

514 node, setter = self._recreate(proto, node_id, nodes) 

515 nodes[node_id] = node 

516 node_setters[node_id] = setter 

517 

518 # If root object is not loaded, add a dummy root object for checkpoint 

519 # compatibility. 

520 if 0 not in nodes: 

521 nodes[0] = self._recreate_base_user_object()[0] 

522 

523 self._nodes = [nodes.get(node_id) 

524 for node_id in range(len(self._proto.nodes))] 

525 self._node_setters = node_setters 

526 

527 def _restore_checkpoint(self): 

528 """Load state from checkpoint into the deserialized objects.""" 

529 variables_path = path_helpers.get_variables_path(self._export_dir) 

530 # TODO(b/205010730): Clean use of private methods of TrackableSaver. 

531 # pylint: disable=protected-access 

532 saver = checkpoint.TrackableSaver(graph_view.ObjectGraphView(self.get(0))) 

533 with ops.device("CPU"): 

534 saver._file_prefix_placeholder = constant_op.constant(variables_path) 

535 if self._save_options.allow_partial_checkpoint: 

536 load_status = saver.restore(variables_path, 

537 self._checkpoint_options).expect_partial() 

538 load_status.assert_nontrivial_match() 

539 else: 

540 load_status = saver.restore(variables_path, self._checkpoint_options) 

541 load_status.assert_existing_objects_matched() 

542 ckpt = load_status._checkpoint 

543 

544 if not context.executing_eagerly(): 

545 reader = py_checkpoint_reader.NewCheckpointReader(variables_path) 

546 

547 # When running in eager mode, the `restore` call above has already run and 

548 # restored the state of trackables, and calling `position.restore_ops()` 

549 # would re-run the restore. In graph mode, that will return a cached list 

550 # of ops that must run to restore the object on that position. We have to 

551 # wire them in the initializers of the objects so that they get 

552 # initialized properly when using common practices (e.g. the ones used by 

553 # ManagedSession) without further user action. 

554 for object_id, obj in dict(ckpt.object_by_proto_id).items(): 

555 position = restore.CheckpointPosition(checkpoint=ckpt, 

556 proto_id=object_id) 

557 registered_saver = position.get_registered_saver_name() 

558 if registered_saver: 

559 raise NotImplementedError( 

560 "Loading a SavedModel that uses registered checkpoint saver is " 

561 f"not supported in graph mode. The loaded object {obj} uses the " 

562 f"saver registered with the name {registered_saver}.") 

563 

564 restore_ops = position.restore_ops(reader) 

565 if restore_ops: 

566 if resource_variable_ops.is_resource_variable(obj): 

567 if len(restore_ops) == 1: 

568 obj._initializer_op = restore_ops[0] 

569 else: 

570 obj._initializer_op = control_flow_ops.group(*restore_ops) 

571 elif (isinstance(obj, lookup_ops.LookupInterface) or 

572 isinstance(obj, resource.CapturableResource)): 

573 # We don't need to check for eager execution here, since this code 

574 # path should only be taken if we are restoring in graph mode. 

575 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) 

576 else: 

577 raise NotImplementedError( 

578 f"Unable to restore state of object {obj} from the checkpoint.") 

579 

580 def adjust_debug_info_func_names(self, debug_info): 

581 """Rewrite func names in the debug info by using the concrete func names.""" 

582 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 

583 output_debug_info.files[:] = debug_info.files 

584 for key in debug_info.traces: 

585 node, func = key.split("@") 

586 new_func = "" 

587 if func in self._concrete_functions: 

588 new_func = self._concrete_functions[func].function_def.signature.name 

589 output_debug_info.traces[node + "@" + new_func].CopyFrom( 

590 debug_info.traces[key]) 

591 return output_debug_info 

592 

593 def get(self, node_id): 

594 if isinstance(node_id, str): 

595 node_id = self._node_path_to_id[node_id] 

596 return self._nodes[node_id] 

597 

598 def _recreate(self, proto, node_id, nodes): 

599 """Creates a Python object from a SavedObject protocol buffer. 

600 

601 Args: 

602 proto: a SavedObject proto 

603 node_id: int, the index of this object in the SavedObjectGraph node list. 

604 nodes: dict mapping int node_ids -> created objects. 

605 

606 Returns: 

607 The recreated object, and the set-attribute function for reconnecting 

608 the trackable children. 

609 """ 

610 registered_class = registration.get_registered_class(proto.registered_name) 

611 if registered_class is None: 

612 registered_class = _BUILT_IN_REGISTRATIONS.get(proto.WhichOneof("kind")) 

613 

614 dependencies = {} 

615 for key, dep_node_id in self._get_node_dependencies(proto).items(): 

616 dependencies[key] = nodes[dep_node_id] 

617 

618 if registered_class: 

619 obj = registered_class._deserialize_from_proto( # pylint: disable=protected-access 

620 proto=proto.serialized_user_proto, 

621 object_proto=proto, 

622 dependencies=dependencies, 

623 export_dir=self._export_dir, 

624 asset_file_def=self._asset_file_def, 

625 operation_attributes=self._operation_attributes) 

626 if isinstance(obj, base.Trackable): 

627 setter = type(obj)._add_trackable_child # pylint: disable=protected-access 

628 else: 

629 # Returned object may be non-Trackable (e.g. when restoring captures). 

630 setter = setattr 

631 return obj, setter 

632 else: 

633 return self._recreate_default(proto, node_id, dependencies) 

634 

635 def _recreate_default(self, proto, node_id, deps): 

636 """Creates a Python object from a SavedObject protocol buffer.""" 

637 factory = { 

638 "user_object": ( 

639 lambda: self._recreate_user_object(proto.user_object, node_id)), 

640 "function": lambda: self._recreate_function(proto.function, deps), 

641 "bare_concrete_function": functools.partial( 

642 self._recreate_bare_concrete_function, 

643 proto=proto.bare_concrete_function, dependencies=deps), 

644 "variable": lambda: self._recreate_variable(proto.variable), 

645 "captured_tensor": functools.partial( 

646 self._get_tensor_from_fn, proto.captured_tensor), 

647 } 

648 kind = proto.WhichOneof("kind") 

649 if kind not in factory: 

650 raise ValueError(f"Unknown SavedObject type: {kind}. Expected one of " 

651 f"{list(factory.keys())}.") 

652 return factory[kind]() 

653 

654 def _recreate_user_object(self, proto, node_id): 

655 """Instantiates a SavedUserObject.""" 

656 if proto.identifier == "optimizer": 

657 # Make sure that the Keras optimizers module is imported. This is needed 

658 # to be able to load the "optimizer" object (OptimizerV2), which has 

659 # special logic around adding slot variables with `add_slot` in this file. 

660 try: 

661 import keras.optimizers.legacy as _ # pylint: disable=g-import-not-at-top 

662 except ImportError: 

663 try: 

664 import keras.optimizers.optimizer_v2 as _ # pylint: disable=g-import-not-at-top 

665 except ImportError as e: 

666 raise ImportError( 

667 "Error when importing Keras. Unable to load SavedModel that " 

668 "contains an optimizer without the Keras module.") from e 

669 looked_up = revived_types.deserialize(proto) 

670 if looked_up is None: 

671 return self._recreate_base_user_object(proto, node_id) 

672 return looked_up 

673 

674 def _recreate_base_user_object(self, proto=None, node_id=None): 

675 del proto, node_id 

676 # Note: each user object has its own class. This allows making each one 

677 # individually callable by adding a `__call__` method to the classes of 

678 # the objects instances that have a `__call__` property. 

679 

680 class _UserObject(autotrackable.AutoTrackable): 

681 pass 

682 

683 return _UserObject(), setattr 

684 

685 def _recreate_function(self, proto, dependencies): 

686 fn = function_deserialization.recreate_function( 

687 proto, self._concrete_functions) 

688 for name in proto.concrete_functions: 

689 self._setup_function_captures(name, dependencies) 

690 

691 if self._save_options.experimental_load_function_aliases: 

692 for name in proto.concrete_functions: 

693 if name in self._concrete_function_aliases: 

694 alias = self._concrete_function_aliases[name] 

695 self.function_aliases[alias] = fn 

696 # We only need to save the mapping from alias to a tf.Function 

697 # once even though it can appear multiple times in 

698 # self._concrete_function_aliases due to one-to-many mapping from 

699 # tf.Function to concrete functions. 

700 break 

701 

702 return fn, setattr 

703 

704 def _recreate_bare_concrete_function(self, proto, dependencies): 

705 fn = function_deserialization.setup_bare_concrete_function( 

706 proto, self._concrete_functions) 

707 self._setup_function_captures(proto.concrete_function_name, dependencies) 

708 return fn, setattr 

709 

710 def _recreate_variable(self, proto): 

711 name = proto.name if proto.name else None 

712 if name is not None: 

713 dbg_name = name 

714 else: 

715 dbg_name = "<variable loaded from saved model>" 

716 synchronization, aggregation, trainable = ( 

717 variables.validate_synchronization_aggregation_trainable( 

718 proto.synchronization, proto.aggregation, proto.trainable, 

719 name=dbg_name)) 

720 

721 def uninitialized_variable_creator(next_creator, **kwargs): 

722 """A variable creator that creates uninitialized variables.""" 

723 del next_creator 

724 return resource_variable_ops.UninitializedVariable(**kwargs) 

725 

726 # Create a variable_creator_scope that creates uninitialized variables with 

727 # a lower priority such that a potential distributed variable_creator_scope 

728 # can take precedence. 

729 with ops.get_default_graph()._variable_creator_scope( # pylint: disable=protected-access 

730 uninitialized_variable_creator, 

731 priority=50): 

732 saved_device = proto.device 

733 load_with_device = ( 

734 self._save_options.experimental_variable_policy 

735 ._save_variable_devices() and config.get_soft_device_placement() and 

736 saved_device) 

737 if load_with_device: 

738 with ops.device(saved_device): 

739 return variables.Variable( 

740 shape=proto.shape, 

741 dtype=proto.dtype, 

742 name=name, 

743 trainable=trainable, 

744 synchronization=synchronization, 

745 aggregation=aggregation), setattr 

746 else: 

747 return variables.Variable( 

748 shape=proto.shape, 

749 dtype=proto.dtype, 

750 name=name, 

751 trainable=trainable, 

752 synchronization=synchronization, 

753 aggregation=aggregation), setattr 

754 

755 def _get_tensor_from_fn(self, proto): 

756 outer_graph = self._concrete_functions[proto.concrete_function].graph 

757 captured_tensor = outer_graph.get_tensor_by_name(proto.name) 

758 return captured_tensor, setattr 

759 

760 

761def _call_attribute(instance, *args, **kwargs): 

762 return instance.__call__(*args, **kwargs) 

763 

764 

765@tf_export("saved_model.load", v1=["saved_model.load_v2"]) 

766def load(export_dir, tags=None, options=None): 

767 """Load a SavedModel from `export_dir`. 

768 

769 Signatures associated with the SavedModel are available as functions: 

770 

771 ```python 

772 imported = tf.saved_model.load(path) 

773 f = imported.signatures["serving_default"] 

774 print(f(x=tf.constant([[1.]]))) 

775 ``` 

776 

777 Objects exported with `tf.saved_model.save` additionally have trackable 

778 objects and functions assigned to attributes: 

779 

780 ```python 

781 exported = tf.train.Checkpoint(v=tf.Variable(3.)) 

782 exported.f = tf.function( 

783 lambda x: exported.v * x, 

784 input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 

785 tf.saved_model.save(exported, path) 

786 imported = tf.saved_model.load(path) 

787 assert 3. == imported.v.numpy() 

788 assert 6. == imported.f(x=tf.constant(2.)).numpy() 

789 ``` 

790 

791 _Loading Keras models_ 

792 

793 Keras models are trackable, so they can be saved to SavedModel. The object 

794 returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have 

795 `.fit`, `.predict`, etc. methods). A few attributes and functions are still 

796 available: `.variables`, `.trainable_variables` and `.__call__`. 

797 

798 ```python 

799 model = tf.keras.Model(...) 

800 tf.saved_model.save(model, path) 

801 imported = tf.saved_model.load(path) 

802 outputs = imported(inputs) 

803 ``` 

804 

805 Use `tf.keras.models.load_model` to restore the Keras model. 

806 

807 _Importing SavedModels from TensorFlow 1.x_ 

808 

809 SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat 

810 graph instead of `tf.function` objects. These SavedModels will be loaded with 

811 the following attributes: 

812 

813 * `.signatures`: A dictionary mapping signature names to functions. 

814 * `.prune(feeds, fetches) `: A method which allows you to extract 

815 functions for new subgraphs. This is equivalent to importing the SavedModel 

816 and naming feeds and fetches in a Session from TensorFlow 1.x. 

817 

818 ```python 

819 imported = tf.saved_model.load(path_to_v1_saved_model) 

820 pruned = imported.prune("x:0", "out:0") 

821 pruned(tf.ones([])) 

822 ``` 

823 

824 See `tf.compat.v1.wrap_function` for details. 

825 * `.variables`: A list of imported variables. 

826 * `.graph`: The whole imported graph. 

827 * `.restore(save_path)`: A function that restores variables from a checkpoint 

828 saved from `tf.compat.v1.Saver`. 

829 

830 _Consuming SavedModels asynchronously_ 

831 

832 When consuming SavedModels asynchronously (the producer is a separate 

833 process), the SavedModel directory will appear before all files have been 

834 written, and `tf.saved_model.load` will fail if pointed at an incomplete 

835 SavedModel. Rather than checking for the directory, check for 

836 "saved_model_dir/saved_model.pb". This file is written atomically as the last 

837 `tf.saved_model.save` file operation. 

838 

839 Args: 

840 export_dir: The SavedModel directory to load from. 

841 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 

842 if the SavedModel contains a single MetaGraph, as for those exported from 

843 `tf.saved_model.save`. 

844 options: `tf.saved_model.LoadOptions` object that specifies options for 

845 loading. 

846 

847 Returns: 

848 A trackable object with a `signatures` attribute mapping from signature 

849 keys to functions. If the SavedModel was exported by `tf.saved_model.save`, 

850 it also points to trackable objects, functions, debug info which it has been 

851 saved. 

852 

853 Raises: 

854 ValueError: If `tags` don't match a MetaGraph in the SavedModel. 

855 """ 

856 if isinstance(export_dir, os.PathLike): 

857 export_dir = os.fspath(export_dir) 

858 result = load_partial(export_dir, None, tags, options)["root"] 

859 return result 

860 

861 

862@tf_export("__internal__.saved_model.load_partial", v1=[]) 

863def load_partial(export_dir, filters, tags=None, options=None): 

864 """Partially load a SavedModel (saved from V2). 

865 

866 Similar to `tf.saved_model.load`, but with an additional argument that 

867 lets you specify which nodes to load. 

868 `tf.saved_model.load_partial(export_dir, ["root"])` and 

869 `tf.saved_model.load(export_dir)` are equivalent. 

870 

871 Note: This only works for SavedModels saved with TensorFlow V2 from 

872 `tf.saved_model.save` or Keras. This will not load SavedModels save from 

873 the Estimator API. 

874 

875 In Tensorflow V2, SavedModel stores the **object graph** of the saved object. 

876 The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras 

877 layers, etc.) and edges that are the name of the attributes connecting the 

878 objects. 

879 

880 *Example 1* 

881 

882 ``` 

883 model = tf.Module() 

884 model.child_layer = tf.Module() 

885 model.child_layer.v = tf.Variable(5.) 

886 tf.saved_model.save(model, '/tmp/model') 

887 loaded = tf.__internal__.saved_model.load_partial( 

888 ... '/tmp/model', 

889 ... ['root.child_layer', 'root.child_layer.v']) 

890 loaded['root.child_layer'].v.numpy() 

891 5. 

892 loaded['root.child_layer'].v is loaded['root.child_layer.v'] 

893 True 

894 

895 *Example 2* 

896 model = tf.Module() 

897 model.child_layer = tf.Module() 

898 model.child_layer.v = tf.Variable(5.) 

899 >>> 

900 tf.saved_model.save(model, '/tmp/model') 

901 # Create a variable 

902 new_variable = tf.Variable(0.) 

903 loaded = tf.__internal__.saved_model.load_partial( 

904 ... '/tmp/model', 

905 ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) 

906 loaded['root.child_layer'].v.numpy() 

907 5. 

908 new_variable.numpy() 

909 5. 

910 ``` 

911 

912 **Loading under different distribution strategies** 

913 You can load different parts of the model under different distribution 

914 strategies. Note that this is very experimental so use with care. 

915 

916 ``` 

917 model = tf.Module() 

918 model.layer_1 = tf.Module() 

919 model.layer_1.v = tf.Variable(5.) 

920 model.layer_2 = tf.Module() 

921 model.layer_2.v = tf.Variable(7.) 

922 tf.saved_model.save(model, '/tmp/model') 

923 # Load with no strategy 

924 loaded = tf.__internal__.saved_model.load_partial( 

925 ... '/tmp/model', 

926 ... ['root.layer_1']) 

927 loaded['root.layer_1'].v 

928 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0> 

929 strategy = tf.distribute.MirroredStrategy() 

930 with strategy.scope(): 

931 ... loaded2 = tf.__internal__.saved_model.load_partial( 

932 ... '/tmp/model', 

933 ... ['root.layer_2']) 

934 loaded2['root.layer_2'].v 

935 MirroredVariable:{ 

936 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0> 

937 } 

938 ``` 

939 

940 Args: 

941 export_dir: The SavedModel directory to load from. 

942 filters: A list or dictionary where each element or key is a string 

943 path to nodes that should be loaded. Node paths consist of all the child 

944 attribute names to reach that node in the form: `root.{attribute_name}`. 

945 The loader will load all of the specified nodes and their recursive 

946 descendants. When this option is defined, the loader will return a 

947 dictionary mapping the node paths to the loaded objects. 

948 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 

949 if the SavedModel contains a single MetaGraph, as for those exported from 

950 `tf.saved_model.save`. 

951 options: `tf.saved_model.LoadOptions` object that specifies options for 

952 loading. 

953 

954 Returns: 

955 A dictionary mapping node paths from the filter to loaded objects. 

956 """ 

957 options = options or load_options.LoadOptions() 

958 if tags is not None and not isinstance(tags, set): 

959 # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered 

960 # sequences for nest.flatten, so we put those through as-is. 

961 tags = nest.flatten(tags) 

962 saved_model_proto, debug_info = ( 

963 loader_impl.parse_saved_model_with_debug_info(export_dir)) 

964 

965 if (len(saved_model_proto.meta_graphs) == 1 and 

966 saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 

967 metrics.IncrementReadApi(_LOAD_V2_LABEL) 

968 meta_graph_def = saved_model_proto.meta_graphs[0] 

969 # tensor_content field contains raw bytes in litle endian format 

970 # which causes problems when loaded on big-endian systems 

971 # requiring byteswap 

972 if sys.byteorder == "big": 

973 saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", 

974 "big") 

975 if (tags is not None 

976 and set(tags) != set(meta_graph_def.meta_info_def.tags)): 

977 raise ValueError( 

978 f"Got an incompatible argument to `tags`: {tags}. The SavedModel at " 

979 f"{export_dir} has one MetaGraph with tags " 

980 f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " 

981 "pass 'None', or pass matching tags.") 

982 object_graph_proto = meta_graph_def.object_graph_def 

983 

984 ckpt_options = checkpoint_options.CheckpointOptions( 

985 experimental_io_device=options.experimental_io_device) 

986 with ops.init_scope(): 

987 try: 

988 loader = Loader(object_graph_proto, saved_model_proto, export_dir, 

989 ckpt_options, options, filters) 

990 except errors.NotFoundError as err: 

991 raise FileNotFoundError( 

992 str(err) + "\n You may be trying to load on a different device " 

993 "from the computational device. Consider setting the " 

994 "`experimental_io_device` option in `tf.saved_model.LoadOptions` " 

995 "to the io_device such as '/job:localhost'.") 

996 root = loader.get(0) 

997 root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) 

998 root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version 

999 root.tensorflow_git_version = ( 

1000 meta_graph_def.meta_info_def.tensorflow_git_version) 

1001 metrics.IncrementRead(write_version="2") 

1002 else: 

1003 if filters: 

1004 raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" 

1005 " version) cannot be loaded with node filters.") 

1006 with ops.init_scope(): 

1007 root = load_v1_in_v2.load(export_dir, tags) 

1008 root.graph_debug_info = debug_info 

1009 # For privacy concerns, please see the note in 

1010 # tensorflow/cc/saved_model/metrics.h 

1011 metrics.SetReadPath(saved_model_path=str(export_dir)) 

1012 

1013 # Read and log SavedModel checksum, if it is nonzero. 

1014 try: 

1015 fingerprint = fingerprinting.read_fingerprint(export_dir) 

1016 except FileNotFoundError: 

1017 logging.info( 

1018 "Fingerprint not found. Saved model loading will continue.") 

1019 singleprint = "" 

1020 except RuntimeError: 

1021 logging.exception( 

1022 "Fingerprint was found, but there was an error when reading the proto.") 

1023 singleprint = "" 

1024 else: 

1025 metrics.SetReadFingerprint( 

1026 fingerprint=fingerprinting_utils.to_proto( 

1027 fingerprint).SerializeToString()) 

1028 singleprint = fingerprint.singleprint() 

1029 metrics.SetReadPathAndSingleprint(path=export_dir, singleprint=singleprint) 

1030 

1031 if options.experimental_load_function_aliases: 

1032 if hasattr(root, "function_aliases"): 

1033 raise ValueError( 

1034 "Could not load with experimental_load_function_aliases option" 

1035 " because the top-level object already has an attributed with name" 

1036 " 'function_aliases'" 

1037 ) 

1038 root.function_aliases = loader.function_aliases 

1039 

1040 if filters: 

1041 return {node_id: loader.get(node_id) for node_id in filters} 

1042 else: 

1043 return {"root": root}