Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/models/cloning.py: 14%

281 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Code for model cloning, plus model-related API entries.""" 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src import metrics as metrics_module 

22from keras.src.engine import functional 

23from keras.src.engine import sequential 

24from keras.src.engine import training 

25from keras.src.engine import training_v1 

26from keras.src.engine.base_layer import AddMetric 

27from keras.src.engine.base_layer import Layer 

28from keras.src.engine.input_layer import Input 

29from keras.src.engine.input_layer import InputLayer 

30from keras.src.optimizers import optimizer_v1 

31from keras.src.saving.legacy import serialization 

32from keras.src.saving.legacy.saved_model.utils import keras_option_scope 

33from keras.src.saving.object_registration import CustomObjectScope 

34from keras.src.utils import generic_utils 

35from keras.src.utils import version_utils 

36 

37# isort: off 

38from tensorflow.python.platform import tf_logging as logging 

39from tensorflow.python.util.tf_export import keras_export 

40 

41# API entries importable from `keras.models`: 

42Model = training.Model 

43Sequential = sequential.Sequential 

44 

45 

46# Callable used to clone a layer with weights preserved. 

47def share_weights(layer): 

48 return layer 

49 

50 

51def _clone_layer(layer): 

52 return layer.__class__.from_config(layer.get_config()) 

53 

54 

55def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes): 

56 """Inserts ancillary layers into the model with the proper order.""" 

57 # Sort `AddMetric` layers so they agree with metrics_names. 

58 metric_layers = [ 

59 layer for layer in ancillary_layers if isinstance(layer, AddMetric) 

60 ] 

61 metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name)) 

62 ancillary_layers = [ 

63 layer for layer in ancillary_layers if not isinstance(layer, AddMetric) 

64 ] + metric_layers 

65 model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes)) 

66 

67 

68def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map): 

69 """Make new nodes with the layers in `layer_map` based on `nodes_by_depth`. 

70 

71 Args: 

72 nodes_by_depth: Provides structure information to create new nodes. 

73 layer_fn: Function to clone layers. 

74 layer_map: Map from layers in `model` to new layers. 

75 tensor_map: Map from tensors in `model` to newly compute tensors. 

76 

77 Returns: 

78 A set of new nodes. `layer_map` and `tensor_map` are updated. 

79 """ 

80 # Iterated over every node in the reference model, in depth order. 

81 new_nodes = set() 

82 depth_keys = list(nodes_by_depth.keys()) 

83 depth_keys.sort(reverse=True) 

84 for depth in depth_keys: 

85 nodes = nodes_by_depth[depth] 

86 for node in nodes: 

87 # Recover the corresponding layer. 

88 layer = node.outbound_layer 

89 

90 # Get or create layer. 

91 if layer not in layer_map: 

92 new_layer = layer_fn(layer) 

93 layer_map[layer] = new_layer 

94 layer = new_layer 

95 else: 

96 # Reuse previously cloned layer. 

97 layer = layer_map[layer] 

98 # Don't call InputLayer multiple times. 

99 if isinstance(layer, InputLayer): 

100 continue 

101 

102 # If all previous input tensors are available in tensor_map, 

103 # then call node.inbound_layer on them. 

104 if all( 

105 tensor in tensor_map 

106 for tensor in tf.nest.flatten(node.input_tensors) 

107 ): 

108 # Call layer. 

109 args = tf.nest.map_structure( 

110 lambda t: tensor_map.get(t, t), node.call_args 

111 ) 

112 kwargs = tf.nest.map_structure( 

113 lambda t: tensor_map.get(t, t), node.call_kwargs 

114 ) 

115 output_tensors = layer(*args, **kwargs) 

116 

117 # Thread-safe way to keep track of what node was created. 

118 first_output_tensor = tf.nest.flatten(output_tensors)[0] 

119 new_nodes.add( 

120 layer._inbound_nodes[ 

121 first_output_tensor._keras_history.node_index 

122 ] 

123 ) 

124 

125 for x, y in zip( 

126 tf.nest.flatten(node.output_tensors), 

127 tf.nest.flatten(output_tensors), 

128 ): 

129 tensor_map[x] = y 

130 return new_nodes 

131 

132 

133def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer): 

134 """Clone a functional `Model` instance. 

135 

