Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/load.py: 15%

585 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras SavedModel deserialization.""" 

16 

17import re 

18import types 

19import warnings 

20 

21import tensorflow.compat.v1.logging as logging 

22import tensorflow.compat.v2 as tf 

23from google.protobuf import message 

24 

25from keras.src import backend 

26from keras.src import regularizers 

27from keras.src.engine import input_spec 

28from keras.src.optimizers.legacy import optimizer_v2 

29from keras.protobuf import saved_metadata_pb2 

30from keras.protobuf import versions_pb2 

31from keras.src.saving import object_registration 

32from keras.src.saving.legacy import model_config 

33from keras.src.saving.legacy import saving_utils 

34from keras.src.saving.legacy import serialization 

35from keras.src.saving.legacy.saved_model import constants 

36from keras.src.saving.legacy.saved_model import json_utils 

37from keras.src.saving.legacy.saved_model import utils 

38from keras.src.saving.legacy.saved_model.serialized_attributes import ( 

39 CommonEndpoints, 

40) 

41from keras.src.utils import layer_utils 

42from keras.src.utils import metrics_utils 

43from keras.src.utils import tf_inspect 

44from keras.src.utils.generic_utils import LazyLoader 

45 

46# To avoid circular dependencies between keras/engine and keras/saving, 

47# code in keras/saving must delay imports. 

48 

49# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 

50# once the issue with copybara is fixed. 

51 

52models_lib = LazyLoader("models_lib", globals(), "keras.src.models") 

53base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer") 

54layers_module = LazyLoader("layers_module", globals(), "keras.src.layers") 

55input_layer = LazyLoader("input_layer", globals(), "keras.src.engine.input_layer") 

56functional_lib = LazyLoader( 

57 "functional_lib", globals(), "keras.src.engine.functional" 

58) 

59training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training") 

60training_lib_v1 = LazyLoader( 

61 "training_lib_v1", globals(), "keras.src.engine.training_v1" 

62) 

63metrics = LazyLoader("metrics", globals(), "keras.src.metrics") 

64base_rnn = LazyLoader("base_rnn", globals(), "keras.src.layers.rnn.base_rnn") 

65 

66 

67PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union( 

68 CommonEndpoints.all_checkpointable_objects 

69) 

70PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR) 

71 

72 

73def load(path, compile=True, options=None): 

74 """Loads Keras objects from a SavedModel. 

75 

76 Any Keras layer or model saved to the SavedModel will be loaded back 

77 as Keras objects. Other objects are loaded as regular trackable objects 

78 (same as `tf.saved_model.load`). 

79 

80 Currently, Keras saving/loading only retains the Keras object's weights, 

81 losses, and call function. 

82 

83 The loaded model can be re-compiled, but the original optimizer, compiled 

84 loss functions, and metrics are not retained. This is temporary, and 

85 `model.save` will soon be able to serialize compiled models. 

86 

87 Args: 

88 path: Path to SavedModel. 

89 compile: If true, compile the model after loading it. 

90 options: Optional `tf.saved_model.LoadOptions` object that specifies 

91 options for loading from SavedModel. 

92 

93 Returns: 

94 Object loaded from SavedModel. 

95 """ 

96 # TODO(kathywu): Add saving/loading of optimizer, compiled losses and 

97 # metrics. 

98 # TODO(kathywu): Add code to load from objects that contain all endpoints 

99 

100 # Look for metadata file or parse the SavedModel 

101 metadata = saved_metadata_pb2.SavedMetadata() 

102 meta_graph_def = tf.__internal__.saved_model.parse_saved_model( 

103 path 

104 ).meta_graphs[0] 

105 object_graph_def = meta_graph_def.object_graph_def 

106 path_to_metadata_pb = tf.io.gfile.join(path, constants.SAVED_METADATA_PATH) 

107 if tf.compat.v1.gfile.Exists(path_to_metadata_pb): 

108 try: 

109 with tf.io.gfile.GFile(path_to_metadata_pb, "rb") as f: 

110 file_content = f.read() 

111 metadata.ParseFromString(file_content) 

112 except message.DecodeError as e: 

113 raise IOError( 

114 f"Cannot parse keras metadata at path {path_to_metadata_pb}: " 

115 f"Received error: {e}" 

116 ) 

117 else: 

118 logging.warning( 

119 "SavedModel saved prior to TF 2.5 detected when loading " 

120 "Keras model. Please ensure that you are saving the model " 

121 "with model.save() or tf.keras.models.save_model(), *NOT* " 

122 "tf.saved_model.save(). To confirm, there should be a file " 

123 'named "keras_metadata.pb" in the SavedModel directory.' 

124 ) 

125 _read_legacy_metadata(object_graph_def, metadata, path) 

126 

127 if not metadata.nodes: 

128 # When there are no Keras objects, return the results from the core 

129 # loader 

130 return tf.saved_model.load(path, options=options) 

131 

132 metadata = _update_to_current_version(metadata) 

133 # Recreate layers and metrics using the info stored in the metadata. 

134 keras_loader = KerasObjectLoader(metadata, object_graph_def) 

135 keras_loader.load_layers(compile=compile) 

136 

137 # Generate a dictionary of all loaded nodes. 

138 nodes_to_load = {"root": None} 

139 for node_id, loaded_node in keras_loader.loaded_nodes.items(): 

140 nodes_to_load[keras_loader.get_path(node_id)] = loaded_node 

141 with warnings.catch_warnings(): 

142 warnings.filterwarnings( 

143 "ignore", message="Trying to load ShardedVariables" 

144 ) 

145 loaded = tf.__internal__.saved_model.load_partial( 

146 path, nodes_to_load, options=options 

147 ) 

148 

149 # Finalize the loaded layers and remove the extra tracked dependencies. 

150 keras_loader.finalize_objects() 

151 keras_loader.del_tracking() 

152 

153 model = loaded["root"] 

154 

155 if isinstance(model, training_lib.Model) and compile: 

156 # TODO(kathywu): Use compiled objects from SavedModel, instead of 

157 # creating new objects from the training config. 

158 training_config = model._serialized_attributes["metadata"].get( 

159 "training_config", None 

160 ) 

161 if training_config is not None: 

162 model.compile( 

163 **saving_utils.compile_args_from_training_config( 

164 training_config 

165 ), 

166 from_serialized=True, 

167 ) 

168 saving_utils.try_build_compiled_arguments(model) 

169 if isinstance(model.optimizer, optimizer_v2.OptimizerV2): 

170 if model.optimizer.get_slot_names(): 

171 logging.warning( 

172 "Your optimizer uses slots. " 

173 "Slots cannot be restored from saved_model, " 

174 "as a result, your model is starting with " 

175 "a new initialized optimizer." 

176 ) 

177 else: 

178 logging.warning( 

179 "No training configuration found in save file, so the " 

180 "model was *not* compiled. Compile it manually." 

181 ) 

182 

183 # Force variables and resources to initialize. 

184 if not tf.executing_eagerly(): 

185 sess = backend.get_session() # Variables are initialized by this call. 

186 sess.run( 

187 tf.compat.v1.get_collection( 

188 tf.compat.v1.GraphKeys.TABLE_INITIALIZERS 

189 ) 

190 ) 

191 

192 return model 

193 

194 

195def _update_to_current_version(metadata): 

196 """Applies version updates to the metadata proto for backwards compat.""" 

197 for node in metadata.nodes: 

198 if node.version.producer == 1 and node.identifier in [ 

199 constants.MODEL_IDENTIFIER, 

200 constants.SEQUENTIAL_IDENTIFIER, 

201 constants.NETWORK_IDENTIFIER, 

202 ]: 

203 node_metadata = json_utils.decode(node.metadata) 

204 save_spec = node_metadata.get("save_spec") 

205 

206 if save_spec is not None: 

207 node_metadata["full_save_spec"] = ([save_spec], {}) 

208 node.metadata = json_utils.Encoder().encode(node_metadata) 

209 return metadata 

210 

211 

212def _read_legacy_metadata(object_graph_def, metadata, path): 

213 """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef.""" 

214 # Older SavedModels store the metadata directly in the proto instead of the 

215 # separate pb file. 

216 node_paths = _generate_object_paths(object_graph_def) 

217 for node_id, proto in enumerate(object_graph_def.nodes): 

218 if ( 

219 proto.WhichOneof("kind") == "user_object" 

220 and proto.user_object.identifier 

221 in constants.KERAS_OBJECT_IDENTIFIERS 

222 ): 

223 if not proto.user_object.metadata: 

224 raise ValueError( 

225 "Unable to create a Keras model from SavedModel at " 

226 f"{path}. This SavedModel was exported with " 

227 "`tf.saved_model.save`, and lacks the Keras metadata file. " 

228 "Please save your Keras model by calling `model.save` " 

229 "or `tf.keras.models.save_model`. Note that " 

230 "you can still load this SavedModel with " 

231 "`tf.saved_model.load`." 

232 ) 

233 metadata.nodes.add( 

234 node_id=node_id, 

235 node_path=node_paths[node_id], 

236 version=versions_pb2.VersionDef( 

237 producer=1, min_consumer=1, bad_consumers=[] 

238 ), 

239 identifier=proto.user_object.identifier, 

240 metadata=proto.user_object.metadata, 

241 ) 

242 

243 

244def _generate_object_paths(object_graph_def): 

245 """Traverses through an ObjectGraphDef and builds a map of all node 

