Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py: 17%

579 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras SavedModel deserialization.""" 

16 

17import os 

18import re 

19import types 

20 

21from google.protobuf import message 

22 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import sparse_tensor 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.framework import tensor_spec 

28from tensorflow.python.keras import backend 

29from tensorflow.python.keras import regularizers 

30from tensorflow.python.keras.engine import input_spec 

31from tensorflow.python.keras.optimizer_v2 import optimizer_v2 

32from tensorflow.python.keras.protobuf import saved_metadata_pb2 

33from tensorflow.python.keras.protobuf import versions_pb2 

34from tensorflow.python.keras.saving import saving_utils 

35from tensorflow.python.keras.saving.saved_model import constants 

36from tensorflow.python.keras.saving.saved_model import json_utils 

37from tensorflow.python.keras.saving.saved_model import utils 

38from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints 

39from tensorflow.python.keras.utils import generic_utils 

40from tensorflow.python.keras.utils import metrics_utils 

41from tensorflow.python.keras.utils.generic_utils import LazyLoader 

42from tensorflow.python.ops.ragged import ragged_tensor 

43from tensorflow.python.platform import gfile 

44from tensorflow.python.platform import tf_logging as logging 

45from tensorflow.python.saved_model import load as tf_load 

46from tensorflow.python.saved_model import loader_impl 

47from tensorflow.python.saved_model import nested_structure_coder 

48from tensorflow.python.saved_model import revived_types 

49from tensorflow.python.trackable import base as trackable 

50from tensorflow.python.trackable import data_structures 

51from tensorflow.python.util import compat 

52from tensorflow.python.util import nest 

53 

54# To avoid circular dependencies between keras/engine and keras/saving, 

55# code in keras/saving must delay imports. 

56 

57# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 

58# once the issue with copybara is fixed. 

59# pylint:disable=g-inconsistent-quotes 

60models_lib = LazyLoader("models_lib", globals(), 

61 "tensorflow.python.keras.models") 

62base_layer = LazyLoader( 

63 "base_layer", globals(), 

64 "tensorflow.python.keras.engine.base_layer") 

65layers_module = LazyLoader( 

66 "layers_module", globals(), 

67 "tensorflow.python.keras.layers") 

68input_layer = LazyLoader( 

69 "input_layer", globals(), 

70 "tensorflow.python.keras.engine.input_layer") 

71functional_lib = LazyLoader( 

72 "functional_lib", globals(), 

73 "tensorflow.python.keras.engine.functional") 

74training_lib = LazyLoader( 

75 "training_lib", globals(), 

76 "tensorflow.python.keras.engine.training") 

77training_lib_v1 = LazyLoader( 

78 "training_lib_v1", globals(), 

79 "tensorflow.python.keras.engine.training_v1") 

80metrics = LazyLoader("metrics", globals(), 

81 "tensorflow.python.keras.metrics") 

82recurrent = LazyLoader( 

83 "recurrent", globals(), 

84 "tensorflow.python.keras.layers.recurrent") 

85# pylint:enable=g-inconsistent-quotes 

86 

87 

88PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union( 

89 CommonEndpoints.all_checkpointable_objects) 

90PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR) 

91 

92 

93def load(path, compile=True, options=None): # pylint: disable=redefined-builtin 

94 """Loads Keras objects from a SavedModel. 

95 

96 Any Keras layer or model saved to the SavedModel will be loaded back 

97 as Keras objects. Other objects are loaded as regular trackable objects (same 

98 as `tf.saved_model.load`). 

99 

100 Currently, Keras saving/loading only retains the Keras object's weights, 

101 losses, and call function. 

102 

103 The loaded model can be re-compiled, but the original optimizer, compiled loss 

104 functions, and metrics are not retained. This is temporary, and `model.save` 

105 will soon be able to serialize compiled models. 

106 

107 Args: 

108 path: Path to SavedModel. 

109 compile: If true, compile the model after loading it. 

110 options: Optional `tf.saved_model.LoadOptions` object that specifies 

111 options for loading from SavedModel. 

112 

113 

114 Returns: 

115 Object loaded from SavedModel. 

