Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/models.py: 16%

276 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16"""Code for model cloning, plus model-related API entries.""" 

17 

18from tensorflow.python.framework import ops 

19from tensorflow.python.keras import backend 

20from tensorflow.python.keras import metrics as metrics_module 

21from tensorflow.python.keras import optimizer_v1 

22from tensorflow.python.keras.engine import functional 

23from tensorflow.python.keras.engine import sequential 

24from tensorflow.python.keras.engine import training 

25from tensorflow.python.keras.engine import training_v1 

26from tensorflow.python.keras.engine.base_layer import AddMetric 

27from tensorflow.python.keras.engine.base_layer import Layer 

28from tensorflow.python.keras.engine.input_layer import Input 

29from tensorflow.python.keras.engine.input_layer import InputLayer 

30from tensorflow.python.keras.saving import model_config 

31from tensorflow.python.keras.saving import save 

32from tensorflow.python.keras.utils import generic_utils 

33from tensorflow.python.keras.utils import version_utils 

34from tensorflow.python.keras.utils.generic_utils import CustomObjectScope 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.util import nest 

37from tensorflow.python.util.tf_export import keras_export 

38 

39 

40# API entries importable from `keras.models`: 

41Model = training.Model # pylint: disable=invalid-name 

42Sequential = sequential.Sequential # pylint: disable=invalid-name 

43Functional = functional.Functional # pylint: disable=invalid-name 

44save_model = save.save_model 

45load_model = save.load_model 

46model_from_config = model_config.model_from_config 

47model_from_yaml = model_config.model_from_yaml 

48model_from_json = model_config.model_from_json 

49 

50 

51# Callable used to clone a layer with weights preserved. 

52def share_weights(layer): 

53 return layer 

54 

55 

56def _clone_layer(layer): 

57 return layer.__class__.from_config(layer.get_config()) 

58 

59 

60def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes): 

61 """Inserts ancillary layers into the model with the proper order.""" 

62 # Sort `AddMetric` layers so they agree with metrics_names. 

63 metric_layers = [ 

64 layer for layer in ancillary_layers if isinstance(layer, AddMetric) 

65 ] 

66 metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name)) 

67 ancillary_layers = [ 

68 layer for layer in ancillary_layers if not isinstance(layer, AddMetric) 

69 ] + metric_layers 

70 model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes)) 

71 

72 

73def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map): 

74 """Uses the layers in `layer_map` to make new nodes based on `nodes_by_depth`. 

75 

76 Args: 

77 nodes_by_depth: Provides structure information to create new nodes. 

78 layer_fn: Function to clone layers. 

79 layer_map: Map from layers in `model` to new layers. 

80 tensor_map: Map from tensors in `model` to newly compute tensors. 

81 

82 Returns: 

83 A set of new nodes. `layer_map` and `tensor_map` are updated. 

84 """ 

85 # Iterated over every node in the reference model, in depth order. 

86 new_nodes = set() 

87 depth_keys = list(nodes_by_depth.keys()) 

88 depth_keys.sort(reverse=True) 

89 for depth in depth_keys: 

90 nodes = nodes_by_depth[depth] 

91 for node in nodes: 

92 # Recover the corresponding layer. 

93 layer = node.outbound_layer 

94 

95 # Get or create layer. 

96 if layer not in layer_map: 

97 new_layer = layer_fn(layer) 

98 layer_map[layer] = new_layer 

99 layer = new_layer 

100 else: 

101 # Reuse previously cloned layer. 

102 layer = layer_map[layer] 

103 # Don't call InputLayer multiple times. 

104 if isinstance(layer, InputLayer): 

105 continue 

106 

107 # If all previous input tensors are available in tensor_map, 

108 # then call node.inbound_layer on them. 

109 if all( 

110 tensor in tensor_map for tensor in nest.flatten(node.input_tensors)): 

111 # Call layer. 