246 paths.""" 

247 paths = {0: "root"} 

248 nodes_to_visit = [0] 

249 

250 while nodes_to_visit: 

251 current_node = nodes_to_visit.pop() 

252 current_path = paths[current_node] 

253 for reference in object_graph_def.nodes[current_node].children: 

254 if reference.node_id in paths: 

255 continue 

256 paths[reference.node_id] = f"{current_path}.{reference.local_name}" 

257 nodes_to_visit.append(reference.node_id) 

258 

259 return paths 

260 

261 

262def _is_graph_network(layer): 

263 """Determines whether the layer is a graph network.""" 

264 

265 if isinstance(layer, RevivedNetwork): 

266 return False 

267 elif isinstance(layer, functional_lib.Functional): 

268 return layer._is_graph_network or isinstance( 

269 layer, models_lib.Sequential 

270 ) 

271 return False 

272 

273 

274class KerasObjectLoader: 

275 """Loader that recreates Keras objects (e.g. 

276 

277 layers, models). 

278 

279 Layers and models are revived from either the config or SavedModel following 

280 these rules: 

281 1. If object is a graph network (i.e. Sequential or Functional) then it will 

282 be initialized using the structure from the config only after the 

283 children layers have been created. Graph networks must be initialized 

284 with inputs and outputs, so all child layers must be created beforehand. 

285 2. If object's config exists and the class can be found, then revive from 

286 config. 

287 3. Object may have already been created if its parent was revived from 

288 config. In this case, do nothing. 

289 4. If nothing of the above applies, compose the various artifacts from the 

290 SavedModel to create a subclassed layer or model. At this time, custom 

291 metrics are not supported. 

292 

293 """ 

294 

295 def __init__(self, metadata, object_graph_def): 

296 self._metadata = {x.node_id: x for x in metadata.nodes} 

297 self._proto = object_graph_def 

298 

299 self._node_paths = { 

300 node_data.node_id: node_data.node_path 

301 for node_data in metadata.nodes 

302 } 

303 self.loaded_nodes = {} # Maps node path -> loaded node 

304 

305 # Store all node ids that have already been traversed when tracking 

306 # nodes that were recreated from the config. 

307 self._traversed_nodes_from_config = set() 

308 

309 # Maps model id -> (blank model obj, list of child layer or their node 

310 # ids) This tracks all layers in functional and sequential models. These 

311 # models are only reconstructed after all of their child layers have 

312 # been created. 

313 self.model_layer_dependencies = {} 

314 self._models_to_reconstruct = [] 

315 

316 def del_tracking(self): 

317 """Removes tracked references that are only used when loading the 