116 """ 

117 # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. 

118 # TODO(kathywu): Add code to load from objects that contain all endpoints 

119 

120 # Look for metadata file or parse the SavedModel 

121 metadata = saved_metadata_pb2.SavedMetadata() 

122 meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0] 

123 object_graph_def = meta_graph_def.object_graph_def 

124 path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH) 

125 if gfile.Exists(path_to_metadata_pb): 

126 try: 

127 with gfile.GFile(path_to_metadata_pb, 'rb') as f: 

128 file_content = f.read() 

129 metadata.ParseFromString(file_content) 

130 except message.DecodeError as e: 

131 raise IOError('Cannot parse keras metadata {}: {}.' 

132 .format(path_to_metadata_pb, str(e))) 

133 else: 

134 logging.warning('SavedModel saved prior to TF 2.5 detected when loading ' 

135 'Keras model. Please ensure that you are saving the model ' 

136 'with model.save() or tf.keras.models.save_model(), *NOT* ' 

137 'tf.saved_model.save(). To confirm, there should be a file ' 

138 'named "keras_metadata.pb" in the SavedModel directory.') 

139 _read_legacy_metadata(object_graph_def, metadata) 

140 

141 if not metadata.nodes: 

142 # When there are no Keras objects, return the results from the core loader 

143 return tf_load.load(path, options=options) 

144 

145 # Recreate layers and metrics using the info stored in the metadata. 

146 keras_loader = KerasObjectLoader(metadata, object_graph_def) 

147 keras_loader.load_layers(compile=compile) 

148 

149 # Generate a dictionary of all loaded nodes. 

150 nodes_to_load = {'root': None} 

151 for node_id, loaded_node in keras_loader.loaded_nodes.items(): 

152 nodes_to_load[keras_loader.get_path(node_id)] = loaded_node 

153 loaded = tf_load.load_partial(path, nodes_to_load, options=options) 

154 

155 # Finalize the loaded layers and remove the extra tracked dependencies. 

156 keras_loader.finalize_objects() 

157 keras_loader.del_tracking() 

158 

159 model = loaded['root'] 

160 

161 # pylint: disable=protected-access 

162 if isinstance(model, training_lib.Model) and compile: 

163 # TODO(kathywu): Use compiled objects from SavedModel, instead of 

164 # creating new objects from the training config. 

165 training_config = model._serialized_attributes['metadata'].get( 

166 'training_config', None) 

167 if training_config is not None: 

168 model.compile(**saving_utils.compile_args_from_training_config( 

169 training_config), from_serialized=True) 

170 saving_utils.try_build_compiled_arguments(model) 

171 if isinstance(model.optimizer, optimizer_v2.OptimizerV2): 

172 if (model.optimizer.get_slot_names()): 

173 logging.warning('Your optimizer uses slots. ' 

174 'Slots cannot be restored from saved_model, ' 

175 'as a result, your model is starting with ' 

176 'a new initialized optimizer.') 

177 else: 

178 logging.warning('No training configuration found in save file, so the ' 

179 'model was *not* compiled. Compile it manually.') 

180 # pylint: enable=protected-access 

181 

182 # Force variables and resources to initialize. 

183 if not context.executing_eagerly(): 

184 sess = backend.get_session() # Variables are initialized by this call. 

185 sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)) 

186 

187 return model 

188 

189 

190def _read_legacy_metadata(object_graph_def, metadata): 

191 """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef.""" 

192 # Older SavedModels store the metadata directly in the proto instead of the 

193 # separate pb file. 

194 node_paths = _generate_object_paths(object_graph_def) 

195 for node_id, proto in enumerate(object_graph_def.nodes): 

196 if (proto.WhichOneof('kind') == 'user_object' and 

197 proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS): 

198 if not proto.user_object.metadata: 

199 raise ValueError('Unable to create a Keras model from this SavedModel. ' 

200 'This SavedModel was created with ' 

201 '`tf.saved_model.save`, and lacks the Keras metadata.' 

202 'Please save your Keras model by calling `model.save`' 

203 'or `tf.keras.models.save_model`.') 

204 metadata.nodes.add( 

205 node_id=node_id, 

206 node_path=node_paths[node_id], 

207 version=versions_pb2.VersionDef( 

208 producer=1, min_consumer=1, bad_consumers=[]), 

209 identifier=proto.user_object.identifier, 

210 metadata=proto.user_object.metadata) 

211 

212 

213def _generate_object_paths(object_graph_def): 

214 """Traverses through an ObjectGraphDef and builds a map of all node paths.""" 

215 paths = {0: 'root'} 

216 nodes_to_visit = [0] 

217 

218 while nodes_to_visit: 

219 current_node = nodes_to_visit.pop() 

220 current_path = paths[current_node] 

221 for reference in object_graph_def.nodes[current_node].children: 

222 if reference.node_id in paths: 

223 continue 

224 paths[reference.node_id] = '{}.{}'.format(current_path, 

225 reference.local_name) 

226 nodes_to_visit.append(reference.node_id) 

227 

228 return paths 

229 

230 

231def _is_graph_network(layer): 

232 """Determines whether the layer is a graph network.""" 

233 # pylint: disable=protected-access 

234 if isinstance(layer, RevivedNetwork): 

235 return False 

236 elif isinstance(layer, functional_lib.Functional): 

237 return (layer._is_graph_network or 

238 isinstance(layer, models_lib.Sequential)) 

239 return False 

240 

241 

242class KerasObjectLoader(object): 

243 """Loader that recreates Keras objects (e.g. layers, models). 

244 

245 Layers and models are revived from either the config or SavedModel following 

246 these rules: 

247 1. If object is a graph network (i.e. Sequential or Functional) then it will 

248 be initialized using the structure from the config only after the children 

249 layers have been created. Graph networks must be initialized with inputs 

250 and outputs, so all child layers must be created beforehand. 

251 2. If object's config exists and the class can be found, then revive from 

252 config. 

253 3. Object may have already been created if its parent was revived from config. 

254 In this case, do nothing. 

255 4. If nothing of the above applies, compose the various artifacts from the 

256 SavedModel to create a subclassed layer or model. At this time, custom 

257 metrics are not supported. 

258 