112 args = nest.map_structure(lambda t: tensor_map.get(t, t), 

113 node.call_args) 

114 kwargs = nest.map_structure(lambda t: tensor_map.get(t, t), 

115 node.call_kwargs) 

116 output_tensors = layer(*args, **kwargs) 

117 

118 # Thread-safe way to keep track of what node was created. 

119 first_output_tensor = nest.flatten(output_tensors)[0] 

120 new_nodes.add( 

121 layer._inbound_nodes[first_output_tensor._keras_history.node_index]) 

122 

123 for x, y in zip( 

124 nest.flatten(node.output_tensors), nest.flatten(output_tensors)): 

125 tensor_map[x] = y 

126 return new_nodes 

127 

128 

129def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer): 

130 """Clone a functional `Model` instance. 

131 

132 Model cloning is similar to calling a model on new inputs, 

133 except that it creates new layers (and thus new weights) instead 

134 of sharing the weights of the existing layers. 

135 

136 Input layers are always cloned. 

137 

138 Args: 

139 model: Instance of `Model`. 

140 input_tensors: optional list of input tensors 

141 to build the model upon. If not provided, 

142 placeholders will be created. 

143 layer_fn: callable to be applied on non-input layers in the model. By 

144 default it clones the layer. Another example is to preserve the layer 

145 to share the weights. This is required when we create a per-replica 

146 copy of the model with distribution strategy; we want the weights to 

147 be shared but still feed inputs separately so we create new input 

148 layers. 

149 

150 Returns: 

151 An instance of `Model` reproducing the behavior 

152 of the original model, on top of new inputs tensors, 

153 using newly instantiated weights. 

154 

155 Raises: 

156 ValueError: in case of invalid `model` argument value or `layer_fn` 

157 argument value. 

158 """ 

159 if not isinstance(model, Model): 

160 raise ValueError('Expected `model` argument ' 

161 'to be a `Model` instance, got ', model) 

162 if isinstance(model, Sequential): 

163 raise ValueError('Expected `model` argument ' 

164 'to be a functional `Model` instance, ' 

165 'got a `Sequential` instance instead:', model) 

166 if not model._is_graph_network: 

167 raise ValueError('Expected `model` argument ' 

168 'to be a functional `Model` instance, ' 

169 'but got a subclass model instead.') 

170 

171 new_input_layers = {} # Cache for created layers. 

172 if input_tensors is not None: 

173 # Make sure that all input tensors come from a Keras layer. 

174 input_tensors = nest.flatten(input_tensors) 

175 for i, input_tensor in enumerate(input_tensors): 

176 original_input_layer = model._input_layers[i] 

177 

178 # Cache input layer. Create a new layer if the tensor is originally not 

179 # from a Keras layer. 

180 if not backend.is_keras_tensor(input_tensor): 

181 name = original_input_layer.name 

182 input_tensor = Input(tensor=input_tensor, 

183 name='input_wrapper_for_' + name) 

184 newly_created_input_layer = input_tensor._keras_history.layer 

185 new_input_layers[original_input_layer] = newly_created_input_layer 

186 else: 

187 new_input_layers[original_input_layer] = original_input_layer 

188 

189 if not callable(layer_fn): 

190 raise ValueError('Expected `layer_fn` argument to be a callable.') 

191 

192 model_configs, created_layers = _clone_layers_and_model_config( 

193 model, new_input_layers, layer_fn) 

194 # Reconstruct model from the config, using the cloned layers. 

195 input_tensors, output_tensors, created_layers = ( 

196 functional.reconstruct_from_config(model_configs, 

197 created_layers=created_layers)) 

198 metrics_names = model.metrics_names 

199 model = Model(input_tensors, output_tensors, name=model.name) 

200 # Layers not directly tied to outputs of the Model, such as loss layers 

201 # created in `add_loss` and `add_metric`. 

202 ancillary_layers = [ 

203 layer for layer in created_layers.values() if layer not in model.layers 

204 ] 