318 model.""" 

319 # Now that the node object has been fully loaded, and the checkpoint has 

320 # been restored, the object no longer needs to track objects added from 

321 # SerializedAttributes. (Note that saving a training checkpoint still 

322 # functions correctly, because layers and variables are tracked 

323 # separately by the Layer object.) 

324 # TODO(kathywu): Instead of outright deleting these nodes (which would 

325 # make restoring from a different checkpoint tricky), mark them as extra 

326 # dependencies that are OK to overwrite. 

327 for node in self.loaded_nodes.values(): 

328 node = node[0] 

329 if not isinstance(node, base_layer.Layer): 

330 # Loaded nodes can contain other trackable objects created when 

331 # loading layers from the config, such as variables. 

332 continue 

333 for name in PUBLIC_ATTRIBUTES: 

334 node._delete_tracking(name) 

335 

336 if isinstance(node, functional_lib.Functional): 

337 # Delete the temporary layer dependencies, which were used to 

338 # restore the checkpointed values. When the model is live, the 

339 # user can delete or add layers to the model at any time, so 

340 # these layer dependencies may be obsolete. 

341 dependencies = list(node._self_unconditional_dependency_names) 

342 for name in dependencies: 

343 if ( 

344 re.match(r"^layer(_with_weights)?-[\d+]", name) 

345 is not None 

346 ): 

347 node._delete_tracking(name) 

348 

349 def _add_children_recreated_from_config(self, obj, proto, node_id): 

350 """Recursively records objects recreated from config.""" 

351 

352 if node_id in self._traversed_nodes_from_config: 

353 return 

354 

355 parent_path = self._node_paths[node_id] 

356 self._traversed_nodes_from_config.add(node_id) 

357 obj._maybe_initialize_trackable() 

358 if isinstance(obj, base_layer.Layer) and not obj.built: 

359 metadata = json_utils.decode(self._metadata[node_id].metadata) 

360 self._try_build_layer( 

361 obj, node_id, metadata.get("build_input_shape") 

362 ) 

363 

364 # Create list of all possible children 

365 children = [] 

366 # Look for direct children 

367 for reference in proto.children: 

368 obj_child = obj._lookup_dependency(reference.local_name) 

369 children.append( 

370 (obj_child, reference.node_id, reference.local_name) 

371 ) 

372 

373 # Add metrics that may have been added to the layer._metrics list. 

374 # This is stored in the SavedModel as layer.keras_api.layer_metrics in 

375 # SavedModels created after Tf 2.2. 

376 metric_list_node_id = self._search_for_child_node( 

377 node_id, [constants.KERAS_ATTR, "layer_metrics"] 

378 ) 

379 if metric_list_node_id is not None and hasattr(obj, "_metrics"): 

380 obj_metrics = {m.name: m for m in obj._metrics} 

381 for reference in self._proto.nodes[metric_list_node_id].children: 

382 metric = obj_metrics.get(reference.local_name) 

383 if metric is not None: 

384 metric_path = "{}.layer_metrics.{}".format( 

385 constants.KERAS_ATTR, reference.local_name 

386 ) 

387 children.append((metric, reference.node_id, metric_path)) 

388 

389 for obj_child, child_id, child_name in children: 

390 child_proto = self._proto.nodes[child_id] 

391 

392 if not isinstance(obj_child, tf.__internal__.tracking.Trackable): 

393 continue 

394 if ( 

395 child_proto.user_object.identifier 

396 in tf.__internal__.saved_model.load.registered_identifiers() 

397 ): 

398 setter = tf.__internal__.saved_model.load.get_setter( 

399 child_proto.user_object 

400 ) 

401 elif ( 

402 obj_child._object_identifier 

403 in constants.KERAS_OBJECT_IDENTIFIERS 

404 ): 

405 setter = _revive_setter 

406 else: 

407 setter = setattr 

408 

409 if child_id in self.loaded_nodes: 

410 if self.loaded_nodes[child_id][0] is not obj_child: 

411 # This means that the same trackable object is referenced by 

412 # two different objects that were recreated from the config. 

413 logging.warning( 

414 "Looks like there is an object (perhaps variable or " 

415 "layer) that is shared between different " 

416 "layers/models. This may cause issues when restoring " 

417 "the variable values. Object: {}".format(obj_child) 

418 ) 

419 continue 

420 

421 # Overwrite variable names with the ones saved in the SavedModel. 

422 if ( 

423 child_proto.WhichOneof("kind") == "variable" 

424 and child_proto.variable.name 

425 ): 

426 obj_child._handle_name = child_proto.variable.name + ":0" 

427 

428 if isinstance( 

429 obj_child, tf.__internal__.tracking.TrackableDataStructure 

430 ): 

431 setter = lambda *args: None 

432 

433 child_path = f"{parent_path}.{child_name}" 

434 self._node_paths[child_id] = child_path 

435 self._add_children_recreated_from_config( 

436 obj_child, child_proto, child_id 

437 ) 

438 self.loaded_nodes[child_id] = obj_child, setter 

439 

440 def load_layers(self, compile=True): 

441 """Load all layer nodes from the metadata.""" 

442 # Load metrics after models and layers, since it's likely that models 

443 # and layers will create the metric when initialized (this avoids 

444 # wasting time by creating objects multiple times). 

445 metric_list = [] 

446 for node_metadata in self._metadata.values(): 

447 if node_metadata.identifier == constants.METRIC_IDENTIFIER: 

448 metric_list.append(node_metadata) 

449 continue 

450 

451 self.loaded_nodes[node_metadata.node_id] = self._load_layer( 

452 node_metadata.node_id, 

453 node_metadata.identifier, 

454 node_metadata.metadata, 

455 ) 

456 

457 for node_metadata in metric_list: 

458 try: 

459 self.loaded_nodes[node_metadata.node_id] = self._load_layer( 

460 node_metadata.node_id, 

461 node_metadata.identifier, 

462 node_metadata.metadata, 

463 ) 

464 except ValueError as e: 

465 # Metrics are only needed when the model is compiled later. We 

466 # ignore errors when trying to load custom metrics when 

467 # `compile=False` until custom metrics are serialized properly 

468 # (b/135550038). 

469 if compile: 

470 raise e 

471 logging.warning( 

472 "Unable to restore custom metric. Please ensure that " 

473 "the layer implements `get_config` and `from_config` " 

474 "when saving. In addition, please use the " 

475 "`custom_objects` arg when calling `load_model()`." 

476 ) 

477 

478 def _load_layer(self, node_id, identifier, metadata): 

479 """Load a single layer from a SavedUserObject proto.""" 

480 metadata = json_utils.decode(metadata) 

481 

482 # If node was already created 

483 if node_id in self.loaded_nodes: 

484 node, setter = self.loaded_nodes[node_id] 

485 

486 # Revive setter requires the object to have a 

487 # `_serialized_attributes` property. Add it here. 

488 _maybe_add_serialized_attributes(node, metadata) 

489 

490 config = metadata.get("config") 

491 if _is_graph_network(node) and serialization.validate_config( 

492 config 

493 ): 

494 child_nodes = self._get_child_layer_node_ids(node_id) 

495 self.model_layer_dependencies[node_id] = (node, child_nodes) 

496 if not child_nodes: 

497 self._models_to_reconstruct.append(node_id) 

498 return node, setter 

499 

500 # Detect whether this object can be revived from the config. If not, 

501 # then revive from the SavedModel instead. 

502 obj, setter = self._revive_from_config(identifier, metadata, node_id) 

503 if obj is None: 

504 obj, setter = revive_custom_object(identifier, metadata) 

505 

506 # Add an attribute that stores the extra functions/objects saved in the 

507 # SavedModel. Most of these functions/objects are ignored, but some are 

508 # used later in the loading process (e.g. the list of regularization 

509 # losses, or the training config of compiled models). 

510 _maybe_add_serialized_attributes(obj, metadata) 

511 return obj, setter 

512 

513 def _revive_from_config(self, identifier, metadata, node_id): 

514 """Revives a layer/model from config, or returns None.""" 

515 if identifier == constants.METRIC_IDENTIFIER: 

516 obj = self._revive_metric_from_config(metadata) 

517 else: 

518 obj = self._revive_graph_network( 

519 identifier, metadata, node_id 

520 ) or self._revive_layer_or_model_from_config(metadata, node_id) 

521 

522 if obj is None: 

523 return None, None 

524 

525 setter = self._config_node_setter(_revive_setter) 

526 self._add_children_recreated_from_config( 

527 obj, self._proto.nodes[node_id], node_id 

528 ) 

529 return obj, setter 

530 

531 def _revive_graph_network(self, identifier, metadata, node_id): 

532 """Revives a graph network from config.""" 

533 # Determine whether the metadata contains information for reviving a 

534 # functional or Sequential model. 

535 config = metadata.get("config") 

536 if not serialization.validate_config(config): 

537 return None 

538 

539 class_name = tf.compat.as_str(metadata["class_name"]) 

540 if object_registration.get_registered_object(class_name) is not None: 

541 return None 

542 model_is_functional_or_sequential = ( 

543 metadata.get("is_graph_network", False) 

544 or class_name == "Sequential" 

545 or class_name == "Functional" 

546 ) 

547 if not model_is_functional_or_sequential: 

548 return None 

549 

550 # Revive functional and sequential models as blank model objects for now 

551 # ( must be initialized to enable setattr tracking and attribute 

552 # caching). Reconstruction of the network is deferred until all of the 

553 # model's layers have been revived. 

554 if class_name == "Sequential": 

555 model = models_lib.Sequential(name=config["name"]) 

556 # The model is a custom Sequential model. 

557 elif identifier == constants.SEQUENTIAL_IDENTIFIER: 

558 # Uses the custom class name, since the config does not have one. 

559 model = models_lib.Sequential(name=class_name) 

560 else: 

561 model = models_lib.Functional( 

562 inputs=[], outputs=[], name=config["name"] 

563 ) 

564 

565 # Record this model and its layers. This will later be used to 

566 # reconstruct the model. 

567 layers = self._get_child_layer_node_ids(node_id) 

568 self.model_layer_dependencies[node_id] = (model, layers) 

569 if not layers: 

570 self._models_to_reconstruct.append(node_id) 

571 return model 

572 

573 def _revive_layer_or_model_from_config(self, metadata, node_id): 

574 """Revives a layer/custom model from config; returns None if 