136 Model cloning is similar to calling a model on new inputs, 

137 except that it creates new layers (and thus new weights) instead 

138 of sharing the weights of the existing layers. 

139 

140 Input layers are always cloned. 

141 

142 Args: 

143 model: Instance of `Model`. 

144 input_tensors: optional list of input tensors 

145 to build the model upon. If not provided, 

146 placeholders will be created. 

147 layer_fn: callable to be applied on non-input layers in the model. By 

148 default it clones the layer. Another example is to preserve the 

149 layer to share the weights. This is required when we create a 

150 per-replica copy of the model with distribution strategy; we want 

151 the weights to be shared but still feed inputs separately so we 

152 create new input layers. 

153 

154 Returns: 

155 An instance of `Model` reproducing the behavior 

156 of the original model, on top of new inputs tensors, 

157 using newly instantiated weights. 

158 

159 Raises: 

160 ValueError: in case of invalid `model` argument value or `layer_fn` 

161 argument value. 

162 """ 

163 if layer_fn is None: 

164 layer_fn = _clone_layer 

165 

166 if not isinstance(model, Model): 

167 raise ValueError( 

168 "Expected `model` argument " 

169 f"to be a `Model` instance. Received: model={model}" 

170 ) 

171 if isinstance(model, Sequential): 

172 raise ValueError( 

173 "Expected `model` argument " 

174 "to be a functional `Model` instance, " 

175 f"got a `Sequential` instance instead: {model}" 

176 ) 

177 if not model._is_graph_network: 

178 raise ValueError( 

179 "Expected `model` argument " 

180 "to be a functional `Model` instance, " 

181 f"but got a subclassed model instead: {model}" 

182 ) 

183 

184 new_input_layers = {} # Cache for created layers. 

185 if input_tensors is not None: 

186 # Make sure that all input tensors come from a Keras layer. 

187 input_tensors = tf.nest.flatten(input_tensors) 

188 for i, input_tensor in enumerate(input_tensors): 

189 original_input_layer = model._input_layers[i] 

190 

191 # Cache input layer. Create a new layer if the tensor is originally 

192 # not from a Keras layer. 

193 if not backend.is_keras_tensor(input_tensor): 

194 name = original_input_layer.name 

195 input_tensor = Input( 

196 tensor=input_tensor, name="input_wrapper_for_" + name 

197 ) 

198 newly_created_input_layer = input_tensor._keras_history.layer 

199 new_input_layers[ 

200 original_input_layer 

201 ] = newly_created_input_layer 

202 else: 

203 new_input_layers[ 

204 original_input_layer 

205 ] = input_tensor._keras_history.layer 

206 

207 if not callable(layer_fn): 

208 raise ValueError( 

209 "Expected `layer_fn` argument to be a callable. " 

210 f"Received: layer_fn={layer_fn}" 

211 ) 

212 

213 # For affected g3 users who need to default to old serialization in cloning 

214 if getattr(model, "use_legacy_config", False): 

215 with keras_option_scope( 

216 save_traces=False, in_tf_saved_model_scope=True 

217 ): 

218 model_configs, created_layers = _clone_layers_and_model_config( 

219 model, new_input_layers, layer_fn 

220 ) 

221 else: 

222 model_configs, created_layers = _clone_layers_and_model_config( 

223 model, new_input_layers, layer_fn 

224 ) 

225 # Reconstruct model from the config, using the cloned layers. 

226 ( 

227 input_tensors, 

228 output_tensors, 

229 created_layers, 

230 ) = functional.reconstruct_from_config( 

231 model_configs, created_layers=created_layers 

232 ) 

233 metrics_names = model.metrics_names 

234 if functional.has_functional_like_constructor(model.__class__): 

235 new_model = model.__class__( 

236 input_tensors, output_tensors, name=model.name 

237 ) 

238 else: 

239 # This may be incorrect: the new model will end up having a different 

240 # class than the original. However various existing models rely 

241 # on this behavior, so we keep it. 

242 new_model = Model(input_tensors, output_tensors, name=model.name) 

243 

244 # Layers not directly tied to outputs of the Model, such as loss layers 

245 # created in `add_loss` and `add_metric`. 

246 ancillary_layers = [ 

247 layer 

248 for layer in created_layers.values() 

249 if layer not in new_model.layers 

250 ] 

251 # TODO(b/162887610): This may need to adjust the inbound node index if the 

252 # created layers had already been used to define other models. 

253 if ancillary_layers: 

254 new_nodes = tf.nest.flatten( 

255 [ 

256 layer.inbound_nodes[1:] 

257 if functional._should_skip_first_node(layer) 

258 else layer.inbound_nodes 

259 for layer in created_layers.values() 

260 ] 

261 ) 

262 _insert_ancillary_layers( 

263 new_model, ancillary_layers, metrics_names, new_nodes 

264 ) 

265 return new_model 

266 

267 

268def _clone_layers_and_model_config(model, input_layers, layer_fn): 

269 """Clones all layers; returns the model config without serializing layers. 