205 # TODO(b/162887610): This may need to adjust the inbound node index if the 

206 # created layers had already been used to define other models. 

207 if ancillary_layers: 

208 new_nodes = nest.flatten([ 

209 layer.inbound_nodes[1:] 

210 if functional._should_skip_first_node(layer) 

211 else layer.inbound_nodes for layer in created_layers.values() 

212 ]) 

213 _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes) 

214 return model 

215 

216 

217def _clone_layers_and_model_config(model, input_layers, layer_fn): 

218 """Clones all layers, and returns the model config without serializing layers. 

219 

220 This function ensures that only the node graph is retrieved when getting the 

221 model config. The `layer_fn` used to clone layers might not rely on 

222 `layer.get_config()`, so some custom layers do not define `get_config`. 

223 Trying to retrieve the config results in errors. 

224 

225 Args: 

226 model: A Functional model. 

227 input_layers: Dictionary mapping input layers in `model` to new input layers 

228 layer_fn: Function used to clone all non-input layers. 

229 

230 Returns: 

231 Model config object, and a dictionary of newly created layers. 

232 """ 

233 created_layers = {} 

234 def _copy_layer(layer): 

235 # Whenever the network config attempts to get the layer serialization, 

236 # return a dummy dictionary. 

237 if layer in input_layers: 

238 created_layers[layer.name] = input_layers[layer] 

239 elif layer in model._input_layers: 

240 created_layers[layer.name] = InputLayer(**layer.get_config()) 

241 else: 

242 created_layers[layer.name] = layer_fn(layer) 

243 return {} 

244 

245 config = functional.get_network_config( 

246 model, serialize_layer_fn=_copy_layer) 

247 return config, created_layers 

248 

249 

250def _remove_ancillary_layers(model, layer_map, layers): 

251 """Removes and returns any ancillary layers from `layers` based on `model`. 

252 

253 Ancillary layers are part of the model topology but not used to compute the 

254 model outputs, e.g., layers from `add_loss` and `add_metric`. 

255 

256 Args: 

257 model: A Keras Model. 

258 layer_map: A map to from layers in the `model` to those in `layers`. 

259 layers: A list of all layers. 

260 

261 Returns: 

262 Two lists of layers: (1) `layers` with the ancillary layers removed, and (2) 

263 the ancillary layers. 

264 """ 

265 ancillary_layers = [] # Additional layers for computing losses and metrics. 

266 if not model._is_graph_network: 

267 return layers, ancillary_layers 

268 

269 # Ancillary layers are those with depth < 0. 

270 depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0] 

271 depths.sort(reverse=True) # Order topologically from inputs to outputs. 

272 for depth in depths: 

273 for node in model._nodes_by_depth[depth]: 

274 ancillary_layers.append(layer_map[node.outbound_layer]) 

275 

276 return [l for l in layers if l not in ancillary_layers], ancillary_layers 

277 

278 

279def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer): 

280 """Clone a `Sequential` model instance. 

281 

282 Model cloning is similar to calling a model on new inputs, 

283 except that it creates new layers (and thus new weights) instead 

284 of sharing the weights of the existing layers. 

285 

286 Args: 

287 model: Instance of `Sequential`. 

288 input_tensors: optional list of input tensors 

289 to build the model upon. If not provided, 

290 placeholders will be created. 

291 layer_fn: callable to be applied on non-input layers in the model. By 

292 default it clones the layer. Another example is to preserve the layer 

293 to share the weights. This is required when we create a per-replica 

294 copy of the model with distribution strategy; we want the weights to 

295 be shared but still feed inputs separately so we create new input 

296 layers. 

297 

298 Returns: 

299 An instance of `Sequential` reproducing the behavior 

300 of the original model, on top of new inputs tensors, 

301 using newly instantiated weights. 

302 

303 Raises: 

304 ValueError: in case of invalid `model` argument value or `layer_fn` 

305 argument value. 

306 """ 