575 infeasible.""" 

576 # Check that the following requirements are met for reviving from 

577 # config: 

578 # 1. Object can be deserialized from config. 

579 # 2. If the object needs to be built, then the build input shape can 

580 # be found. 

581 class_name = metadata.get("class_name") 

582 config = metadata.get("config") 

583 shared_object_id = metadata.get("shared_object_id") 

584 must_restore_from_config = metadata.get("must_restore_from_config") 

585 if not serialization.validate_config(config): 

586 return None 

587 

588 try: 

589 try: 

590 obj = model_config.model_from_config( 

591 serialization.serialize_keras_class_and_config( 

592 class_name, config, shared_object_id=shared_object_id 

593 ) 

594 ) 

595 except (TypeError, KeyError) as e: 

596 # A name conflict has occurred. The `class_name` is in the Keras 

597 # native framework; however, the value in the framework is 

598 # different from the user's class definition which confuses the 

599 # KerasObjectLoader. 

600 builtin_layer = layers_module.get_builtin_layer(class_name) 

601 if builtin_layer: 

602 raise RuntimeError( 

603 f"Unable to restore object of class '{class_name}'. " 

604 "One of several possible causes could be " 

605 "a missing custom object. " 

606 "Decorate your custom object with " 

607 "`@keras.utils.register_keras_serializable()` and " 

608 "include that file in your program, " 

609 "or pass your class in a " 

610 "`keras.utils.CustomObjectScope` " 

611 "that wraps this load call. " 

612 f"\n\nException: {e}" 

613 ) from e 

614 else: 

615 raise 

616 except Exception as e: 

617 if must_restore_from_config: 

618 raise e 

619 else: 

620 return None 

621 

622 # Use the dtype, name, and trainable status. Often times these are not 

623 # specified in custom configs, so retrieve their values from the 

624 # metadata. 

625 

626 obj._name = metadata["name"] 

627 if metadata.get("trainable") is not None: 

628 obj.trainable = metadata["trainable"] 

629 if metadata.get("dtype") is not None: 

630 obj._set_dtype_policy(metadata["dtype"]) 

631 if metadata.get("stateful") is not None: 

632 obj.stateful = metadata["stateful"] 

633 if metadata.get("autocast") is not None: 

634 obj._autocast = metadata["autocast"] 

635 # Restore model save spec for subclassed models. (layers do not store a 

636 # SaveSpec) 

637 if isinstance(obj, training_lib.Model): 

638 full_save_spec = metadata.get("full_save_spec") 

639 if full_save_spec is not None: 

640 args_spec, kwargs_spec = full_save_spec 

641 inputs_spec = args_spec.pop(0) 

642 obj._set_save_spec(inputs_spec, args_spec, kwargs_spec) 

643 

644 build_input_shape = metadata.get("build_input_shape") 

645 built = self._try_build_layer(obj, node_id, build_input_shape) 

646 

647 if not built: 

648 # If the layer cannot be built, revive a custom layer instead. 

649 return None 

650 return obj 

651 

652 def _revive_metric_from_config(self, metadata): 

653 """Revives a metric object using the config saved in the metadata.""" 

654 class_name = tf.compat.as_str(metadata["class_name"]) 

655 config = metadata.get("config") 

656 

657 if not serialization.validate_config(config): 

658 return None 

659 

660 try: 

661 obj = metrics.deserialize( 

662 serialization.serialize_keras_class_and_config( 

663 class_name, config 

664 ) 

665 ) 

666 except ValueError: 

667 return None 

668 

669 build_input_shape = metadata.get("build_input_shape") 

670 if build_input_shape is not None and hasattr(obj, "_build"): 

671 obj._build(build_input_shape) 

672 

673 return obj 

674 

675 def _try_build_layer(self, obj, node_id, build_input_shape): 

676 """Attempts to build the layer.""" 

677 if obj.built or hasattr(obj.build, "_is_default"): 

678 obj.built = True 

679 return True 

680 

681 if build_input_shape is None: 

682 build_input_shape = self._infer_inputs( 

683 node_id, convert_to_shapes=True 

684 ) 

685 

686 if build_input_shape is not None: 

687 obj.build(build_input_shape) 

688 base_layer.Layer.build(obj, build_input_shape) 

689 return True 

690 

691 return False 

692 

693 def get_path(self, node_id): 

694 return self._node_paths[node_id] 

695 

696 def finalize_objects(self): 

697 """Finish setting up Keras objects. 

698 

699 This function is executed after all objects and functions have been 

700 created. Call functions and losses are attached to each layer, and once 

701 all layers have been fully set up, graph networks are initialized. 

702 

703 Subclassed models that are revived from the SavedModel are treated like 

704 layers, and have their call/loss functions attached here. 