259 """ 

260 

261 def __init__(self, metadata, object_graph_def): 

262 self._metadata = {x.node_id: x for x in metadata.nodes} 

263 self._proto = object_graph_def 

264 

265 self._node_paths = {node_data.node_id: node_data.node_path 

266 for node_data in metadata.nodes} 

267 self.loaded_nodes = {} # Maps node path -> loaded node 

268 

269 # Store all node ids that have already been traversed when tracking nodes 

270 # that were recreated from the config. 

271 self._traversed_nodes_from_config = set() 

272 

273 # Maps model id -> (blank model obj, list of child layer or their node ids) 

274 # This tracks all layers in functional and sequential models. These models 

275 # are only reconstructed after all of their child layers have been created. 

276 self.model_layer_dependencies = {} 

277 self._models_to_reconstruct = [] 

278 

279 def del_tracking(self): 

280 """Removes tracked references that are only used when loading the model.""" 

281 # Now that the node object has been fully loaded, and the checkpoint has 

282 # been restored, the object no longer needs to track objects added from 

283 # SerializedAttributes. (Note that saving a training checkpoint still 

284 # functions correctly, because layers and variables are tracked separately 

285 # by the Layer object.) 

286 # TODO(kathywu): Instead of outright deleting these nodes (which would 

287 # make restoring from a different checkpoint tricky), mark them as extra 

288 # dependencies that are OK to overwrite. 

289 for node in self.loaded_nodes.values(): 

290 node = node[0] 

291 if not isinstance(node, base_layer.Layer): 

292 # Loaded nodes can contain other trackable objects created when 

293 # loading layers from the config, such as variables. 

294 continue 

295 for name in PUBLIC_ATTRIBUTES: 

296 node._delete_tracking(name) # pylint: disable=protected-access 

297 

298 if isinstance(node, functional_lib.Functional): 

299 # Delete the temporary layer dependencies, which were used to restore 

300 # the checkpointed values. When the model is live, the user can delete 

301 # or add layers to the model at any time, so these layer dependencies 

302 # may be obsolete. 

303 dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access 

304 for name in dependencies: 

305 if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None: 

306 node._delete_tracking(name) # pylint: disable=protected-access 

307 

308 def _add_children_recreated_from_config(self, obj, proto, node_id): 

309 """Recursively records objects recreated from config.""" 

310 # pylint: disable=protected-access 

311 if node_id in self._traversed_nodes_from_config: 

312 return 

313 

314 parent_path = self._node_paths[node_id] 

315 self._traversed_nodes_from_config.add(node_id) 

316 obj._maybe_initialize_trackable() 

317 if isinstance(obj, base_layer.Layer) and not obj.built: 

318 metadata = json_utils.decode(self._metadata[node_id].metadata) 

319 self._try_build_layer(obj, node_id, metadata.get('build_input_shape')) 

320 

321 # Create list of all possible children 

322 children = [] 

323 # Look for direct children 

324 for reference in proto.children: 

325 obj_child = obj._lookup_dependency(reference.local_name) 

326 children.append((obj_child, reference.node_id, reference.local_name)) 

327 

328 # Add metrics that may have been added to the layer._metrics list. 

329 # This is stored in the SavedModel as layer.keras_api.layer_metrics in 

330 # SavedModels created after Tf 2.2. 

331 metric_list_node_id = self._search_for_child_node( 

332 node_id, [constants.KERAS_ATTR, 'layer_metrics']) 

333 if metric_list_node_id is not None and hasattr(obj, '_metrics'): 

334 obj_metrics = {m.name: m for m in obj._metrics} 

335 for reference in self._proto.nodes[metric_list_node_id].children: 

336 metric = obj_metrics.get(reference.local_name) 

337 if metric is not None: 

338 metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR, 

339 reference.local_name) 

340 children.append((metric, reference.node_id, metric_path)) 

341 

342 for (obj_child, child_id, child_name) in children: 

343 child_proto = self._proto.nodes[child_id] 

344 

345 if not isinstance(obj_child, trackable.Trackable): 

346 continue 

347 if (child_proto.user_object.identifier in 

348 revived_types.registered_identifiers()): 

349 setter = revived_types.get_setter(child_proto.user_object) 

350 elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS: 

351 setter = _revive_setter 

352 else: 

353 setter = setattr 

354 # pylint: enable=protected-access 

355 

356 if child_id in self.loaded_nodes: 

357 if self.loaded_nodes[child_id][0] is not obj_child: 

358 # This means that the same trackable object is referenced by two 

359 # different objects that were recreated from the config. 

360 logging.warning( 

361 'Looks like there is an object (perhaps variable or ' 

362 'layer) that is shared between different layers/models. ' 

363 'This may cause issues when restoring the variable ' 

364 'values. Object: {}'.format(obj_child)) 

365 continue 

366 

367 # Overwrite variable names with the ones saved in the SavedModel. 

368 if (child_proto.WhichOneof('kind') == 'variable' and 

369 child_proto.variable.name): 

370 obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access 

371 

372 if isinstance(obj_child, data_structures.TrackableDataStructure): 

373 setter = lambda *args: None 

374 

375 child_path = '{}.{}'.format(parent_path, child_name) 

376 self._node_paths[child_id] = child_path 

377 self._add_children_recreated_from_config( 

378 obj_child, child_proto, child_id) 

379 self.loaded_nodes[child_id] = obj_child, setter 

380 

381 def load_layers(self, compile=True): # pylint: disable=redefined-builtin 

382 """Load all layer nodes from the metadata.""" 

383 # Load metrics after models and layers, since it's likely that models 

384 # and layers will create the metric when initialized (this avoids wasting 

385 # time by creating objects multiple times). 

386 metric_list = [] 

387 for node_metadata in self._metadata.values(): 

388 if node_metadata.identifier == constants.METRIC_IDENTIFIER: 

389 metric_list.append(node_metadata) 

390 continue 

391 

392 self.loaded_nodes[node_metadata.node_id] = self._load_layer( 

393 node_metadata.node_id, node_metadata.identifier, 

394 node_metadata.metadata) 

395 

396 for node_metadata in metric_list: 

397 try: 

398 self.loaded_nodes[node_metadata.node_id] = self._load_layer( 

399 node_metadata.node_id, node_metadata.identifier, 

400 node_metadata.metadata) 

401 except ValueError: 

402 # Metrics are only needed when the model is compiled later. We ignore 

403 # errors when trying to load custom metrics when `compile=False` until 

404 # custom metrics are serialized properly (b/135550038). 

405 if compile: 

406 raise 

407 logging.warning('Unable to restore custom metric. Please ensure that ' 

408 'the layer implements `get_config` and `from_config` ' 

409 'when saving. In addition, please use the ' 

410 '`custom_objects` arg when calling `load_model()`.') 

411 

412 def _load_layer(self, node_id, identifier, metadata): 

413 """Load a single layer from a SavedUserObject proto.""" 

414 metadata = json_utils.decode(metadata) 

415 

416 # If node was already created 

417 if node_id in self.loaded_nodes: 

418 node, setter = self.loaded_nodes[node_id] 

419 

420 # Revive setter requires the object to have a `_serialized_attributes` 

421 # property. Add it here. 

422 _maybe_add_serialized_attributes(node, metadata) 

423 

424 config = metadata.get('config') 

425 if _is_graph_network(node) and generic_utils.validate_config(config): 

426 child_nodes = self._get_child_layer_node_ids(node_id) 

427 self.model_layer_dependencies[node_id] = (node, child_nodes) 

428 if not child_nodes: 

429 self._models_to_reconstruct.append(node_id) 

430 return node, setter 

431 

432 # Detect whether this object can be revived from the config. If not, then 

433 # revive from the SavedModel instead. 

434 obj, setter = self._revive_from_config(identifier, metadata, node_id) 

435 if obj is None: 

436 obj, setter = revive_custom_object(identifier, metadata) 

437 

438 # Add an attribute that stores the extra functions/objects saved in the 

439 # SavedModel. Most of these functions/objects are ignored, but some are 

440 # used later in the loading process (e.g. the list of regularization 

441 # losses, or the training config of compiled models). 

442 _maybe_add_serialized_attributes(obj, metadata) 

443 return obj, setter 

444 

445 def _revive_from_config(self, identifier, metadata, node_id): 

446 """Revives a layer/model from config, or returns None.""" 

447 if identifier == constants.METRIC_IDENTIFIER: 

448 obj = self._revive_metric_from_config(metadata) 

449 else: 

450 obj = ( 

451 self._revive_graph_network(identifier, metadata, node_id) or 

452 self._revive_layer_or_model_from_config(metadata, node_id)) 

453 

454 if obj is None: 

455 return None, None 

456 

457 setter = self._config_node_setter(_revive_setter) 

458 self._add_children_recreated_from_config( 

459 obj, self._proto.nodes[node_id], node_id) 

460 return obj, setter 

461 

462 def _revive_graph_network(self, identifier, metadata, node_id): 

463 """Revives a graph network from config.""" 

464 # Determine whether the metadata contains information for reviving a 

465 # functional or Sequential model. 

466 config = metadata.get('config') 

467 if not generic_utils.validate_config(config): 

468 return None 

469 

470 class_name = compat.as_str(metadata['class_name']) 

471 if generic_utils.get_registered_object(class_name) is not None: 

472 return None 

473 model_is_functional_or_sequential = ( 

474 metadata.get('is_graph_network', False) or 

475 class_name == 'Sequential' or 

476 class_name == 'Functional') 

477 if not model_is_functional_or_sequential: 

478 return None 

479 

480 # Revive functional and sequential models as blank model objects for now ( 

481 # must be initialized to enable setattr tracking and attribute caching). 

482 # Reconstruction of the network is deferred until all of the model's layers 

483 # have been revived. 

484 if class_name == 'Sequential': 

485 model = models_lib.Sequential(name=config['name']) 

486 # The model is a custom Sequential model. 

487 elif identifier == constants.SEQUENTIAL_IDENTIFIER: 

488 # Uses the custom class name, since the config does not have one. 

489 model = models_lib.Sequential(name=class_name) 

490 else: 

491 model = models_lib.Functional( 

492 inputs=[], outputs=[], name=config['name']) 

493 

494 # Record this model and its layers. This will later be used to reconstruct 

495 # the model. 

496 layers = self._get_child_layer_node_ids(node_id) 

497 self.model_layer_dependencies[node_id] = (model, layers) 

498 if not layers: 

499 self._models_to_reconstruct.append(node_id) 

500 return model 

501 

502 def _revive_layer_or_model_from_config(self, metadata, node_id): 

503 """Revives a layer/custom model from config; returns None if infeasible.""" 

504 # Check that the following requirements are met for reviving from config: 

505 # 1. Object can be deserialized from config. 

506 # 2. If the object needs to be built, then the build input shape can be 

507 # found. 

508 class_name = metadata.get('class_name') 

509 config = metadata.get('config') 

510 shared_object_id = metadata.get('shared_object_id') 

511 must_restore_from_config = metadata.get('must_restore_from_config') 

512 if not generic_utils.validate_config(config): 

513 return None 

514 

515 try: 

516 obj = layers_module.deserialize( 

517 generic_utils.serialize_keras_class_and_config( 

518 class_name, config, shared_object_id=shared_object_id)) 

519 except ValueError: 

520 if must_restore_from_config: 

521 raise RuntimeError( 

522 'Unable to restore a layer of class {cls}. Layers of ' 

523 'class {cls} require that the class be provided to ' 

524 'the model loading code, either by registering the ' 

525 'class using @keras.utils.register_keras_serializable ' 

526 'on the class def and including that file in your ' 

527 'program, or by passing the class in a ' 

528 'keras.utils.CustomObjectScope that wraps this load ' 

529 'call.'.format(cls=class_name)) 

530 else: 

531 return None 

532 

533 # Use the dtype, name, and trainable status. Often times these are not 

534 # specified in custom configs, so retrieve their values from the metadata. 

535 # pylint: disable=protected-access 

536 obj._name = metadata['name'] 

537 if metadata.get('trainable') is not None: 

538 obj.trainable = metadata['trainable'] 

539 if metadata.get('dtype') is not None: 

540 obj._set_dtype_policy(metadata['dtype']) 

541 if metadata.get('stateful') is not None: 

542 obj.stateful = metadata['stateful'] 

543 # Restore model save spec for subclassed models. (layers do not store a 

544 # SaveSpec) 

545 if isinstance(obj, training_lib.Model): 

546 save_spec = metadata.get('save_spec') 

547 if save_spec is not None: 

548 obj._set_save_spec(save_spec) 

549 # pylint: enable=protected-access 

550 

551 build_input_shape = metadata.get('build_input_shape') 

552 built = self._try_build_layer(obj, node_id, build_input_shape) 

553 

554 if not built: 

555 # If the layer cannot be built, revive a custom layer instead. 

556 return None 

557 return obj 

558 

559 def _revive_metric_from_config(self, metadata): 

560 """Revives a metric object using the config saved in the metadata.""" 

561 class_name = compat.as_str(metadata['class_name']) 

562 config = metadata.get('config') 

563 

564 if not generic_utils.validate_config(config): 

565 return None 

566 

567 try: 

568 obj = metrics.deserialize( 

569 generic_utils.serialize_keras_class_and_config(class_name, config)) 

570 except ValueError: 

571 return None 

572 

573 build_input_shape = metadata.get('build_input_shape') 

574 if build_input_shape is not None and hasattr(obj, '_build'): 

575 obj._build(build_input_shape) # pylint: disable=protected-access 

576 

577 return obj 

578 

579 def _try_build_layer(self, obj, node_id, build_input_shape): 

580 """Attempts to build the layer.""" 

581 if obj.built or hasattr(obj.build, '_is_default'): 

582 obj.built = True 

583 return True 

584 

585 if build_input_shape is None: 

586 build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True) 

587 

588 if build_input_shape is not None: 

589 obj.build(build_input_shape) 

590 base_layer.Layer.build(obj, build_input_shape) 

591 return True 

592 

593 return False 

594 

595 def _load_edges(self): 

596 """Add edges for all nodes that are not waiting on initialization.""" 

597 for node_id, proto in enumerate(self._proto.nodes): 

598 if node_id not in self.model_layer_dependencies: 

599 self._add_object_graph_edges(proto, node_id) 

600 

601 def get_path(self, node_id): 

602 return self._node_paths[node_id] 

603 

604 def finalize_objects(self): 

605 """Finish setting up Keras objects. 