270 

271 This function ensures that only the node graph is retrieved when getting the 

272 model config. The `layer_fn` used to clone layers might not rely on 

273 `layer.get_config()`, so some custom layers do not define `get_config`. 

274 Trying to retrieve the config results in errors. 

275 

276 Args: 

277 model: A Functional model. 

278 input_layers: Dictionary mapping input layers in `model` to new input 

279 layers. 

280 layer_fn: Function used to clone all non-input layers. 

281 

282 Returns: 

283 Model config object, and a dictionary of newly created layers. 

284 """ 

285 created_layers = {} 

286 

287 def _copy_layer(layer): 

288 # Whenever the network config attempts to get the layer serialization, 

289 # return a dummy dictionary. 

290 if layer in input_layers: 

291 created_layers[layer.name] = input_layers[layer] 

292 elif layer in model._input_layers: 

293 created_layers[layer.name] = InputLayer(**layer.get_config()) 

294 else: 

295 created_layers[layer.name] = layer_fn(layer) 

296 return {} 

297 

298 config = functional.get_network_config( 

299 model, serialize_layer_fn=_copy_layer 

300 ) 

301 return config, created_layers 

302 

303 

304def _remove_ancillary_layers(model, layer_map, layers): 

305 """Removes and returns any ancillary layers from `layers` based on `model`. 

306 

307 Ancillary layers are part of the model topology but not used to compute the 

308 model outputs, e.g., layers from `add_loss` and `add_metric`. 

309 

310 Args: 

311 model: A Keras Model. 

312 layer_map: A map to from layers in the `model` to those in `layers`. 

313 layers: A list of all layers. 

314 

315 Returns: 

316 Two lists of layers: (1) `layers` with the ancillary layers removed, and 

317 (2) the ancillary layers. 

318 """ 

319 ancillary_layers = [] # Additional layers for computing losses and metrics. 

320 if not model._is_graph_network: 

321 return layers, ancillary_layers 

322 

323 # Ancillary layers are those with depth < 0. 

324 depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0] 

325 depths.sort(reverse=True) # Order topologically from inputs to outputs. 

326 for depth in depths: 

327 for node in model._nodes_by_depth[depth]: 

328 ancillary_layers.append(layer_map[node.outbound_layer]) 

329 

330 return [l for l in layers if l not in ancillary_layers], ancillary_layers 

331 

332 

333def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer): 

334 """Clone a `Sequential` model instance. 

335 

336 Model cloning is similar to calling a model on new inputs, 

337 except that it creates new layers (and thus new weights) instead 

338 of sharing the weights of the existing layers. 

339 

340 Args: 

341 model: Instance of `Sequential`. 

342 input_tensors: optional list of input tensors 

343 to build the model upon. If not provided, 

344 placeholders will be created. 

345 layer_fn: callable to be applied on non-input layers in the model. By 

346 default it clones the layer. Another example is to preserve the 

347 layer to share the weights. This is required when we create a 

348 per-replica copy of the model with distribution strategy; we want 

349 the weights to be shared but still feed inputs separately so we 

350 create new input layers. 

351 

352 Returns: 

353 An instance of `Sequential` reproducing the behavior 

354 of the original model, on top of new inputs tensors, 

355 using newly instantiated weights. 

356 

357 Raises: 

358 ValueError: in case of invalid `model` argument value or `layer_fn` 

359 argument value. 