705 """ 

706 # Finish setting up layers and subclassed models. This step attaches 

707 # call functions and losses to each object, and sets model 

708 # inputs/outputs. 

709 layers_revived_from_config = [] 

710 layers_revived_from_saved_model = [] 

711 for node_id, (node, _) in self.loaded_nodes.items(): 

712 if ( 

713 not isinstance(node, base_layer.Layer) 

714 # Don't finalize models until all layers have finished loading. 

715 or node_id in self.model_layer_dependencies 

716 ): 

717 continue 

718 

719 self._unblock_model_reconstruction(node_id, node) 

720 

721 if isinstance(node, input_layer.InputLayer): 

722 continue 

723 elif isinstance(node, metrics.Metric): 

724 continue 

725 

726 if isinstance(node, (RevivedLayer, RevivedInputLayer)): 

727 layers_revived_from_saved_model.append(node) 

728 else: 

729 layers_revived_from_config.append(node) 

730 

731 _finalize_saved_model_layers(layers_revived_from_saved_model) 

732 _finalize_config_layers(layers_revived_from_config) 

733 

734 # Initialize graph networks, now that layer dependencies have been 

735 # resolved. 

736 self._reconstruct_all_models() 

737 

738 def _unblock_model_reconstruction(self, layer_id, layer): 

739 """Removes layer from blocking model reconstruction.""" 

740 for model_id, v in self.model_layer_dependencies.items(): 

741 _, layers = v 

742 if layer_id not in layers: 

743 continue 

744 layers[layers.index(layer_id)] = layer 

745 if all(isinstance(x, base_layer.Layer) for x in layers): 

746 self._models_to_reconstruct.append(model_id) 

747 

748 def _reconstruct_all_models(self): 

749 """Reconstructs the network structure of all models.""" 

750 all_initialized_models = set() 

751 while self._models_to_reconstruct: 

752 model_id = self._models_to_reconstruct.pop(0) 

753 all_initialized_models.add(model_id) 

754 model, layers = self.model_layer_dependencies[model_id] 

755 self._reconstruct_model(model_id, model, layers) 

756 _finalize_config_layers([model]) 

757 

758 if all_initialized_models != set(self.model_layer_dependencies.keys()): 

759 # This should not happen. 

760 uninitialized_model_ids = ( 

761 set(self.model_layer_dependencies.keys()) 

762 - all_initialized_models 

763 ) 

764 uninitialized_model_names = [ 

765 self.model_layer_dependencies[model_id][0].name 

766 for model_id in uninitialized_model_ids 

767 ] 

768 raise ValueError( 

769 "Error loading model(s) in the SavedModel format. " 

770 "The following model(s) could not be initialized: " 

771 f"{uninitialized_model_names}" 

772 ) 

773 

774 def _reconstruct_model(self, model_id, model, layers): 

775 """Reconstructs the network structure.""" 

776 config = json_utils.decode(self._metadata[model_id].metadata)["config"] 

777 

778 # Set up model inputs 

779 if model.inputs: 

780 # Inputs may already be created if the model is instantiated in 

781 # another object's __init__. 

782 pass 

783 elif isinstance(model, models_lib.Sequential): 

784 if not layers or not isinstance(layers[0], input_layer.InputLayer): 

785 if config["layers"][0]["class_name"] == "InputLayer": 

786 layers.insert( 

787 0, 

788 input_layer.InputLayer.from_config( 

789 config["layers"][0]["config"] 

790 ), 

791 ) 

792 elif "batch_input_shape" in config["layers"][0]["config"]: 

793 batch_input_shape = config["layers"][0]["config"][ 

794 "batch_input_shape" 

795 ] 

796 layers.insert( 

797 0, 

798 input_layer.InputLayer( 

799 input_shape=batch_input_shape[1:], 

800 batch_size=batch_input_shape[0], 

801 dtype=layers[0].dtype, 

802 name=layers[0].name + "_input", 

803 ), 

804 ) 

805 model.__init__(layers, name=config["name"]) 

806 if not model.inputs: 

807 first_layer = self._get_child_layer_node_ids(model_id)[0] 

808 input_specs = self._infer_inputs(first_layer) 

809 input_shapes = self._infer_inputs( 

810 first_layer, convert_to_shapes=True 

811 ) 

812 model._set_inputs(input_specs) 

813 if not model.built and not isinstance(input_specs, dict): 

814 model.build(input_shapes) 

815 else: # Reconstruct functional model 

816 ( 

817 inputs, 

818 outputs, 

819 created_layers, 

820 ) = functional_lib.reconstruct_from_config( 

821 config, created_layers={layer.name: layer for layer in layers} 

822 ) 

823 model.__init__(inputs, outputs, name=config["name"]) 

824 functional_lib.connect_ancillary_layers(model, created_layers) 

825 

826 # Set model dtype. 

827 _set_network_attributes_from_metadata(model) 

828 

829 # Unblock models that are dependent on this model. 

830 self._unblock_model_reconstruction(model_id, model) 

831 

832 def _get_child_layer_node_ids(self, node_id): 

833 """Returns the node ids of each layer in a Sequential/Functional 

834 model.""" 

835 # Sequential and Functional track layers with names following the format 

836 # "layer-N". Use this to generate the list of layers. 

837 num_layers = 0 

838 child_layers = {} 

839 pattern = re.compile("layer-(\\d+)") 

840 

841 for child in self._proto.nodes[node_id].children: 

842 m = pattern.match(child.local_name) 

843 if m is None: 

844 continue 

845 layer_n = int(m.group(1)) 

846 num_layers = max(layer_n + 1, num_layers) 

847 child_layers[layer_n] = child.node_id 

848 

849 ordered = [] 

850 for n in range(num_layers): 

851 child = child_layers.get(n) 

852 if child is None: 

853 break 

854 ordered.append(child) 

855 return ordered 

856 

857 def _search_for_child_node(self, parent_id, path_to_child): 

858 """Returns node id of child node. 

859 

860 A helper method for traversing the object graph proto. 

861 

862 As an example, say that the object graph proto in the SavedModel 

863 contains an object with the following child and grandchild attributes: 

864 

865 `parent.child_a.child_b` 

866 

867 This method can be used to retrieve the node id of `child_b` using the 

868 parent's node id by calling: 

869 

870 `_search_for_child_node(parent_id, ['child_a', 'child_b'])`. 

871 

872 Args: 

873 parent_id: node id of parent node 

874 path_to_child: list of children names. 

875 

876 Returns: 

877 node_id of child, or None if child isn't found. 