307 if not isinstance(model, Sequential): 

308 raise ValueError('Expected `model` argument ' 

309 'to be a `Sequential` model instance, ' 

310 'but got:', model) 

311 

312 if not callable(layer_fn): 

313 raise ValueError('Expected `layer_fn` argument to be a callable.') 

314 

315 layers = [] # Layers needed to compute the model's outputs. 

316 layer_map = {} 

317 # Ensure that all layers are cloned. The model's layers 

318 # property will exclude the initial InputLayer (if it exists) in the model, 

319 # resulting in a different Sequential model structure. 

320 for layer in model._flatten_layers(include_self=False, recursive=False): 

321 if isinstance(layer, InputLayer) and input_tensors is not None: 

322 # If input tensors are provided, the original model's InputLayer is 

323 # overwritten with a different InputLayer. 

324 continue 

325 cloned_layer = ( 

326 _clone_layer(layer) 

327 if isinstance(layer, InputLayer) else layer_fn(layer)) 

328 layers.append(cloned_layer) 

329 layer_map[layer] = cloned_layer 

330 layers, ancillary_layers = _remove_ancillary_layers(model, layer_map, layers) 

331 

332 if input_tensors is None: 

333 cloned_model = Sequential(layers=layers, name=model.name) 

334 elif len(generic_utils.to_list(input_tensors)) != 1: 

335 raise ValueError('To clone a `Sequential` model, we expect ' 

336 ' at most one tensor ' 

337 'as part of `input_tensors`.') 

338 else: 

339 # Overwrite the original model's input layer. 

340 if isinstance(input_tensors, tuple): 

341 input_tensors = list(input_tensors) 

342 x = generic_utils.to_list(input_tensors)[0] 

343 if backend.is_keras_tensor(x): 

344 origin_layer = x._keras_history.layer 

345 if isinstance(origin_layer, InputLayer): 

346 cloned_model = Sequential( 

347 layers=[origin_layer] + layers, name=model.name) 

348 else: 

349 raise ValueError('Cannot clone a `Sequential` model on top ' 

350 'of a tensor that comes from a Keras layer ' 

351 'other than an `InputLayer`. ' 

352 'Use the functional API instead.') 

353 else: 

354 input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name)) 

355 input_layer = input_tensor._keras_history.layer 

356 cloned_model = Sequential(layers=[input_layer] + layers, name=model.name) 

357 

358 if not ancillary_layers: 

359 return cloned_model 

360 

361 tensor_map = {} # Maps tensors from `model` to those in `cloned_model`. 

362 for depth, cloned_nodes in cloned_model._nodes_by_depth.items(): 

363 nodes = model._nodes_by_depth[depth] 

364 # This should be safe in a Sequential model. In an arbitrary network, you 

365 # need to sort using the outbound layer of the node as a key. 

366 for cloned_node, node in zip(cloned_nodes, nodes): 

367 if isinstance(cloned_node.output_tensors, list): 

368 for j, output_tensor in enumerate(cloned_node.output_tensors): 

369 tensor_map[node.output_tensors[j]] = output_tensor 

370 else: 

371 tensor_map[node.output_tensors] = cloned_node.output_tensors 

372 # Ancillary nodes have negative depth. 

373 new_nodes = _make_new_nodes( 

374 { 

375 depth: nodes 

376 for depth, nodes in model._nodes_by_depth.items() 

377 if depth < 0 

378 }, layer_fn, layer_map, tensor_map) 

379 _insert_ancillary_layers(cloned_model, ancillary_layers, model.metrics_names, 

380 new_nodes) 

381 return cloned_model 

382 

383 

384@keras_export('keras.models.clone_model') 

385def clone_model(model, input_tensors=None, clone_function=None): 