360 """ 

361 if layer_fn is None: 

362 layer_fn = _clone_layer 

363 

364 if not isinstance(model, Sequential): 

365 raise ValueError( 

366 "Expected `model` argument " 

367 "to be a `Sequential` model instance. " 

368 f"Received: model={model}" 

369 ) 

370 

371 if not callable(layer_fn): 

372 raise ValueError( 

373 "Expected `layer_fn` argument to be a callable. " 

374 f"Received: layer_fn={layer_fn}" 

375 ) 

376 

377 layers = [] # Layers needed to compute the model's outputs. 

378 layer_map = {} 

379 # Ensure that all layers are cloned. The model's layers 

380 # property will exclude the initial InputLayer (if it exists) in the model, 

381 # resulting in a different Sequential model structure. 

382 for layer in model._flatten_layers(include_self=False, recursive=False): 

383 if isinstance(layer, InputLayer) and input_tensors is not None: 

384 # If input tensors are provided, the original model's InputLayer is 

385 # overwritten with a different InputLayer. 

386 continue 

387 cloned_layer = ( 

388 _clone_layer(layer) 

389 if isinstance(layer, InputLayer) 

390 else layer_fn(layer) 

391 ) 

392 layers.append(cloned_layer) 

393 layer_map[layer] = cloned_layer 

394 layers, ancillary_layers = _remove_ancillary_layers( 

395 model, layer_map, layers 

396 ) 

397 

398 if input_tensors is None: 

399 cloned_model = Sequential(layers=layers, name=model.name) 

400 elif len(generic_utils.to_list(input_tensors)) != 1: 

401 raise ValueError( 

402 "To clone a `Sequential` model, we expect at most one tensor as " 

403 f"part of `input_tensors`. Received: input_tensors={input_tensors}" 

404 ) 

405 else: 

406 # Overwrite the original model's input layer. 

407 if isinstance(input_tensors, tuple): 

408 input_tensors = list(input_tensors) 

409 x = generic_utils.to_list(input_tensors)[0] 

410 if backend.is_keras_tensor(x): 

411 origin_layer = x._keras_history.layer 

412 if isinstance(origin_layer, InputLayer): 

413 cloned_model = Sequential( 

414 layers=[origin_layer] + layers, name=model.name 

415 ) 

416 else: 

417 raise ValueError( 

418 "Cannot clone a `Sequential` model on top " 

419 "of a tensor that comes from a Keras layer " 

420 "other than an `InputLayer`. " 

421 "Use the Functional API instead. " 

422 f"Received: input_tensors={input_tensors}" 

423 ) 

424 else: 

425 input_tensor = Input( 

426 tensor=x, name="input_wrapper_for_" + str(x.name) 

427 ) 

428 input_layer = input_tensor._keras_history.layer 

429 cloned_model = Sequential( 

430 layers=[input_layer] + layers, name=model.name 

431 ) 

432 

433 if not ancillary_layers: 

434 return cloned_model 

435 

436 tensor_map = {} # Maps tensors from `model` to those in `cloned_model`. 

437 for depth, cloned_nodes in cloned_model._nodes_by_depth.items(): 

438 nodes = model._nodes_by_depth[depth] 

439 # This should be safe in a Sequential model. In an arbitrary network, 

440 # you need to sort using the outbound layer of the node as a key. 

441 for cloned_node, node in zip(cloned_nodes, nodes): 

442 if isinstance(cloned_node.output_tensors, list): 

443 for j, output_tensor in enumerate(cloned_node.output_tensors): 

444 tensor_map[node.output_tensors[j]] = output_tensor 

445 else: 

446 tensor_map[node.output_tensors] = cloned_node.output_tensors 

447 # Ancillary nodes have negative depth. 

448 new_nodes = _make_new_nodes( 

449 { 

450 depth: nodes 

451 for depth, nodes in model._nodes_by_depth.items() 

452 if depth < 0 

453 }, 

454 layer_fn, 

455 layer_map, 

456 tensor_map, 

457 ) 

458 _insert_ancillary_layers( 

459 cloned_model, ancillary_layers, model.metrics_names, new_nodes 

460 ) 

461 return cloned_model 

462 

463 

464@keras_export("keras.models.clone_model") 

465def clone_model(model, input_tensors=None, clone_function=None): 

466 """Clone a Functional or Sequential `Model` instance. 

467 

468 Model cloning is similar to calling a model on new inputs, 