878 """ 

879 if not path_to_child: 

880 return parent_id 

881 

882 for child in self._proto.nodes[parent_id].children: 

883 if child.local_name == path_to_child[0]: 

884 return self._search_for_child_node( 

885 child.node_id, path_to_child[1:] 

886 ) 

887 return None 

888 

889 def _infer_inputs(self, layer_node_id, convert_to_shapes=False): 

890 """Infers input shape of layer from SavedModel functions.""" 

891 call_fn_id = self._search_for_child_node( 

892 layer_node_id, ["call_and_return_all_conditional_losses"] 

893 ) 

894 if call_fn_id is None: 

895 return None 

896 

897 concrete_functions = self._proto.nodes[ 

898 call_fn_id 

899 ].function.concrete_functions 

900 if not concrete_functions: 

901 return None 

902 call_fn_name = concrete_functions[0] 

903 call_fn_proto = self._proto.concrete_functions[call_fn_name] 

904 structured_input_signature = tf.__internal__.saved_model.decode_proto( 

905 call_fn_proto.canonicalized_input_signature 

906 ) 

907 inputs = structured_input_signature[0][0] 

908 if convert_to_shapes: 

909 return tf.nest.map_structure(lambda spec: spec.shape, inputs) 

910 else: 

911 return inputs 

912 

913 def _config_node_setter(self, setter): 

914 """Creates edges for nodes that are recreated from config.""" 

915 

916 def setattr_wrapper(obj, name, value): 

917 # Avoid overwriting attributes of objects recreated from the config. 

918 if obj._lookup_dependency(name) is None: 

919 setter(obj, name, value) 

920 

921 return setattr_wrapper 

922 

923 

924def _finalize_saved_model_layers(layers): 

925 """Runs the final steps of loading Keras Layers from SavedModel.""" 

926 

927 # 1. Set up call functions for all layers initialized from the SavedModel ( 

928 # and not the config) 

929 for layer in layers: 

930 layer.built = True 

931 layer_call = getattr( 

932 _get_keras_attr(layer), "call_and_return_conditional_losses", None 

933 ) 

934 if layer_call and layer_call.concrete_functions: 

935 call_spec = layer_utils.CallFunctionSpec( 

936 tf_inspect.getfullargspec(layer_call) 

937 ) 

938 layer.call = utils.use_wrapped_call( 

939 layer, layer_call, call_spec, return_method=True 

940 ) 

941 expects_training_arg = layer._serialized_attributes["metadata"][ 

942 "expects_training_arg" 

943 ] 

944 if "training" in layer_call.function_spec.arg_names: 

945 # This could change the value of `expects_training_arg` if this 

946 # layer doesn't expect a training arg, but has a child layer 

947 # that does. 

948 expects_training_arg = True 

949 layer._init_call_fn_args(expects_training_arg) 

950 else: 

951 layer.call = types.MethodType( 

952 _unable_to_call_layer_due_to_serialization_issue, layer 

953 ) 

954 

955 for layer in layers: 

956 # 2. Set model inputs and outputs. 

957 if isinstance(layer, RevivedNetwork): 

958 _set_network_attributes_from_metadata(layer) 

959 

960 if hasattr( 

961 _get_keras_attr(layer), "call_and_return_conditional_losses" 

962 ): 

963 call_fn = _get_keras_attr( 

964 layer 

965 ).call_and_return_conditional_losses 

966 if not call_fn.concrete_functions: 

967 continue 

968 if call_fn.input_signature is None: 

969 args, kwargs = infer_inputs_from_restored_call_function( 

970 call_fn 

971 ) 

972 args = list(args) 

973 inputs = args.pop(0) 

974 else: 

975 args = call_fn.input_signature 

976 args = list(args) 

977 inputs = args.pop(0) 

978 kwargs = None 

979 layer._set_save_spec(inputs, args, kwargs) 

980 

981 # V1 models require calling _set_inputs to set the `.inputs` 

982 # attr. Skip this step when there are multiple tensor inputs 

983 # (this behavior is not well supported in V1 models). 

984 if not any( 

985 isinstance(x, tf.TensorSpec) 

986 for x in tf.nest.flatten([args, kwargs]) 

987 ): 

988 layer._set_inputs(inputs) 

989 

990 # 3. Add losses that aren't generated by the layer.call function. 

991 _restore_layer_unconditional_losses(layer) 

992 _restore_layer_activation_loss(layer) 

993 

994 # 4. Restore metrics list 

995 _restore_layer_metrics(layer) 

996 

997 

998def _unable_to_call_layer_due_to_serialization_issue( 

999 layer, *unused_args, **unused_kwargs 

1000): 

1001 """Replaces the `layer.call` if the layer was not fully serialized. 

1002 

1003 Keras Model/Layer serialization is relatively relaxed because SavedModels 

1004 are not always loaded back as keras models. Thus, when there is an issue 

1005 tracing a non-signature function, a warning is logged instead of raising an 

1006 error. This results in a SavedModel where the model's call function is 

1007 saved, but the internal layer call functions are not. 

1008 

1009 When deserialized with `tf.keras.models.load_model`, the internal layers 

1010 which do not have serialized call functions should raise an error when 

1011 called. 

1012 

1013 Args: 

1014 layer: Layer without the serialized call function. 

1015 

1016 Raises: 

1017 ValueError 