386 """Clone a Functional or Sequential `Model` instance. 

387 

388 Model cloning is similar to calling a model on new inputs, 

389 except that it creates new layers (and thus new weights) instead 

390 of sharing the weights of the existing layers. 

391 

392 Note that 

393 `clone_model` will not preserve the uniqueness of shared objects within the 

394 model (e.g. a single variable attached to two distinct layers will be 

395 restored as two separate variables). 

396 

397 Args: 

398 model: Instance of `Model` 

399 (could be a Functional model or a Sequential model). 

400 input_tensors: optional list of input tensors or InputLayer objects 

401 to build the model upon. If not provided, 

402 new `Input` objects will be created. 

403 clone_function: Callable to be used to clone each layer in the target 

404 model (except `InputLayer` instances). It takes as argument the layer 

405 instance to be cloned, and returns the corresponding layer instance to 

406 be used in the model copy. If unspecified, this callable defaults to 

407 the following serialization/deserialization function: 

408 `lambda layer: layer.__class__.from_config(layer.get_config())`. 

409 By passing a custom callable, you can customize your copy of the 

410 model, e.g. by wrapping certain layers of interest (you might want to 

411 replace all `LSTM` instances with equivalent 

412 `Bidirectional(LSTM(...))` instances, for example). 

413 

414 Returns: 

415 An instance of `Model` reproducing the behavior 

416 of the original model, on top of new inputs tensors, 

417 using newly instantiated weights. The cloned model may behave 

418 differently from the original model if a custom `clone_function` 

419 modifies the layer. 

420 

421 Example: 

422 

423 ```python 

424 # Create a test Sequential model. 

425 model = keras.Sequential([ 

426 keras.Input(shape=(728,)), 

427 keras.layers.Dense(32, activation='relu'), 

428 keras.layers.Dense(1, activation='sigmoid'), 

429 ]) 

430 # Create a copy of the test model (with freshly initialized weights). 

431 new_model = clone_model(model) 

432 ``` 

433 

434 Note that subclassed models cannot be cloned, since their internal 

435 layer structure is not known. To achieve equivalent functionality 

436 as `clone_model` in the case of a subclassed model, simply make sure 

437 that the model class implements `get_config()` 

438 (and optionally `from_config()`), and call: 

439 

440 ```python 

441 new_model = model.__class__.from_config(model.get_config()) 

442 ``` 

443 """ 

444 with generic_utils.DisableSharedObjectScope(): 

445 if clone_function is None: 

446 clone_function = _clone_layer 

447 

448 if isinstance(model, Sequential): 

449 return _clone_sequential_model( 

450 model, input_tensors=input_tensors, layer_fn=clone_function) 

451 else: 

452 return _clone_functional_model( 

453 model, input_tensors=input_tensors, layer_fn=clone_function) 

454 

455 

456# "Clone" a subclassed model by reseting all of the attributes. 

457def _in_place_subclassed_model_reset(model): 

458 """Substitute for model cloning that works for subclassed models. 

459 

460 Subclassed models cannot be cloned because their topology is not serializable. 

461 To "instantiate" an identical model in a new TF graph, we reuse the original 

462 model object, but we clear its state. 

463 

464 After calling this function on a model instance, you can use the model 

465 instance as if it were a model clone (in particular you can use it in a new 

466 graph). 

467 

468 This method clears the state of the input model. It is thus destructive. 

469 However the original state can be restored fully by calling 

470 `_in_place_subclassed_model_state_restoration`. 

471 

472 Args: 

473 model: Instance of a Keras model created via subclassing. 

474 

475 Raises: 

476 ValueError: In case the model uses a subclassed model as inner layer. 

477 """ 

478 assert not model._is_graph_network # Only makes sense for subclassed networks 

479 # Select correct base class for new Model. 

480 version_utils.swap_class(model.__class__, training.Model, training_v1.Model, 

481 ops.executing_eagerly_outside_functions()) 

482 # Retrieve all layers tracked by the model as well as their attribute names 