606 

607 This function is executed after all objects and functions have been created. 

608 Call functions and losses are attached to each layer, and once all layers 

609 have been fully set up, graph networks are initialized. 

610 

611 Subclassed models that are revived from the SavedModel are treated like 

612 layers, and have their call/loss functions attached here. 

613 """ 

614 # Finish setting up layers and subclassed models. This step attaches call 

615 # functions and losses to each object, and sets model inputs/outputs. 

616 layers_revived_from_config = [] 

617 layers_revived_from_saved_model = [] 

618 for node_id, (node, _) in self.loaded_nodes.items(): 

619 if (not isinstance(node, base_layer.Layer) or 

620 # Don't finalize models until all layers have finished loading. 

621 node_id in self.model_layer_dependencies): 

622 continue 

623 

624 self._unblock_model_reconstruction(node_id, node) 

625 

626 if isinstance(node, input_layer.InputLayer): 

627 continue 

628 elif isinstance(node, metrics.Metric): 

629 continue 

630 

631 if isinstance(node, (RevivedLayer, RevivedInputLayer)): 

632 layers_revived_from_saved_model.append(node) 

633 else: 

634 layers_revived_from_config.append(node) 

635 

636 _finalize_saved_model_layers(layers_revived_from_saved_model) 

637 _finalize_config_layers(layers_revived_from_config) 

638 

639 # Initialize graph networks, now that layer dependencies have been resolved. 

640 self._reconstruct_all_models() 

641 

642 def _unblock_model_reconstruction(self, layer_id, layer): 

643 """Removes layer from blocking model reconstruction.""" 

644 for model_id, v in self.model_layer_dependencies.items(): 

645 _, layers = v 

646 if layer_id not in layers: 

647 continue 

648 layers[layers.index(layer_id)] = layer 

649 if all(isinstance(x, base_layer.Layer) for x in layers): 

650 self._models_to_reconstruct.append(model_id) 

651 

652 def _reconstruct_all_models(self): 

653 """Reconstructs the network structure of all models.""" 

654 all_initialized_models = set() 

655 while self._models_to_reconstruct: 

656 model_id = self._models_to_reconstruct.pop(0) 

657 all_initialized_models.add(model_id) 

658 model, layers = self.model_layer_dependencies[model_id] 

659 self._reconstruct_model(model_id, model, layers) 

660 _finalize_config_layers([model]) 

661 

662 if all_initialized_models != set(self.model_layer_dependencies.keys()): 

663 # This should not happen. 

664 uninitialized_model_ids = ( 

665 set(self.model_layer_dependencies.keys()) - all_initialized_models) 

666 uninitialized_model_names = [ 

667 self.model_layer_dependencies[model_id][0].name 

668 for model_id in uninitialized_model_ids] 

669 raise ValueError('Error when loading from SavedModel -- the following ' 

670 'models could not be initialized: {}' 

671 .format(uninitialized_model_names)) 

672 

673 def _reconstruct_model(self, model_id, model, layers): 

674 """Reconstructs the network structure.""" 

675 config = json_utils.decode(self._metadata[model_id].metadata)['config'] 

676 

677 # Set up model inputs 

678 if model.inputs: 

679 # Inputs may already be created if the model is instantiated in another 

680 # object's __init__. 

681 pass 

682 elif isinstance(model, models_lib.Sequential): 

683 if not layers or not isinstance(layers[0], input_layer.InputLayer): 

684 if config['layers'][0]['class_name'] == 'InputLayer': 

685 layers.insert(0, input_layer.InputLayer.from_config( 

686 config['layers'][0]['config'])) 

687 elif 'batch_input_shape' in config['layers'][0]['config']: 

688 batch_input_shape = config['layers'][0]['config']['batch_input_shape'] 

689 layers.insert(0, input_layer.InputLayer( 

690 input_shape=batch_input_shape[1:], 

691 batch_size=batch_input_shape[0], 

692 dtype=layers[0].dtype, 

693 name=layers[0].name + '_input')) 

694 model.__init__(layers, name=config['name']) 

695 if not model.inputs: 

696 first_layer = self._get_child_layer_node_ids(model_id)[0] 

697 input_specs = self._infer_inputs(first_layer) 

698 input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True) 

699 model._set_inputs(input_specs) # pylint: disable=protected-access 

700 if not model.built and not isinstance(input_specs, dict): 

701 model.build(input_shapes) 

702 else: # Reconstruct functional model 

703 (inputs, outputs, 

704 created_layers) = functional_lib.reconstruct_from_config( 

705 config, created_layers={layer.name: layer for layer in layers}) 

706 model.__init__(inputs, outputs, name=config['name']) 

707 functional_lib.connect_ancillary_layers(model, created_layers) 

708 

709 # Set model dtype. 

710 _set_network_attributes_from_metadata(model) 

711 

712 # Unblock models that are dependent on this model. 

713 self._unblock_model_reconstruction(model_id, model) 

714 

715 def _get_child_layer_node_ids(self, node_id): 

716 """Returns the node ids of each layer in a Sequential/Functional model.""" 

717 # Sequential and Functional track layers with names following the format 

718 # "layer-N". Use this to generate the list of layers. 

719 num_layers = 0 

720 child_layers = {} 

721 pattern = re.compile('layer-(\\d+)') 

722 

723 for child in self._proto.nodes[node_id].children: 

724 m = pattern.match(child.local_name) 

725 if m is None: 

726 continue 

727 layer_n = int(m.group(1)) 

728 num_layers = max(layer_n + 1, num_layers) 

729 child_layers[layer_n] = child.node_id 

730 

731 ordered = [] 

732 for n in range(num_layers): 

733 child = child_layers.get(n) 

734 if child is None: 

735 break 

736 ordered.append(child) 

737 return ordered 

738 

739 def _search_for_child_node(self, parent_id, path_to_child): 

740 """Returns node id of child node. 