1018 """ 

1019 

1020 raise ValueError( 

1021 f"Cannot call custom layer {layer.name} of type {type(layer)}, because " 

1022 "the call function was not serialized to the SavedModel." 

1023 "Please try one of the following methods to fix this issue:" 

1024 "\n\n(1) Implement `get_config` and `from_config` in the layer/model " 

1025 "class, and pass the object to the `custom_objects` argument when " 

1026 "loading the model. For more details, see: " 

1027 "https://www.tensorflow.org/guide/keras/save_and_serialize" 

1028 "\n\n(2) Ensure that the subclassed model or layer overwrites `call` " 

1029 "and not `__call__`. The input shape and dtype will be automatically " 

1030 "recorded when the object is called, and used when saving. To manually " 

1031 "specify the input shape/dtype, decorate the call function with " 

1032 "`@tf.function(input_signature=...)`." 

1033 ) 

1034 

1035 

1036def _finalize_config_layers(layers): 

1037 """Runs the final steps of loading Keras Layers from config.""" 

1038 for layer in layers: 

1039 # It is assumed that layers define their unconditional losses after 

1040 # being recreated from the config and built. The exceptions to this are 

1041 # Functional and Sequential models, which only store conditional losses 

1042 # (losses dependent on the inputs) in the config. Unconditional losses 

1043 # like weight regularization must be revived from the SavedModel. 

1044 if _is_graph_network(layer): 

1045 _restore_layer_unconditional_losses(layer) 

1046 

1047 # Some layers, like Dense, record their activation loss function in the 

1048 # config. However, not all layers do this, so the activation loss may be 

1049 # missing when restored from the config/hdf5. 

1050 # TODO(kathywu): Investigate ways to improve the config to ensure 

1051 # consistent loading behavior between HDF5 and SavedModel. 

1052 _restore_layer_activation_loss(layer) 

1053 

1054 # Restore metrics list. 

1055 _restore_layer_metrics(layer) 

1056 

1057 # Restore RNN layer states. 

1058 if ( 

1059 isinstance(layer, base_rnn.RNN) 

1060 and layer.stateful 

1061 and hasattr(_get_keras_attr(layer), "states") 

1062 ): 

1063 layer.states = getattr(_get_keras_attr(layer), "states", None) 

1064 for variable in tf.nest.flatten(layer.states): 

1065 backend.track_variable(variable) 

1066 

1067 # Perform any layer defined finalization of the layer state. 

1068 layer.finalize_state() 

1069 

1070 

1071def _finalize_metric(metric): 

1072 metric.update_state = types.MethodType( 

1073 metrics_utils.update_state_wrapper(metric.keras_api.update_state), 

1074 metric, 

1075 ) 

1076 metric.result = metric.keras_api.result 

1077 

1078 

1079def _restore_layer_unconditional_losses(layer): 

1080 """Restore unconditional losses from SavedModel.""" 

1081 if hasattr(_get_keras_attr(layer), "layer_regularization_losses"): 

1082 losses = getattr( 

1083 _get_keras_attr(layer), "layer_regularization_losses", [] 

1084 ) 

1085 else: 

1086 # Some earlier SavedModels may not have layer_regularization_losses 

1087 # serialized separately. Fall back to using the regularization_losses 

1088 # list if it does not exist. 

1089 losses = layer._serialized_attributes.get("regularization_losses", []) 

1090 for loss in losses: 

1091 layer.add_loss(loss) 

1092 

1093 

1094def _restore_layer_activation_loss(layer): 

1095 """Restore actiation loss from SavedModel.""" 

1096 # Use wrapped activity regularizer function if the layer's activity 

1097 # regularizer wasn't created during initialization. 

1098 activity_regularizer = getattr( 

1099 _get_keras_attr(layer), "activity_regularizer_fn", None 

1100 ) 

1101 if activity_regularizer and not layer.activity_regularizer: 

1102 try: 

1103 layer.activity_regularizer = activity_regularizer 

1104 except AttributeError: 

1105 # This may happen if a layer wrapper is saved with an activity 

1106 # regularizer. The wrapper object's activity regularizer is 

1107 # unsettable. 

1108 pass 

1109 

1110 

1111def revive_custom_object(identifier, metadata): 

1112 """Revives object from SavedModel.""" 

1113 if tf.compat.v1.executing_eagerly_outside_functions(): 

1114 model_class = training_lib.Model 

1115 else: 

1116 model_class = training_lib_v1.Model 

1117 

1118 revived_classes = { 

1119 constants.INPUT_LAYER_IDENTIFIER: ( 

1120 RevivedInputLayer, 

1121 input_layer.InputLayer, 

1122 ), 

1123 constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer), 

1124 constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class), 

1125 constants.NETWORK_IDENTIFIER: ( 

1126 RevivedNetwork, 

1127 functional_lib.Functional, 

1128 ), 

1129 constants.SEQUENTIAL_IDENTIFIER: ( 

1130 RevivedNetwork, 

1131 models_lib.Sequential, 

1132 ), 

1133 } 

1134 parent_classes = revived_classes.get(identifier, None) 

1135 

1136 class_name = tf.compat.as_str(metadata["class_name"]) 

1137 if parent_classes is not None: 

1138 parent_classes = revived_classes[identifier] 

1139 revived_cls = type(class_name, parent_classes, {}) 

1140 return revived_cls._init_from_metadata(metadata) 

1141 else: 

1142 raise ValueError( 

1143 f'Unable to restore custom object of class "{class_name}" ' 

1144 f"(type {identifier}). Please make sure that this class is " 

1145 "included in the `custom_objects` arg when calling `load_model()`. " 

1146 "Also, check that the class implements `get_config` and " 

1147 f"`from_config`.\n\nComplete metadata: {metadata}" 

1148 ) 

1149 

1150 

1151def _restore_layer_metrics(layer): 

1152 metrics_list = getattr(_get_keras_attr(layer), "layer_metrics", {}) 

1153 layer_metrics = {m.name: m for m in layer._metrics} 

1154 for name, metric in metrics_list.items(): 

1155 if name not in layer_metrics: 

1156 # Metrics may be added during initialization/building of custom 

1157 # layers. 

1158 layer._metrics.append(metric) 

1159 

1160 

1161# TODO(kathywu): Centrally define keys and functions for both serialization and 

1162# deserialization. 

1163class RevivedLayer: 

1164 """Keras layer loaded from a SavedModel.""" 

1165 

1166 @classmethod 

1167 def _init_from_metadata(cls, metadata): 

1168 """Create revived layer from metadata stored in the SavedModel proto.""" 

1169 init_args = dict(name=metadata["name"], trainable=metadata["trainable"]) 

1170 if metadata.get("dtype") is not None: 

1171 init_args["dtype"] = metadata["dtype"] 

1172 if metadata.get("batch_input_shape") is not None: 

1173 init_args["batch_input_shape"] = metadata["batch_input_shape"] 

1174 

1175 revived_obj = cls(**init_args) 

1176 

1177 with utils.no_automatic_dependency_tracking_scope(revived_obj): 

1178 

1179 revived_obj._call_spec.expects_training_arg = metadata[ 

1180 "expects_training_arg" 

1181 ] 

1182 config = metadata.get("config") 

1183 if serialization.validate_config(config): 

1184 revived_obj._config = config 

1185 if metadata.get("input_spec") is not None: 

1186 revived_obj.input_spec = recursively_deserialize_keras_object( 

1187 metadata["input_spec"], 

1188 module_objects={"InputSpec": input_spec.InputSpec}, 

1189 ) 

1190 if metadata.get("activity_regularizer") is not None: 

1191 revived_obj.activity_regularizer = regularizers.deserialize( 

1192 metadata["activity_regularizer"] 

1193 ) 

1194 if metadata.get("_is_feature_layer") is not None: 

1195 revived_obj._is_feature_layer = metadata["_is_feature_layer"] 

1196 if metadata.get("stateful") is not None: 

1197 revived_obj.stateful = metadata["stateful"] 

1198 if metadata.get("autocast") is not None: 

1199 revived_obj._autocast = metadata["autocast"] 

1200 if metadata.get("preserve_input_structure_in_config") is not None: 

1201 revived_obj._preserve_input_structure_in_config = metadata[ 

1202 "preserve_input_structure_in_config" 

1203 ] 

1204 

1205 return revived_obj, _revive_setter 

1206 

1207 @property 

1208 def keras_api(self): 

1209 return self._serialized_attributes.get(constants.KERAS_ATTR, None) 

1210 

1211 def get_config(self): 

1212 if hasattr(self, "_config"): 

1213 return self._config 

1214 else: 

1215 raise NotImplementedError 

1216 

1217 

1218def _revive_setter(layer, name, value): 

1219 """Setter function that saves some attributes to separate dictionary.""" 

1220 # Many attributes in the SavedModel conflict with properties defined in 

1221 # Layer and Model. Save these attributes to a separate dictionary. 

1222 if name in PUBLIC_ATTRIBUTES: 

1223 

1224 if isinstance(value, tf.__internal__.tracking.Trackable): 

1225 layer._track_trackable(value, name=name) 

1226 layer._serialized_attributes[name] = value 

1227 

1228 elif ( 

1229 isinstance(layer, functional_lib.Functional) 

1230 and re.match(r"^layer(_with_weights)?-[\d+]", name) is not None 

1231 ): 

1232 # Edges named "layer-n" or "layer_with_weights-n", which are tracked in 

1233 # network._track_layers, should not be added as an attribute. They 

1234 # should be temporarily added as a dependency so that checkpointed 

1235 # values can be restored. These dependencies are manually deleted in 

1236 # KerasObjectLoader.del_tracking. 

1237 

1238 # Set `overwrite=True` in the case that `layer` already tracks a 

1239 # different layer-n. This may cause variable values to not be loaded 

1240 # properly in the original layer-n, but we already warn the users about 

1241 # this (ctrl-f "shared between different layers/models"). 

1242 layer._track_trackable(value, name, overwrite=True) 

1243 elif getattr(layer, name, None) is not None: 

1244 # Don't overwrite already defined attributes. 

1245 pass 

1246 else: 

1247 setattr(layer, name, value) 

1248 

1249 

1250class RevivedInputLayer: 

1251 """InputLayer loaded from a SavedModel.""" 

1252 

1253 @classmethod 

1254 def _init_from_metadata(cls, metadata): 

1255 """Revives the saved InputLayer from the Metadata.""" 

1256 init_args = dict( 

1257 name=metadata["name"], 

1258 dtype=metadata["dtype"], 

1259 sparse=metadata["sparse"], 

1260 ragged=metadata["ragged"], 

1261 batch_input_shape=metadata["batch_input_shape"], 

1262 ) 

1263 revived_obj = cls(**init_args) 

1264 with utils.no_automatic_dependency_tracking_scope(revived_obj): 

1265 revived_obj._config = metadata["config"] 

1266 

1267 return revived_obj, setattr 

1268 

1269 def get_config(self): 

1270 return self._config 

1271 

1272 

1273def recursively_deserialize_keras_object(config, module_objects=None): 

1274 """Deserialize Keras object from a nested structure.""" 

1275 if isinstance(config, dict): 

1276 if "class_name" in config: 

1277 return serialization.deserialize_keras_object( 

1278 config, module_objects=module_objects 

1279 ) 

1280 else: 

1281 return { 

1282 key: recursively_deserialize_keras_object( 

1283 config[key], module_objects 

1284 ) 

1285 for key in config 

1286 } 

1287 elif isinstance(config, (tuple, list)): 

1288 return [ 

1289 recursively_deserialize_keras_object(x, module_objects) 

1290 for x in config 

1291 ] 

1292 else: 

1293 raise ValueError( 

1294 "Unable to decode Keras layer config. Config should be a " 

1295 f"dictionary, tuple or list. Received: config={config}" 

1296 ) 

1297 

1298 

1299def infer_inputs_from_restored_call_function(fn): 

1300 """Returns TypeSpec of inputs from a restored call function. 