483 attributes_cache = {} 

484 for name in dir(model): 

485 # Skip attrs that track other trackables. 

486 if name == 'submodules' or name == '_self_tracked_trackables': 

487 continue 

488 

489 try: 

490 value = getattr(model, name) 

491 except (AttributeError, ValueError, TypeError): 

492 continue 

493 if isinstance(value, Layer): 

494 attributes_cache[name] = value 

495 assert value in model.layers 

496 if hasattr(value, 'layers') and value.layers: 

497 raise ValueError('We do not support the use of nested layers ' 

498 'in `model_to_estimator` at this time. Found nested ' 

499 'layer: %s' % value) 

500 elif isinstance( 

501 value, (list, tuple)) and name not in ('layers', '_layers', 'metrics', 

502 '_compile_metric_functions', 

503 '_output_loss_metrics'): 

504 # Handle case: list/tuple of layers (also tracked by the Network API). 

505 if value and all(isinstance(val, Layer) for val in value): 

506 raise ValueError('We do not support the use of list-of-layers ' 

507 'attributes in subclassed models used with ' 

508 '`model_to_estimator` at this time. Found list ' 

509 'model: %s' % name) 

510 

511 # Replace layers on the model with fresh layers 

512 layers_to_names = {value: key for key, value in attributes_cache.items()} 

513 original_layers = list( 

514 model._flatten_layers(include_self=False, recursive=False)) 

515 setattr_tracking = model._setattr_tracking 

516 model._setattr_tracking = False 

517 model._self_tracked_trackables = [] 

518 for layer in original_layers: # We preserve layer order. 

519 config = layer.get_config() 

520 # This will not work for nested subclassed models used as layers. 

521 # This would be theoretically possible to support, but would add complexity. 

522 # Only do it if users complain. 

523 if isinstance(layer, training.Model) and not layer._is_graph_network: 

524 raise ValueError('We do not support the use of nested subclassed models ' 

525 'in `model_to_estimator` at this time. Found nested ' 

526 'model: %s' % layer) 

527 fresh_layer = layer.__class__.from_config(config) 

528 name = layers_to_names[layer] 

529 setattr(model, name, fresh_layer) 

530 model._self_tracked_trackables.append(fresh_layer) 

531 

532 # Cache original model build attributes (in addition to layers) 

533 if (not hasattr(model, '_original_attributes_cache') or 

534 model._original_attributes_cache is None): 

535 if model.built: 

536 attributes_to_cache = [ 

537 'inputs', 

538 'outputs', 

539 'total_loss', 

540 'optimizer', 

541 'train_function', 

542 'test_function', 

543 'predict_function', 

544 '_training_endpoints', 

545 '_collected_trainable_weights', 

546 '_feed_inputs', 

547 '_feed_input_names', 

548 '_feed_input_shapes', 

549 ] 

550 for name in attributes_to_cache: 

551 attributes_cache[name] = getattr(model, name) 

552 model._original_attributes_cache = attributes_cache 

553 _reset_build_compile_trackers(model) 

554 model._setattr_tracking = setattr_tracking 

555 

556 

557def _reset_build_compile_trackers(model): 

558 """Reset state trackers for model. 

559 

560 Note that we do not actually zero out attributes such as optimizer, 

561 but instead rely on the expectation that all of the attrs will be 

562 over-written on calling build/compile/etc. This is somewhat fragile, 

563 insofar as we check elsewhere for the presence of these attributes as 

564 evidence of having been built/compiled/etc. Pending a better way to do this, 

565 we reset key attributes here to allow building and compiling. 

566 

567 Args: 

568 model: the model that is being reset 

569 """ 

570 # Reset build state 

571 model.built = False 

572 model.inputs = None 

573 model.outputs = None 

574 # Reset compile state 

575 model._is_compiled = False # pylint:disable=protected-access 