469 except that it creates new layers (and thus new weights) instead 

470 of sharing the weights of the existing layers. 

471 

472 Note that 

473 `clone_model` will not preserve the uniqueness of shared objects within the 

474 model (e.g. a single variable attached to two distinct layers will be 

475 restored as two separate variables). 

476 

477 Args: 

478 model: Instance of `Model` 

479 (could be a Functional model or a Sequential model). 

480 input_tensors: optional list of input tensors or InputLayer objects 

481 to build the model upon. If not provided, 

482 new `Input` objects will be created. 

483 clone_function: Callable to be used to clone each layer in the target 

484 model (except `InputLayer` instances). It takes as argument the 

485 layer instance to be cloned, and returns the corresponding layer 

486 instance to be used in the model copy. If unspecified, this callable 

487 defaults to the following serialization/deserialization function: 

488 `lambda layer: layer.__class__.from_config(layer.get_config())`. 

489 By passing a custom callable, you can customize your copy of the 

490 model, e.g. by wrapping certain layers of interest (you might want 

491 to replace all `LSTM` instances with equivalent 

492 `Bidirectional(LSTM(...))` instances, for example). 

493 

494 Returns: 

495 An instance of `Model` reproducing the behavior 

496 of the original model, on top of new inputs tensors, 

497 using newly instantiated weights. The cloned model may behave 

498 differently from the original model if a custom `clone_function` 

499 modifies the layer. 

500 

501 Example: 

502 

503 ```python 

504 # Create a test Sequential model. 

505 model = keras.Sequential([ 

506 keras.Input(shape=(728,)), 

507 keras.layers.Dense(32, activation='relu'), 

508 keras.layers.Dense(1, activation='sigmoid'), 

509 ]) 

510 # Create a copy of the test model (with freshly initialized weights). 

511 new_model = clone_model(model) 

512 ``` 

513 

514 Note that subclassed models cannot be cloned, since their internal 

515 layer structure is not known. To achieve equivalent functionality 

516 as `clone_model` in the case of a subclassed model, simply make sure 

517 that the model class implements `get_config()` 

518 (and optionally `from_config()`), and call: 

519 

520 ```python 

521 new_model = model.__class__.from_config(model.get_config()) 

522 ``` 

523 """ 

524 with serialization.DisableSharedObjectScope(): 

525 if isinstance(model, Sequential): 

526 return _clone_sequential_model( 

527 model, input_tensors=input_tensors, layer_fn=clone_function 

528 ) 

529 if isinstance(model, functional.Functional): 

530 # If the get_config() method is the same as a regular Functional 

531 # model, we're safe to use _clone_functional_model (which relies 

532 # on a Functional constructor). In the case where the get_config 

533 # is custom, this may not necessarily work, but if clone_function 

534 # or input_tensors are passed, we attempt it anyway 

535 # in order to preserve backwards compatibility. 

536 if generic_utils.is_default(model.get_config) or ( 

537 clone_function or input_tensors 

538 ): 

539 return _clone_functional_model( 

540 model, input_tensors=input_tensors, layer_fn=clone_function 

541 ) 

542 

543 # Case of a custom model class 

544 if clone_function or input_tensors: 

545 raise ValueError( 

546 "Arguments clone_function and input_tensors " 

547 "are only supported for Sequential models " 

548 "or Functional models. Received model of " 

549 f"type '{model.__class__.__name__}', with " 

550 f"clone_function={clone_function} and " 

551 f"input_tensors={input_tensors}" 

552 ) 

553 # Note that a custom object scope may be required in this case. 

554 return model.__class__.from_config(model.get_config()) 

555 

556 

557# "Clone" a subclassed model by resetting all of the attributes. 

558def _in_place_subclassed_model_reset(model): 

559 """Substitute for model cloning that works for subclassed models. 

560 

561 Subclassed models cannot be cloned because their topology is not 

562 serializable. To "instantiate" an identical model in a new TF graph, we 

563 reuse the original model object, but we clear its state. 

564 

565 After calling this function on a model instance, you can use the model 

566 instance as if it were a model clone (in particular you can use it in a new 

567 graph). 

568 

569 This method clears the state of the input model. It is thus destructive. 

570 However the original state can be restored fully by calling 

571 `_in_place_subclassed_model_state_restoration`. 