1301 

1302 Args: 

1303 fn: Restored layer call function. It is assumed that `fn` has at least one 

1304 concrete function and that the inputs are in the first argument. 

1305 

1306 Returns: 

1307 TypeSpec of call function inputs in the form of (args, kwargs) 

1308 """ 

1309 

1310 def common_spec(x, y): 

1311 if not isinstance(x, tf.TypeSpec): 

1312 # Doesn't particularly matter what is returned in this case because 

1313 # the result will be filtered out in _set_input_shape. 

1314 return x 

1315 

1316 result = x._without_tensor_names().most_specific_common_supertype( 

1317 [y._without_tensor_names()] 

1318 ) 

1319 if result is None: 

1320 # Please file a bug if you are being hindered by this error. 

1321 raise TypeError(f"No common supertype of {x} and {y}.") 

1322 return result 

1323 

1324 spec = fn.concrete_functions[0].structured_input_signature 

1325 for concrete in fn.concrete_functions[1:]: 

1326 spec2 = concrete.structured_input_signature 

1327 spec = tf.nest.map_structure(common_spec, spec, spec2) 

1328 return spec 

1329 

1330 

1331class RevivedNetwork(RevivedLayer): 

1332 """Keras network of layers loaded from a SavedModel.""" 

1333 

1334 @classmethod 

1335 def _init_from_metadata(cls, metadata): 

1336 """Create revived network from metadata stored in the SavedModel 

1337 proto.""" 

1338 revived_obj = cls(name=metadata["name"]) 

1339 

1340 # Store attributes revived from SerializedAttributes in a un-tracked 

1341 # dictionary. The attributes are the ones listed in CommonEndpoints or 

1342 # "keras_api" for keras-specific attributes. 

1343 with utils.no_automatic_dependency_tracking_scope(revived_obj): 

1344 

1345 revived_obj._call_spec.expects_training_arg = metadata[ 

1346 "expects_training_arg" 

1347 ] 

1348 config = metadata.get("config") 

1349 if serialization.validate_config(config): 

1350 revived_obj._config = config 

1351 

1352 if metadata.get("activity_regularizer") is not None: 

1353 revived_obj.activity_regularizer = regularizers.deserialize( 

1354 metadata["activity_regularizer"] 

1355 ) 

1356 if metadata.get("autocast") is not None: 

1357 revived_obj._autocast = metadata["autocast"] 

1358 

1359 return revived_obj, _revive_setter 

1360 

1361 

1362def _set_network_attributes_from_metadata(revived_obj): 

1363 """Sets attributes recorded in the metadata.""" 

1364 with utils.no_automatic_dependency_tracking_scope(revived_obj): 

1365 

1366 metadata = revived_obj._serialized_attributes["metadata"] 

1367 if metadata.get("dtype") is not None: 

1368 revived_obj._set_dtype_policy(metadata["dtype"]) 

1369 revived_obj._trainable = metadata["trainable"] 

1370 

1371 

1372def _maybe_add_serialized_attributes(layer, metadata): 

1373 # Store attributes revived from SerializedAttributes in a un-tracked 

1374 # dictionary. The attributes are the ones listed in CommonEndpoints or 

1375 # "keras_api" for keras-specific attributes. 

1376 if not hasattr(layer, "_serialized_attributes"): 

1377 with utils.no_automatic_dependency_tracking_scope(layer): 

1378 layer._serialized_attributes = {"metadata": metadata} 

1379 

1380 

1381def _get_keras_attr(layer): 

1382 return getattr(layer, "_serialized_attributes", {}).get( 

1383 constants.KERAS_ATTR, None 

1384 ) 

1385