741 

742 A helper method for traversing the object graph proto. 

743 

744 As an example, say that the object graph proto in the SavedModel contains an 

745 object with the following child and grandchild attributes: 

746 

747 `parent.child_a.child_b` 

748 

749 This method can be used to retrieve the node id of `child_b` using the 

750 parent's node id by calling: 

751 

752 `_search_for_child_node(parent_id, ['child_a', 'child_b'])`. 

753 

754 Args: 

755 parent_id: node id of parent node 

756 path_to_child: list of children names. 

757 

758 Returns: 

759 node_id of child, or None if child isn't found. 

760 """ 

761 if not path_to_child: 

762 return parent_id 

763 

764 for child in self._proto.nodes[parent_id].children: 

765 if child.local_name == path_to_child[0]: 

766 return self._search_for_child_node(child.node_id, path_to_child[1:]) 

767 return None 

768 

769 def _infer_inputs(self, layer_node_id, convert_to_shapes=False): 

770 """Infers input shape of layer from SavedModel functions.""" 

771 call_fn_id = self._search_for_child_node( 

772 layer_node_id, ['call_and_return_all_conditional_losses']) 

773 if call_fn_id is None: 

774 return None 

775 

776 concrete_functions = ( 

777 self._proto.nodes[call_fn_id].function.concrete_functions) 

778 if not concrete_functions: 

779 return None 

780 call_fn_name = concrete_functions[0] 

781 call_fn_proto = self._proto.concrete_functions[call_fn_name] 

782 structured_input_signature = nested_structure_coder.decode_proto( 

783 call_fn_proto.canonicalized_input_signature) 

784 inputs = structured_input_signature[0][0] 

785 if convert_to_shapes: 

786 return nest.map_structure(lambda spec: spec.shape, inputs) 

787 else: 

788 return inputs 

789 

790 def _config_node_setter(self, setter): 

791 """Creates edges for nodes that are recreated from config.""" 

792 def setattr_wrapper(obj, name, value): 

793 # Avoid overwriting attributes of objects recreated from the config. 

794 if obj._lookup_dependency(name) is None: # pylint: disable=protected-access 

795 setter(obj, name, value) 

796 return setattr_wrapper 

797 

798 

799def _finalize_saved_model_layers(layers): 

800 """Runs the final steps of loading Keras Layers from SavedModel.""" 

801 # pylint: disable=protected-access 

802 # 1. Set up call functions for all layers initialized from the SavedModel ( 

803 # and not the config) 

804 for layer in layers: 

805 layer.built = True 

806 layer_call = getattr(_get_keras_attr(layer), 

807 'call_and_return_conditional_losses', None) 

808 if layer_call and layer_call.concrete_functions: 

809 layer.call = utils.use_wrapped_call( 

810 layer, layer_call, return_method=True) 

811 expects_training_arg = layer._serialized_attributes['metadata'][ 

812 'expects_training_arg'] 

813 if 'training' in layer_call.function_spec.arg_names: 

814 # This could change the value of `expects_training_arg` if this layer 

815 # doesn't expect a training arg, but has a child layer that does. 

816 expects_training_arg = True 

817 layer._init_call_fn_args(expects_training_arg) 

818 else: 

819 layer.call = types.MethodType( 

820 _unable_to_call_layer_due_to_serialization_issue, layer) 

821 

822 for layer in layers: 

823 # 2. Set model inputs and outputs. 

824 if isinstance(layer, RevivedNetwork): 

825 _set_network_attributes_from_metadata(layer) 

826 

827 if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'): 

828 call_fn = _get_keras_attr(layer).call_and_return_conditional_losses 

829 if not call_fn.concrete_functions: 

830 continue 

831 if call_fn.input_signature is None: 

832 inputs = infer_inputs_from_restored_call_function(call_fn) 

833 else: 

834 inputs = call_fn.input_signature[0] 

835 layer._set_inputs(inputs) # pylint: disable=protected-access 

836 

837 # 3. Add losses that aren't generated by the layer.call function. 

838 _restore_layer_unconditional_losses(layer) 

839 _restore_layer_activation_loss(layer) 

840 

841 # 4. Restore metrics list 

842 _restore_layer_metrics(layer) 

843 

844 # pylint: enable=protected-access 

845 

846 

847def _unable_to_call_layer_due_to_serialization_issue( 

848 layer, *unused_args, **unused_kwargs): 

849 """Replaces the `layer.call` if the layer was not fully serialized. 