572 

573 Args: 

574 model: Instance of a Keras model created via subclassing. 

575 

576 Raises: 

577 ValueError: In case the model uses a subclassed model as inner layer. 

578 """ 

579 assert ( 

580 not model._is_graph_network 

581 ) # Only makes sense for subclassed networks 

582 # Select correct base class for new Model. 

583 version_utils.swap_class( 

584 model.__class__, 

585 training.Model, 

586 training_v1.Model, 

587 tf.compat.v1.executing_eagerly_outside_functions(), 

588 ) 

589 # Retrieve all layers tracked by the model as well as their attribute names 

590 attributes_cache = {} 

591 for name in dir(model): 

592 # Skip attrs that track other trackables. 

593 if name == "submodules" or name == "_self_tracked_trackables": 

594 continue 

595 

596 try: 

597 value = getattr(model, name) 

598 except (AttributeError, ValueError, TypeError): 

599 continue 

600 if isinstance(value, Layer): 

601 attributes_cache[name] = value 

602 assert value in model.layers 

603 if hasattr(value, "layers") and value.layers: 

604 raise ValueError( 

605 "We do not support the use of nested layers " 

606 "in `model_to_estimator` at this time. Found nested " 

607 f"layer: {value}" 

608 ) 

609 elif isinstance(value, (list, tuple)) and name not in ( 

610 "layers", 

611 "_layers", 

612 "metrics", 

613 "_compile_metric_functions", 

614 "_output_loss_metrics", 

615 ): 

616 # Handle case: list/tuple of layers (also tracked by the Network 

617 # API). 

618 if value and all(isinstance(val, Layer) for val in value): 

619 raise ValueError( 

620 "We do not support the use of list-of-layers " 

621 "attributes in subclassed models used with " 

622 "`model_to_estimator` at this time. Found list " 

623 f"model: {name}" 

624 ) 

625 

626 # Replace layers on the model with fresh layers 

627 layers_to_names = {value: key for key, value in attributes_cache.items()} 

628 original_layers = list( 

629 model._flatten_layers(include_self=False, recursive=False) 

630 ) 

631 setattr_tracking = model._setattr_tracking 

632 model._setattr_tracking = False 

633 model._self_tracked_trackables = [] 

634 for layer in original_layers: # We preserve layer order. 

635 config = layer.get_config() 

636 # This will not work for nested subclassed models used as layers. 

637 # This would be theoretically possible to support, but would add 

638 # complexity. Only do it if users complain. 

639 if isinstance(layer, training.Model) and not layer._is_graph_network: 

640 raise ValueError( 

641 "We do not support the use of nested subclassed models " 

642 "in `model_to_estimator` at this time. Found nested " 

643 f"model: {layer}" 

644 ) 

645 fresh_layer = layer.__class__.from_config(config) 

646 name = layers_to_names[layer] 

647 setattr(model, name, fresh_layer) 

648 model._self_tracked_trackables.append(fresh_layer) 

649 

650 # Cache original model build attributes (in addition to layers) 

651 if ( 

652 not hasattr(model, "_original_attributes_cache") 

653 or model._original_attributes_cache is None 

654 ): 

655 if model.built: 

656 attributes_to_cache = [ 

657 "inputs", 

658 "outputs", 

659 "total_loss", 

660 "optimizer", 

661 "train_function", 

662 "test_function", 

663 "predict_function", 

664 "_training_endpoints", 

665 "_collected_trainable_weights", 

666 "_feed_inputs", 

667 "_feed_input_names", 

668 "_feed_input_shapes", 

669 ] 

670 for name in attributes_to_cache: 

671 attributes_cache[name] = getattr(model, name) 

672 model._original_attributes_cache = attributes_cache 

673 _reset_build_compile_trackers(model) 

674 model._setattr_tracking = setattr_tracking 

675 

676 

677def _reset_build_compile_trackers(model): 

678 """Reset state trackers for model. 

679 

680 Note that we do not actually zero out attributes such as optimizer, 

681 but instead rely on the expectation that all of the attrs will be 

682 over-written on calling build/compile/etc. This is somewhat fragile, 

683 insofar as we check elsewhere for the presence of these attributes as 

684 evidence of having been built/compiled/etc. Pending a better way to do this, 

685 we reset key attributes here to allow building and compiling. 