576 if not ops.executing_eagerly_outside_functions(): 

577 model._v1_compile_was_called = False 

578 model.optimizer = None 

579 

580 

581@keras_export( 

582 'keras.__internal__.models.in_place_subclassed_model_state_restoration', 

583 v1=[]) 

584def in_place_subclassed_model_state_restoration(model): 

585 """Restores the original state of a model after it was "reset". 

586 

587 This undoes this action of `_in_place_subclassed_model_reset`, which is called 

588 in `clone_and_build_model` if `in_place_reset` is set to True. 

589 

590 Args: 

591 model: Instance of a Keras model created via subclassing, on which 

592 `_in_place_subclassed_model_reset` was previously called. 

593 """ 

594 assert not model._is_graph_network 

595 # Restore layers and build attributes 

596 if (hasattr(model, '_original_attributes_cache') and 

597 model._original_attributes_cache is not None): 

598 # Models have sticky attribute assignment, so we want to be careful to add 

599 # back the previous attributes and track Layers by their original names 

600 # without adding dependencies on "utility" attributes which Models exempt 

601 # when they're constructed. 

602 setattr_tracking = model._setattr_tracking 

603 model._setattr_tracking = False 

604 model._self_tracked_trackables = [] 

605 for name, value in model._original_attributes_cache.items(): 

606 setattr(model, name, value) 

607 if isinstance(value, Layer): 

608 model._self_tracked_trackables.append(value) 

609 model._original_attributes_cache = None 

610 model._setattr_tracking = setattr_tracking 

611 else: 

612 # Restore to the state of a never-called model. 

613 _reset_build_compile_trackers(model) 

614 

615 

616@keras_export('keras.__internal__.models.clone_and_build_model', v1=[]) 

617def clone_and_build_model( 

618 model, input_tensors=None, target_tensors=None, custom_objects=None, 

619 compile_clone=True, in_place_reset=False, optimizer_iterations=None, 

620 optimizer_config=None): 

621 """Clone a `Model` and build/compile it with the same settings used before. 

622 

623 This function can be run in the same graph or in a separate graph from the 

624 model. When using a separate graph, `in_place_reset` must be `False`. 

625 

626 Note that, currently, the clone produced from this function may not work with 

627 TPU DistributionStrategy. Try at your own risk. 

628 

629 Args: 

630 model: `tf.keras.Model` object. Can be Functional, Sequential, or 

631 sub-classed. 

632 input_tensors: Optional list or dictionary of input tensors to build the 

633 model upon. If not provided, placeholders will be created. 

634 target_tensors: Optional list of target tensors for compiling the model. If 

635 not provided, placeholders will be created. 

636 custom_objects: Optional dictionary mapping string names to custom classes 

637 or functions. 

638 compile_clone: Boolean, whether to compile model clone (default `True`). 

639 in_place_reset: Boolean, whether to reset the model in place. Only used if 

640 the model is a subclassed model. In the case of a subclassed model, 

641 this argument must be set to `True` (default `False`). To restore the 

642 original model, use the function 

643 `in_place_subclassed_model_state_restoration(model)`. 

644 optimizer_iterations: An iterations variable that will be incremented by the 

645 optimizer if the clone is compiled. This argument is used when a Keras 

646 model is cloned into an Estimator model function, because Estimators 

647 create their own global step variable. 

648 optimizer_config: Optimizer config dictionary or list of dictionary 

649 returned from `get_config()`. This argument should be defined if 

650 `clone_and_build_model` is called in a different graph or session from 

651 the original model, and the optimizer is an instance of `OptimizerV2`. 

652 

653 Returns: 

654 Clone of the model. 

655 

656 Raises: 

657 ValueError: Cloning fails in the following cases 

658 - cloning a subclassed model with `in_place_reset` set to False. 

659 - compiling the clone when the original model has not been compiled. 

660 """ 

661 # Grab optimizer now, as we reset-in-place for subclassed models, but 