850 

851 Keras Model/Layer serialization is relatively relaxed because SavedModels 

852 are not always loaded back as keras models. Thus, when there is an issue 

853 tracing a non-signature function, a warning is logged instead of raising an 

854 error. This results in a SavedModel where the model's call function is saved, 

855 but the internal layer call functions are not. 

856 

857 When deserialized with `tf.keras.models.load_model`, the internal layers 

858 which do not have serialized call functions should raise an error when called. 

859 

860 Args: 

861 layer: Layer without the serialized call function. 

862 

863 Raises: 

864 ValueError 

865 """ 

866 

867 raise ValueError( 

868 'Cannot call custom layer {} of type {}, because the call function was ' 

869 'not serialized to the SavedModel.' 

870 'Please try one of the following methods to fix this issue:' 

871 '\n\n(1) Implement `get_config` and `from_config` in the layer/model ' 

872 'class, and pass the object to the `custom_objects` argument when ' 

873 'loading the model. For more details, see: ' 

874 'https://www.tensorflow.org/guide/keras/save_and_serialize' 

875 '\n\n(2) Ensure that the subclassed model or layer overwrites `call` ' 

876 'and not `__call__`. The input shape and dtype will be automatically ' 

877 'recorded when the object is called, and used when saving. To manually ' 

878 'specify the input shape/dtype, decorate the call function with ' 

879 '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer))) 

880 

881 

882def _finalize_config_layers(layers): 

883 """Runs the final steps of loading Keras Layers from config.""" 

884 for layer in layers: 

885 # It is assumed that layers define their unconditional losses after being 

886 # recreated from the config and built. The exceptions to this 

887 # are Functional and Sequential models, which only store conditional losses 

888 # (losses dependent on the inputs) in the config. Unconditional losses like 

889 # weight regularization must be revived from the SavedModel. 

890 if _is_graph_network(layer): 

891 _restore_layer_unconditional_losses(layer) 

892 

893 # Some layers, like Dense, record their activation loss function in the 

894 # config. However, not all layers do this, so the activation loss may be 

895 # missing when restored from the config/hdf5. 

896 # TODO(kathywu): Investigate ways to improve the config to ensure consistent 

897 # loading behavior between HDF5 and SavedModel. 

898 _restore_layer_activation_loss(layer) 

899 

900 # Restore metrics list. 

901 _restore_layer_metrics(layer) 

902 

903 # Restore RNN layer states. 

904 if (isinstance(layer, recurrent.RNN) and 

905 layer.stateful and 

906 hasattr(_get_keras_attr(layer), 'states')): 

907 layer.states = getattr(_get_keras_attr(layer), 'states', None) 

908 for variable in nest.flatten(layer.states): 

909 backend.track_variable(variable) 

910 

911 # Perform any layer defined finalization of the layer state. 

912 layer.finalize_state() 

913 

914 

915def _finalize_metric(metric): 

916 metric.update_state = types.MethodType(metrics_utils.update_state_wrapper( 

917 metric.keras_api.update_state), metric) 

918 metric.result = metric.keras_api.result 

919 

920 

921def _restore_layer_unconditional_losses(layer): 

922 """Restore unconditional losses from SavedModel.""" 

923 if hasattr(_get_keras_attr(layer), 'layer_regularization_losses'): 

924 losses = getattr(_get_keras_attr(layer), 'layer_regularization_losses', []) 

925 else: 

926 # Some earlier SavedModels may not have layer_regularization_losses 

927 # serialized separately. Fall back to using the regularization_losses 

928 # list if it does not exist. 

929 losses = layer._serialized_attributes.get('regularization_losses', []) # pylint: disable=protected-access 

930 for loss in losses: 

931 layer.add_loss(loss) 

932 

933 

934def _restore_layer_activation_loss(layer): 

935 """Restore actiation loss from SavedModel.""" 

936 # Use wrapped activity regularizer function if the layer's activity 

937 # regularizer wasn't created during initialization. 

938 activity_regularizer = getattr(_get_keras_attr(layer), 

939 'activity_regularizer_fn', None) 

940 if activity_regularizer and not layer.activity_regularizer: 

941 try: 

942 layer.activity_regularizer = activity_regularizer 

943 except AttributeError: 

944 # This may happen if a layer wrapper is saved with an activity 

945 # regularizer. The wrapper object's activity regularizer is unsettable. 

946 pass 

947 

948 

949def revive_custom_object(identifier, metadata): 

950 """Revives object from SavedModel.""" 

951 if ops.executing_eagerly_outside_functions(): 

952 model_class = training_lib.Model 

953 else: 

954 model_class = training_lib_v1.Model 

955 

956 revived_classes = { 

957 constants.INPUT_LAYER_IDENTIFIER: ( 

958 RevivedInputLayer, input_layer.InputLayer), 

959 constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer), 

960 constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class), 

961 constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional), 

962 constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential), 

963 } 

964 parent_classes = revived_classes.get(identifier, None) 

965 

966 if parent_classes is not None: 

967 parent_classes = revived_classes[identifier] 

968 revived_cls = type( 

969 compat.as_str(metadata['class_name']), parent_classes, {}) 

970 return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access 

971 else: 

972 raise ValueError('Unable to restore custom object of type {} currently. ' 

973 'Please make sure that the layer implements `get_config`' 

974 'and `from_config` when saving. In addition, please use ' 

975 'the `custom_objects` arg when calling `load_model()`.' 

976 .format(identifier)) 

977 

978 

979def _restore_layer_metrics(layer): 

980 metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {}) 

981 layer_metrics = {m.name: m for m in layer._metrics} # pylint: disable=protected-access 

982 for name, metric in metrics_list.items(): 

983 if name not in layer_metrics: 

984 # Metrics may be added during initialization/building of custom layers. 

985 layer._metrics.append(metric) # pylint: disable=protected-access 

986 

987 

988# TODO(kathywu): Centrally define keys and functions for both serialization and 

989# deserialization. 

990class RevivedLayer(object): 

991 """Keras layer loaded from a SavedModel.""" 

992 

993 @classmethod 

994 def _init_from_metadata(cls, metadata): 

995 """Create revived layer from metadata stored in the SavedModel proto.""" 

996 init_args = dict( 

997 name=metadata['name'], 

998 trainable=metadata['trainable']) 

999 if metadata.get('dtype') is not None: 

1000 init_args['dtype'] = metadata['dtype'] 

1001 if metadata.get('batch_input_shape') is not None: 

1002 init_args['batch_input_shape'] = metadata['batch_input_shape'] 

1003 

1004 revived_obj = cls(**init_args) 

1005 

1006 with utils.no_automatic_dependency_tracking_scope(revived_obj): 

1007 # pylint:disable=protected-access 

1008 revived_obj._expects_training_arg = metadata['expects_training_arg'] 

1009 config = metadata.get('config') 

1010 if generic_utils.validate_config(config): 

1011 revived_obj._config = config 

1012 if metadata.get('input_spec') is not None: 

1013 revived_obj.input_spec = recursively_deserialize_keras_object( 

1014 metadata['input_spec'], 

1015 module_objects={'InputSpec': input_spec.InputSpec}) 

1016 if metadata.get('activity_regularizer') is not None: 

1017 revived_obj.activity_regularizer = regularizers.deserialize( 

1018 metadata['activity_regularizer']) 

1019 if metadata.get('_is_feature_layer') is not None: 

1020 revived_obj._is_feature_layer = metadata['_is_feature_layer'] 

1021 if metadata.get('stateful') is not None: 

1022 revived_obj.stateful = metadata['stateful'] 

1023 # pylint:enable=protected-access 

1024 

1025 return revived_obj, _revive_setter 

1026 

1027 @property 

1028 def keras_api(self): 

1029 return self._serialized_attributes.get(constants.KERAS_ATTR, None) 

1030 

1031 def get_config(self): 

1032 if hasattr(self, '_config'): 

1033 return self._config 

1034 else: 

1035 raise NotImplementedError 

1036 

1037 

1038def _revive_setter(layer, name, value): 

1039 """Setter function that saves some attributes to separate dictionary.""" 

1040 # Many attributes in the SavedModel conflict with properties defined in 

1041 # Layer and Model. Save these attributes to a separate dictionary. 

1042 if name in PUBLIC_ATTRIBUTES: 

1043 # pylint: disable=protected-access 

1044 if isinstance(value, trackable.Trackable): 

1045 layer._track_trackable(value, name=name) 

1046 layer._serialized_attributes[name] = value 

1047 # pylint: enable=protected-access 

1048 elif (isinstance(layer, functional_lib.Functional) and 

1049 re.match(r'^layer(_with_weights)?-[\d+]', name) is not None): 

1050 # Edges named "layer-n" or "layer_with_weights-n", which are tracked in 

1051 # network._track_layers, should not be added as an attribute. They should 

1052 # be temporarily added as a dependency so that checkpointed values can be 

1053 # restored. These dependencies are manually deleted in 

1054 # KerasObjectLoader.del_tracking. 

1055 

1056 # Set `overwrite=True` in the case that `layer` already tracks a different 

1057 # layer-n. This may cause variable values to not be loaded properly in the 

1058 # original layer-n, but we already warn the users about this 

1059 # (ctrl-f "shared between different layers/models"). 

1060 layer._track_trackable(value, name, overwrite=True) # pylint: disable=protected-access 

1061 elif getattr(layer, name, None) is not None: 

1062 # Don't overwrite already defined attributes. 

1063 pass 

1064 else: 

1065 setattr(layer, name, value) 

1066 

1067 

1068class RevivedInputLayer(object): 

1069 """InputLayer loaded from a SavedModel.""" 

1070 

1071 @classmethod 

1072 def _init_from_metadata(cls, metadata): 

1073 """Revives the saved InputLayer from the Metadata.""" 

1074 init_args = dict( 

1075 name=metadata['name'], 

1076 dtype=metadata['dtype'], 

1077 sparse=metadata['sparse'], 

1078 ragged=metadata['ragged'], 

1079 batch_input_shape=metadata['batch_input_shape']) 

1080 revived_obj = cls(**init_args) 

1081 with utils.no_automatic_dependency_tracking_scope(revived_obj): 

1082 revived_obj._config = metadata['config'] # pylint:disable=protected-access 

1083 

1084 return revived_obj, setattr 

1085 

1086 def get_config(self): 

1087 return self._config 

1088 

1089 

1090def recursively_deserialize_keras_object(config, module_objects=None): 

1091 """Deserialize Keras object from a nested structure.""" 

1092 if isinstance(config, dict): 

1093 if 'class_name' in config: 

1094 return generic_utils.deserialize_keras_object( 

1095 config, module_objects=module_objects) 

1096 else: 

1097 return {key: recursively_deserialize_keras_object(config[key], 

1098 module_objects) 

1099 for key in config} 

1100 if isinstance(config, (tuple, list)): 

1101 return [recursively_deserialize_keras_object(x, module_objects) 

1102 for x in config] 

1103 else: 

1104 raise ValueError('Unable to decode config: {}'.format(config)) 

1105 

1106 

1107def get_common_shape(x, y): 

1108 """Find a `TensorShape` that is compatible with both `x` and `y`.""" 

1109 if x is None != y is None: 

1110 raise RuntimeError( 

1111 'Cannot find a common shape when LHS shape is None but RHS shape ' 

1112 'is not (or vice versa): %s vs. %s' % (x, y)) 

1113 if x is None: 

1114 return None # The associated input was not a Tensor, no shape generated. 

1115 if not isinstance(x, tensor_shape.TensorShape): 

1116 raise TypeError('Expected x to be a TensorShape but saw %s' % (x,)) 

1117 if not isinstance(y, tensor_shape.TensorShape): 

1118 raise TypeError('Expected y to be a TensorShape but saw %s' % (y,)) 

1119 if x.rank != y.rank or x.rank is None: 

1120 return tensor_shape.TensorShape(None) 

1121 dims = [] 

1122 for dim_x, dim_y in zip(x.dims, y.dims): 

1123 if (dim_x != dim_y 

1124 or tensor_shape.dimension_value(dim_x) is None 

1125 or tensor_shape.dimension_value(dim_y) is None): 

1126 dims.append(None) 

1127 else: 

1128 dims.append(tensor_shape.dimension_value(dim_x)) 

1129 return tensor_shape.TensorShape(dims) 

1130 

1131 

1132def infer_inputs_from_restored_call_function(fn): 

1133 """Returns TensorSpec of inputs from a restored call function. 