686 

687 Args: 

688 model: the model that is being reset 

689 """ 

690 # Reset build state 

691 model.built = False 

692 model.inputs = None 

693 model.outputs = None 

694 # Reset compile state 

695 model._is_compiled = False 

696 if not tf.compat.v1.executing_eagerly_outside_functions(): 

697 model._v1_compile_was_called = False 

698 model.optimizer = None 

699 

700 

701@keras_export( 

702 "keras.__internal__.models.in_place_subclassed_model_state_restoration", 

703 v1=[], 

704) 

705def in_place_subclassed_model_state_restoration(model): 

706 """Restores the original state of a model after it was "reset". 

707 

708 This undoes this action of `_in_place_subclassed_model_reset`, which is 

709 called in `clone_and_build_model` if `in_place_reset` is set to True. 

710 

711 Args: 

712 model: Instance of a Keras model created via subclassing, on which 

713 `_in_place_subclassed_model_reset` was previously called. 

714 """ 

715 assert not model._is_graph_network 

716 # Restore layers and build attributes 

717 if ( 

718 hasattr(model, "_original_attributes_cache") 

719 and model._original_attributes_cache is not None 

720 ): 

721 # Models have sticky attribute assignment, so we want to be careful to 

722 # add back the previous attributes and track Layers by their original 

723 # names without adding dependencies on "utility" attributes which Models 

724 # exempt when they're constructed. 

725 setattr_tracking = model._setattr_tracking 

726 model._setattr_tracking = False 

727 model._self_tracked_trackables = [] 

728 for name, value in model._original_attributes_cache.items(): 

729 setattr(model, name, value) 

730 if isinstance(value, Layer): 

731 model._self_tracked_trackables.append(value) 

732 model._original_attributes_cache = None 

733 model._setattr_tracking = setattr_tracking 

734 else: 

735 # Restore to the state of a never-called model. 

736 _reset_build_compile_trackers(model) 

737 

738 

739@keras_export("keras.__internal__.models.clone_and_build_model", v1=[]) 

740def clone_and_build_model( 

741 model, 

742 input_tensors=None, 

743 target_tensors=None, 

744 custom_objects=None, 

745 compile_clone=True, 

746 in_place_reset=False, 

747 optimizer_iterations=None, 

748 optimizer_config=None, 

749): 

750 """Clone a `Model` and build/compile it with the same settings used before. 

751 

752 This function can be run in the same graph or in a separate graph from the 

753 model. When using a separate graph, `in_place_reset` must be `False`. 

754 

755 Note that, currently, the clone produced from this function may not work 

756 with TPU DistributionStrategy. Try at your own risk. 

757 

758 Args: 

759 model: `tf.keras.Model` object. Can be Functional, Sequential, or 

760 sub-classed. 

761 input_tensors: Optional list or dictionary of input tensors to build the 

762 model upon. If not provided, placeholders will be created. 

763 target_tensors: Optional list of target tensors for compiling the model. 

764 If not provided, placeholders will be created. 

765 custom_objects: Optional dictionary mapping string names to custom classes 

766 or functions. 

767 compile_clone: Boolean, whether to compile model clone (default `True`). 

768 in_place_reset: Boolean, whether to reset the model in place. Only used if 

769 the model is a subclassed model. In the case of a subclassed model, 

770 this argument must be set to `True` (default `False`). To restore the 

771 original model, use the function 

772 `in_place_subclassed_model_state_restoration(model)`. 

773 optimizer_iterations: An iterations variable that will be incremented by 

774 the optimizer if the clone is compiled. This argument is used when a 

775 Keras model is cloned into an Estimator model function, because 

776 Estimators create their own global step variable. 

777 optimizer_config: Optimizer config dictionary or list of dictionary 

778 returned from `get_config()`. This argument should be defined if 

779 `clone_and_build_model` is called in a different graph or session from 

780 the original model, and the optimizer is an instance of `OptimizerV2`. 

781 

782 Returns: 

783 Clone of the model. 

784 

785 Raises: 

786 ValueError: Cloning fails in the following cases 

787 - cloning a subclassed model with `in_place_reset` set to False. 

788 - compiling the clone when the original model has not been compiled. 