662 # want to maintain access to the original optimizer. 

663 orig_optimizer = model.optimizer 

664 if compile_clone and not orig_optimizer: 

665 raise ValueError( 

666 'Error when cloning model: compile_clone was set to True, but the ' 

667 'original model has not been compiled.') 

668 

669 if compile_clone: 

670 compile_args = model._get_compile_args() # pylint: disable=protected-access 

671 # Allows this method to be robust to switching graph and eager classes. 

672 model._get_compile_args = lambda: compile_args 

673 

674 with CustomObjectScope(custom_objects or {}): 

675 if model._is_graph_network: 

676 clone = clone_model(model, input_tensors=input_tensors) 

677 elif isinstance(model, Sequential): 

678 clone = clone_model(model, input_tensors=input_tensors) 

679 if (not clone._is_graph_network and model._build_input_shape is not None): 

680 if ops.executing_eagerly_outside_functions(): 

681 clone.build(model._build_input_shape) 

682 else: 

683 clone._set_inputs( 

684 backend.placeholder( 

685 model._build_input_shape, dtype=model.inputs[0].dtype)) 

686 else: 

687 try: 

688 # Prefer cloning the model if serial/deserial logic is implemented for 

689 # subclassed model. 

690 clone = model.__class__.from_config(model.get_config()) 

691 except NotImplementedError: 

692 logging.warning('This model is a subclassed model. Please implement ' 

693 '`get_config` and `from_config` to better support ' 

694 'cloning the model.') 

695 if not in_place_reset: 

696 raise ValueError( 

697 'This model is a subclassed model. ' 

698 'Such a model cannot be cloned, but there is a workaround where ' 

699 'the model is reset in-place. To use this, please set the ' 

700 'argument `in_place_reset` to `True`. This will reset the ' 

701 'attributes in the original model. To restore the attributes, ' 

702 'call `in_place_subclassed_model_state_restoration(model)`.') 

703 clone = model 

704 _in_place_subclassed_model_reset(clone) 

705 if input_tensors is not None: 

706 if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1: 

707 input_tensors = input_tensors[0] 

708 clone._set_inputs(input_tensors) 

709 

710 if compile_clone: 

711 if isinstance(orig_optimizer, optimizer_v1.TFOptimizer): 

712 optimizer = optimizer_v1.TFOptimizer( 

713 orig_optimizer.optimizer, optimizer_iterations) 

714 backend.track_tf_optimizer(optimizer) 

715 else: 

716 if not isinstance(orig_optimizer, (tuple, list)): 

717 orig_optimizer = [orig_optimizer] 

718 if optimizer_config is None: 

719 optimizer = [ 

720 opt.__class__.from_config(opt.get_config()) 

721 for opt in orig_optimizer 

722 ] 

723 elif isinstance(optimizer_config, dict): 

724 optimizer = [orig_optimizer[0].__class__.from_config(optimizer_config)] 

725 else: 

726 # optimizer config is list of dict, same order as orig_optimizer. 

727 optimizer = [ 

728 opt.__class__.from_config(opt_config) 

729 for (opt, opt_config) in zip(orig_optimizer, optimizer_config) 

730 ] 

731 if optimizer_iterations is not None: 

732 for opt in optimizer: 

733 opt.iterations = optimizer_iterations 

734 

735 if len(optimizer) == 1: 

736 optimizer = optimizer[0] 

737 

738 compile_args['optimizer'] = optimizer 

739 if target_tensors is not None: 

740 compile_args['target_tensors'] = target_tensors 

741 # Ensure Metric objects in new model are separate from existing model. 

742 compile_args['metrics'] = metrics_module.clone_metrics( 

743 compile_args['metrics']) 

744 compile_args['weighted_metrics'] = metrics_module.clone_metrics( 

745 compile_args['weighted_metrics']) 

746 clone.compile(**compile_args) 

747 

748 return clone