1134 

1135 Args: 

1136 fn: Restored layer call function. It is assumed that `fn` has at least 

1137 one concrete function and that the inputs are in the first argument. 

1138 

1139 Returns: 

1140 TensorSpec of call function inputs. 

1141 """ 

1142 def common_spec(x, y): 

1143 common_shape = get_common_shape(x.shape, y.shape) 

1144 if isinstance(x, sparse_tensor.SparseTensorSpec): 

1145 return sparse_tensor.SparseTensorSpec(common_shape, x.dtype) 

1146 elif isinstance(x, ragged_tensor.RaggedTensorSpec): 

1147 return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype) 

1148 return tensor_spec.TensorSpec(common_shape, x.dtype, x.name) 

1149 

1150 spec = fn.concrete_functions[0].structured_input_signature[0][0] 

1151 for concrete in fn.concrete_functions[1:]: 

1152 spec2 = concrete.structured_input_signature[0][0] 

1153 spec = nest.map_structure(common_spec, spec, spec2) 

1154 return spec 

1155 

1156 

1157class RevivedNetwork(RevivedLayer): 

1158 """Keras network of layers loaded from a SavedModel.""" 

1159 

1160 @classmethod 

1161 def _init_from_metadata(cls, metadata): 

1162 """Create revived network from metadata stored in the SavedModel proto.""" 

1163 revived_obj = cls(name=metadata['name']) 

1164 

1165 # Store attributes revived from SerializedAttributes in a un-tracked 

1166 # dictionary. The attributes are the ones listed in CommonEndpoints or 

1167 # "keras_api" for keras-specific attributes. 

1168 with utils.no_automatic_dependency_tracking_scope(revived_obj): 

1169 # pylint:disable=protected-access 

1170 revived_obj._expects_training_arg = metadata['expects_training_arg'] 

1171 config = metadata.get('config') 

1172 if generic_utils.validate_config(config): 

1173 revived_obj._config = config 

1174 

1175 if metadata.get('activity_regularizer') is not None: 

1176 revived_obj.activity_regularizer = regularizers.deserialize( 

1177 metadata['activity_regularizer']) 

1178 # pylint:enable=protected-access 

1179 

1180 return revived_obj, _revive_setter # pylint:disable=protected-access 

1181 

1182 

1183def _set_network_attributes_from_metadata(revived_obj): 

1184 """Sets attributes recorded in the metadata.""" 

1185 with utils.no_automatic_dependency_tracking_scope(revived_obj): 

1186 # pylint:disable=protected-access 

1187 metadata = revived_obj._serialized_attributes['metadata'] 

1188 if metadata.get('dtype') is not None: 

1189 revived_obj._set_dtype_policy(metadata['dtype']) 

1190 revived_obj._trainable = metadata['trainable'] 

1191 # pylint:enable=protected-access 

1192 

1193 

1194def _maybe_add_serialized_attributes(layer, metadata): 

1195 # Store attributes revived from SerializedAttributes in a un-tracked 

1196 # dictionary. The attributes are the ones listed in CommonEndpoints or 

1197 # "keras_api" for keras-specific attributes. 

1198 if not hasattr(layer, '_serialized_attributes'): 

1199 with utils.no_automatic_dependency_tracking_scope(layer): 

1200 layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access 

1201 

1202 

1203def _get_keras_attr(layer): 

1204 return getattr(layer, '_serialized_attributes', {}).get(constants.KERAS_ATTR, 

1205 None)