789 """ 

790 # Grab optimizer now, as we reset-in-place for subclassed models, but 

791 # want to maintain access to the original optimizer. 

792 orig_optimizer = model.optimizer 

793 if compile_clone and not orig_optimizer: 

794 raise ValueError( 

795 "Error when cloning model: `compile_clone` was set to True, but " 

796 f"the original model has not been compiled. Received: model={model}" 

797 ) 

798 

799 if compile_clone: 

800 compile_args = model._get_compile_args() 

801 # Allows this method to be robust to switching graph and eager classes. 

802 model._get_compile_args = lambda: compile_args 

803 

804 with CustomObjectScope(custom_objects or {}): 

805 if model._is_graph_network: 

806 clone = clone_model(model, input_tensors=input_tensors) 

807 elif isinstance(model, Sequential): 

808 clone = clone_model(model, input_tensors=input_tensors) 

809 if ( 

810 not clone._is_graph_network 

811 and model._build_input_shape is not None 

812 ): 

813 if tf.compat.v1.executing_eagerly_outside_functions(): 

814 clone.build(model._build_input_shape) 

815 else: 

816 clone._set_inputs( 

817 backend.placeholder( 

818 model._build_input_shape, 

819 dtype=model.inputs[0].dtype, 

820 ) 

821 ) 

822 else: 

823 try: 

824 # Prefer cloning the model if serial/deserial logic is 

825 # implemented for subclassed model. 

826 clone = model.__class__.from_config(model.get_config()) 

827 except NotImplementedError: 

828 logging.warning( 

829 "This model is a subclassed model. Please implement " 

830 "`get_config` and `from_config` to better support " 

831 "cloning the model." 

832 ) 

833 if not in_place_reset: 

834 raise ValueError( 

835 f"This model ({model}) is a subclassed model. " 

836 "Such a model cannot be cloned, but there is a " 

837 "workaround where the model is reset in-place. " 

838 "To use this, please set the " 

839 "argument `in_place_reset` to `True`. This will reset " 

840 "the attributes in the original model. " 

841 "To restore the attributes, call " 

842 "`in_place_subclassed_model_state_restoration(model)`." 

843 ) 

844 clone = model 

845 _in_place_subclassed_model_reset(clone) 

846 if input_tensors is not None: 

847 if ( 

848 isinstance(input_tensors, (list, tuple)) 

849 and len(input_tensors) == 1 

850 ): 

851 input_tensors = input_tensors[0] 

852 clone._set_inputs(input_tensors) 

853 

854 if compile_clone: 

855 if isinstance(orig_optimizer, optimizer_v1.TFOptimizer): 

856 optimizer = optimizer_v1.TFOptimizer( 

857 orig_optimizer.optimizer, optimizer_iterations 

858 ) 

859 backend.track_tf_optimizer(optimizer) 

860 else: 

861 if not isinstance(orig_optimizer, (tuple, list)): 

862 orig_optimizer = [orig_optimizer] 

863 if optimizer_config is None: 

864 optimizer = [ 

865 opt.__class__.from_config(opt.get_config()) 

866 for opt in orig_optimizer 

867 ] 

868 elif isinstance(optimizer_config, dict): 

869 optimizer = [ 

870 orig_optimizer[0].__class__.from_config(optimizer_config) 

871 ] 

872 else: 

873 # optimizer config is list of dict, same order as 

874 # orig_optimizer. 

875 optimizer = [ 

876 opt.__class__.from_config(opt_config) 

877 for (opt, opt_config) in zip( 

878 orig_optimizer, optimizer_config 

879 ) 

880 ] 

881 if optimizer_iterations is not None: 

882 for opt in optimizer: 

883 opt.iterations = optimizer_iterations 

884 

885 if len(optimizer) == 1: 

886 optimizer = optimizer[0] 

887 

888 compile_args["optimizer"] = optimizer 

889 if target_tensors is not None: 

890 compile_args["target_tensors"] = target_tensors 

891 # Ensure Metric objects in new model are separate from existing model. 

892 compile_args["metrics"] = metrics_module.clone_metrics( 

893 compile_args["metrics"] 

894 ) 

895 compile_args["weighted_metrics"] = metrics_module.clone_metrics( 

896 compile_args["weighted_metrics"] 

897 ) 

898 clone.compile(**compile_args) 

899 

900 return clone 

901