Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/functional.py: 1%

737 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-05 06:32 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""A `Network` is way to compose layers: the topological form of a `Model`.""" 

17 

18import collections 

19import copy 

20import itertools 

21import warnings 

22 

23import tensorflow.compat.v2 as tf 

24 

25from keras.src import backend 

26from keras.src.dtensor import layout_map as layout_map_lib 

27from keras.src.engine import base_layer 

28from keras.src.engine import base_layer_utils 

29from keras.src.engine import functional_utils 

30from keras.src.engine import input_layer as input_layer_module 

31from keras.src.engine import input_spec 

32from keras.src.engine import node as node_module 

33from keras.src.engine import training as training_lib 

34from keras.src.engine import training_utils 

35from keras.src.saving import serialization_lib 

36from keras.src.saving.legacy import serialization 

37from keras.src.saving.legacy.saved_model import json_utils 

38from keras.src.saving.legacy.saved_model import network_serialization 

39from keras.src.saving.legacy.saved_model import utils as saved_model_utils 

40from keras.src.utils import generic_utils 

41from keras.src.utils import tf_inspect 

42from keras.src.utils import tf_utils 

43 

44# isort: off 

45from tensorflow.python.platform import tf_logging as logging 

46from tensorflow.tools.docs import doc_controls 

47 

48 

49class Functional(training_lib.Model): 

50 """A `Functional` model is a `Model` defined as a directed graph of layers. 

51 

52 Three types of `Model` exist: subclassed `Model`, `Functional` model, 

53 and `Sequential` (a special case of `Functional`). 

54 In general, more Keras features are supported with `Functional` 

55 than with subclassed `Model`s, specifically: 

56 

57 - Model cloning (`keras.models.clone`) 

58 - Serialization (`model.get_config()/from_config`, `model.to_json()` 

59 - Whole-model saving (`model.save()`) 

60 

61 A `Functional` model can be instantiated by passing two arguments to 

62 `__init__`. The first argument is the `keras.Input` Tensors that represent 

63 the inputs to the model. The second argument specifies the output 

64 tensors that represent the outputs of this model. Both arguments can be a 

65 nested structure of tensors. 

66 

67 Example: 

68 

69 ``` 

70 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))} 

71 t = keras.layers.Dense(1, activation='relu')(inputs['x1']) 

72 outputs = keras.layers.Add()([t, inputs['x2']) 

73 model = keras.Model(inputs, outputs) 

74 ``` 

75 

76 A `Functional` model constructed using the Functional API can also include 

77 raw TensorFlow functions, with the exception of functions that create 

78 Variables or assign ops. 

79 

80 Example: 

81 

82 ```python 

83 inputs = keras.Input(shape=(10,)) 

84 x = keras.layers.Dense(1)(inputs) 

85 outputs = tf.nn.relu(x) 

86 model = keras.Model(inputs, outputs) 

87 ``` 

88 

89 A new `Functional` model can also be created by using the 

90 intermediate tensors. This enables you to quickly extract sub-components 

91 of the model. 

92 

93 Example: 

94 

95 ```python 

96 inputs = keras.Input(shape=(None, None, 3)) 

97 processed = keras.layers.RandomCrop(width=32, height=32)(inputs) 

98 conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed) 

99 pooling = keras.layers.GlobalAveragePooling2D()(conv) 

100 feature = keras.layers.Dense(10)(pooling) 

101 

102 full_model = keras.Model(inputs, feature) 

103 backbone = keras.Model(processed, conv) 

104 activations = keras.Model(conv, feature) 

105 ``` 

106 

107 Note that the `backbone` and `activations` models are not 

108 created with `keras.Input` objects, but with the tensors that are originated 

109 from `keras.Input` objects. Under the hood, the layers and weights will 

110 be shared across these models, so that user can train the `full_model`, and 

111 use `backbone` or `activations` to do feature extraction. 

112 The inputs and outputs of the model can be nested structures of tensors as 

113 well, and the created models are standard `Functional` model that support 

114 all the existing API. 

115 

116 Args: 

117 inputs: List of input tensors (must be created via `tf.keras.Input()` or 

118 originated from `tf.keras.Input()`). 

119 outputs: List of output tensors. 

120 name: String, optional. Name of the model. 

121 trainable: Boolean, optional. If the model's variables should be 

122 trainable. 

123 """ 

124 

125 # See tf.Module for the usage of this property. 

126 # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail 

127 # to flatten the key since it is trying to convert Trackable/Layer to a 

128 # string. 

129 _TF_MODULE_IGNORED_PROPERTIES = frozenset( 

130 itertools.chain( 

131 ( 

132 "_layer_call_argspecs", 

133 "_output_mask_cache", 

134 "_output_tensor_cache", 

135 "_output_shape_cache", 

136 ), 

137 training_lib.Model._TF_MODULE_IGNORED_PROPERTIES, 

138 ) 

139 ) 

140 

141 @tf.__internal__.tracking.no_automatic_dependency_tracking 

142 def __init__(self, inputs, outputs, name=None, trainable=True, **kwargs): 

143 # This is used by the Model class, since we have some logic to swap the 

144 # class in the __new__ method, which will lead to __init__ get invoked 

145 # twice. Using the skip_init to skip one of the invocation of __init__ 

146 # to avoid any side effects 

147 skip_init = kwargs.pop("skip_init", False) 

148 if skip_init: 

149 return 

150 generic_utils.validate_kwargs(kwargs, {}) 

151 super().__init__(name=name, trainable=trainable) 

152 # Check if the inputs contain any intermediate `KerasTensor` (not 

153 # created by tf.keras.Input()). In this case we need to clone the `Node` 

154 # and `KerasTensor` objects to mimic rebuilding a new model from new 

155 # inputs. This feature is only enabled in TF2 not in v1 graph mode. 

156 if tf.compat.v1.executing_eagerly_outside_functions(): 

157 if not all( 

158 [ 

159 functional_utils.is_input_keras_tensor(t) 

160 for t in tf.nest.flatten(inputs) 

161 ] 

162 ): 

163 inputs, outputs = functional_utils.clone_graph_nodes( 

164 inputs, outputs 

165 ) 

166 self._init_graph_network(inputs, outputs) 

167 

168 @tf.__internal__.tracking.no_automatic_dependency_tracking 

169 def _init_graph_network(self, inputs, outputs): 

170 # This method is needed for Sequential to reinitialize graph network 

171 # when layer is added or removed. 

172 

173 base_layer.keras_api_gauge.get_cell("Functional").set(True) 

174 self._is_graph_network = True 

175 

176 # Normalize and set self.inputs, self.outputs. 

177 if isinstance(inputs, list) and len(tf.nest.flatten(inputs)) == 1: 

178 inputs = inputs[0] 

179 if isinstance(outputs, list) and len(tf.nest.flatten(outputs)) == 1: 

180 outputs = outputs[0] 

181 self._nested_inputs = inputs 

182 self._nested_outputs = outputs 

183 self.inputs = tf.nest.flatten(inputs) 

184 self.outputs = tf.nest.flatten(outputs) 

185 

186 # Models constructed with a single Tensor or list of Tensors can 

187 # be called with a dict, where the keys of the dict are the names 

188 # of the `Input` objects. Extra keys are ignored with warning. 

189 if not tf.nest.is_nested(self._nested_inputs): 

190 self._enable_dict_to_input_mapping = True 

191 elif isinstance(self._nested_inputs, (list, tuple)) and not any( 

192 tf.nest.is_nested(t) for t in self._nested_inputs 

193 ): 

194 self._enable_dict_to_input_mapping = True 

195 elif isinstance(self._nested_inputs, dict) and not any( 

196 tf.nest.is_nested(t) for t in self._nested_inputs.values() 

197 ): 

198 self._enable_dict_to_input_mapping = True 

199 else: 

200 self._enable_dict_to_input_mapping = False 

201 

202 if not tf.compat.v1.executing_eagerly_outside_functions(): 

203 if any( 

204 not hasattr(tensor, "_keras_history") for tensor in self.outputs 

205 ): 

206 base_layer_utils.create_keras_history(self._nested_outputs) 

207 

208 self._validate_graph_inputs_and_outputs() 

209 

210 # A Network does not create weights of its own, thus it is already 

211 # built. 

212 self.built = True 

213 self._build_input_shape = tf.nest.map_structure( 

214 lambda x: x.shape, inputs 

215 ) 

216 self._compute_output_and_mask_jointly = True 

217 # `_expects_training_arg` is True since the `training` argument is 

218 # always present in the signature of the `call` method of a graph 

219 # network. 

220 self._call_spec.expects_training_arg = True 

221 self._call_spec.expects_mask_arg = True 

222 # A graph network does not autocast inputs, as its layers will cast them 

223 # instead. 

224 self._autocast = False 

225 

226 self._input_layers = [] 

227 self._output_layers = [] 

228 self._input_coordinates = [] 

229 self._output_coordinates = [] 

230 

231 # This is for performance optimization when calling the Network on new 

232 # inputs. Every time the Network is called on a set on input tensors, we 

233 # compute the output tensors, output masks and output shapes in one 

234 # pass, then cache them here. When any of these outputs is queried 

235 # later, we retrieve it from there instead of recomputing it. 

236 self._output_mask_cache = {} 

237 self._output_tensor_cache = {} 

238 self._output_shape_cache = {} 

239 

240 # Build self._output_layers: 

241 for x in self.outputs: 

242 ( 

243 layer, 

244 node_index, 

245 tensor_index, 

246 ) = x._keras_history 

247 self._output_layers.append(layer) 

248 self._output_coordinates.append((layer, node_index, tensor_index)) 

249 

250 # Build self._input_layers: 

251 for x in self.inputs: 

252 ( 

253 layer, 

254 node_index, 

255 tensor_index, 

256 ) = x._keras_history 

257 # It's supposed to be an input layer, so only one node 

258 # and one tensor output. 

259 assert node_index == 0 

260 assert tensor_index == 0 

261 self._input_layers.append(layer) 

262 self._input_coordinates.append((layer, node_index, tensor_index)) 

263 

264 # Keep track of the network's nodes and layers. 

265 nodes, nodes_by_depth, layers, _ = _map_graph_network( 

266 self.inputs, self.outputs 

267 ) 

268 self._network_nodes = nodes 

269 self._nodes_by_depth = nodes_by_depth 

270 self._self_tracked_trackables = layers 

271 self._layer_call_argspecs = {} 

272 for layer in self._self_tracked_trackables: 

273 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec( 

274 layer.call 

275 ) 

276 

277 # Build self.input_names and self.output_names. 

278 self._set_output_names() 

279 self.input_names = [] 

280 self._feed_input_names = [] 

281 self._feed_inputs = [] 

282 self._feed_input_shapes = [] 

283 for layer in self._input_layers: 

284 self.input_names.append(layer.name) 

285 if layer.is_placeholder: 

286 self._feed_input_names.append(layer.name) 

287 # Use batch_input_shape here because non-eager composite tensors 

288 # may not have a shape attribute that's meaningful (sparse, for 

289 # instance, has a tensor that's non-constant and needs to be 

290 # fed). This means that input layers that create placeholders 

291 # will need to have the batch_input_shape attr to allow for 

292 # input shape validation. 

293 self._feed_input_shapes.append(layer._batch_input_shape) 

294 self._feed_inputs.append(layer.input) 

295 

296 self._compute_tensor_usage_count() 

297 self._set_save_spec(self._nested_inputs) 

298 tf_utils.assert_no_legacy_layers(self.layers) 

299 

300 # Note that this method is used by both functional and sequential 

301 # models, so we can't just have this method in functional.__init__, 

302 # which will miss the coverage of sequential model. 

303 if self._layout_map is not None: 

304 layout_map_lib._map_functional_model_variable( 

305 self, self._layout_map 

306 ) 

307 

308 @property 

309 def input(self): 

310 """Retrieves the input tensor(s) of a layer. 

311 

312 Only applicable if the layer has exactly one input, 

313 i.e. if it is connected to one incoming layer. 

314 

315 Returns: 

316 Input tensor or list of input tensors. 

317 

318 Raises: 

319 RuntimeError: If called in Eager mode. 

320 AttributeError: If no inbound nodes are found. 

321 """ 

322 return self._nested_inputs 

323 

324 @property 

325 def input_shape(self): 

326 """Retrieves the input shape(s) of a layer. 

327 

328 Only applicable if the layer has exactly one input, 

329 i.e. if it is connected to one incoming layer, or if all inputs 

330 have the same shape. 

331 

332 Returns: 

333 Input shape, as an integer shape tuple 

334 (or list of shape tuples, one tuple per input tensor). 

335 

336 Raises: 

337 AttributeError: if the layer has no defined input_shape. 

338 RuntimeError: if called in Eager mode. 

339 """ 

340 return tf.nest.map_structure(backend.int_shape, self.input) 

341 

342 @property 

343 def input_spec(self): 

344 if hasattr(self, "_manual_input_spec"): 

345 return self._manual_input_spec 

346 if isinstance(self._nested_inputs, (dict, list, tuple)) and len( 

347 self._nested_inputs 

348 ) != len(self.inputs): 

349 # Case where we have a nested structure. 

350 # In such a case we can't safely run any checks. 

351 return None 

352 if isinstance(self._nested_inputs, dict): 

353 # Case where `_nested_inputs` is a plain dict of Inputs. 

354 names = sorted(self._nested_inputs.keys()) 

355 return [ 

356 input_spec.InputSpec( 

357 shape=shape_with_no_batch_size(self._nested_inputs[name]), 

358 allow_last_axis_squeeze=True, 

359 name=name, 

360 ) 

361 for name in names 

362 ] 

363 else: 

364 # Single input, or list / tuple of inputs. 

365 # The data may be passed as a dict keyed by input name. 

366 return [ 

367 input_spec.InputSpec( 

368 shape=shape_with_no_batch_size(x), 

369 allow_last_axis_squeeze=True, 

370 name=x._keras_history.layer.name, 

371 ) 

372 for x in self.inputs 

373 ] 

374 

375 @input_spec.setter 

376 def input_spec(self, value): 

377 self._manual_input_spec = value 

378 

379 @property 

380 def output(self): 

381 """Retrieves the output tensor(s) of a layer. 

382 

383 Only applicable if the layer has exactly one output, 

384 i.e. if it is connected to one incoming layer. 

385 

386 Returns: 

387 Output tensor or list of output tensors. 

388 

389 Raises: 

390 AttributeError: if the layer is connected to more than one incoming 

391 layers. 

392 RuntimeError: if called in Eager mode. 

393 """ 

394 return self._nested_outputs 

395 

396 @property 

397 def output_shape(self): 

398 """Retrieves the output shape(s) of a layer. 

399 

400 Only applicable if the layer has one output, 

401 or if all outputs have the same shape. 

402 

403 Returns: 

404 Output shape, as an integer shape tuple 

405 (or list of shape tuples, one tuple per output tensor). 

406 

407 Raises: 

408 AttributeError: if the layer has no defined output shape. 

409 RuntimeError: if called in Eager mode. 

410 """ 

411 return tf.nest.map_structure(backend.int_shape, self.output) 

412 

413 def _set_output_names(self): 

414 """Assigns unique names to the Network's outputs. 

415 

416 Output layers with multiple output tensors would otherwise lead to 

417 duplicate names in self.output_names. 

418 """ 

419 uniquified = [] 

420 output_names = set() 

421 prefix_count = {} 

422 for layer in self._output_layers: 

423 proposal = layer.name 

424 while proposal in output_names: 

425 existing_count = prefix_count.get(layer.name, 1) 

426 proposal = f"{layer.name}_{existing_count}" 

427 prefix_count[layer.name] = existing_count + 1 

428 output_names.add(proposal) 

429 uniquified.append(proposal) 

430 self.output_names = uniquified 

431 

432 @property 

433 def _layer_checkpoint_dependencies(self): 

434 """Dictionary of layer dependencies to be included in the checkpoint.""" 

435 weight_layer_index = 0 

436 

437 dependencies = collections.OrderedDict() 

438 for layer_index, layer in enumerate(self.layers): 

439 try: 

440 if layer.weights: 

441 # Keep a separate index for layers which have weights. This 

442 # allows users to insert Layers without weights anywhere in 

443 # the network without breaking checkpoints. 

444 dependencies[ 

445 "layer_with_weights-%d" % weight_layer_index 

446 ] = layer 

447 weight_layer_index += 1 

448 except ValueError: 

449 # The layer might have weights, but may not be built yet. We 

450 # just treat it as layer without weight. 

451 pass 

452 

453 # Even if it doesn't have weights, we should still track everything 

454 # in case it has/will have Trackable dependencies. 

455 dependencies["layer-%d" % layer_index] = layer 

456 return dependencies 

457 

458 def _trackable_children(self, save_type="checkpoint", **kwargs): 

459 dependencies = self._layer_checkpoint_dependencies 

460 dependencies.update(super()._trackable_children(save_type, **kwargs)) 

461 return dependencies 

462 

463 def _lookup_dependency(self, name, cached_dependencies=None): 

464 if cached_dependencies: 

465 return cached_dependencies.get(name) 

466 # Fall back to slow lookup (`layer_checkpoint_dependencies` does a 

467 # thorough check of all layer to see if they contain weights.) 

468 layer_dependencies = self._layer_checkpoint_dependencies 

469 if name in layer_dependencies: 

470 return layer_dependencies[name] 

471 return super()._lookup_dependency(name) 

472 

473 def _handle_deferred_layer_dependencies(self, layers): 

474 """Handles layer checkpoint dependencies that are added after init.""" 

475 layer_checkpoint_dependencies = self._layer_checkpoint_dependencies 

476 layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()} 

477 for layer in layers: 

478 if layer in layer_to_name: 

479 self._handle_deferred_dependencies( 

480 name=layer_to_name[layer], trackable=layer 

481 ) 

482 

483 @property 

484 def _should_compute_mask(self): 

485 return True 

486 

487 def compute_mask(self, inputs, mask): 

488 # TODO(omalleyt): b/123540974 This function is not really safe to call 

489 # by itself because it will duplicate any updates and losses in graph 

490 # mode by `call`ing the Layers again. 

491 output_tensors = self._run_internal_graph(inputs, mask=mask) 

492 return tf.nest.map_structure( 

493 lambda t: getattr(t, "_keras_mask", None), output_tensors 

494 ) 

495 

496 @doc_controls.do_not_doc_inheritable 

497 def call(self, inputs, training=None, mask=None): 

498 """Calls the model on new inputs. 

499 

500 In this case `call` just reapplies 

501 all ops in the graph to the new inputs 

502 (e.g. build a new computational graph from the provided inputs). 

503 

504 Args: 

505 inputs: A tensor or list of tensors. 

506 training: Boolean or boolean scalar tensor, indicating whether to 

507 run the `Network` in training mode or inference mode. 

508 mask: A mask or list of masks. A mask can be 

509 either a tensor or None (no mask). 

510 

511 Returns: 

512 A tensor if there is a single output, or 

513 a list of tensors if there are more than one outputs. 

514 """ 

515 return self._run_internal_graph(inputs, training=training, mask=mask) 

516 

517 def compute_output_shape(self, input_shape): 

518 # Convert any shapes in tuple format to TensorShapes. 

519 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 

520 

521 if len(tf.nest.flatten(input_shape)) != len( 

522 tf.nest.flatten(self._input_layers) 

523 ): 

524 raise ValueError( 

525 f"Invalid `input_shape` argument {input_shape}: " 

526 f"the model expects {len(self._input_layers)} " 

527 "input tensors." 

528 ) 

529 

530 # Use the tuple of TensorShape as the cache key, since tuple is hashable 

531 # and can be used as hash key. 

532 try: 

533 cache_key = tuple( 

534 tf_utils.convert_shapes(input_shape, to_tuples=True) 

535 ) 

536 if cache_key in self._output_shape_cache: 

537 # Cache hit. Return shapes as TensorShapes. 

538 return self._output_shape_cache[cache_key] 

539 except ValueError: 

540 # In case there are unknown TensorShape, eg for sparse tensor input, 

541 # We skip the caching since the shape is unknown. 

542 pass 

543 

544 layers_to_output_shapes = {} 

545 for layer, shape in zip( 

546 self._input_layers, tf.nest.flatten(input_shape) 

547 ): 

548 # It's an input layer: then `compute_output_shape` is identity, 

549 # and there is only one node and one tensor.. 

550 shape_key = layer.name + "_0_0" 

551 layers_to_output_shapes[shape_key] = shape 

552 

553 depth_keys = list(self._nodes_by_depth.keys()) 

554 depth_keys.sort(reverse=True) 

555 # Iterate over nodes, by depth level. 

556 if len(depth_keys) > 1: 

557 for depth in depth_keys: 

558 nodes = self._nodes_by_depth[depth] 

559 for node in nodes: 

560 layer = node.layer 

561 if layer in self._input_layers: 

562 # We've already covered the input layers 

563 # a few lines above. 

564 continue 

565 # Get the input shapes for the first argument of the node 

566 layer_input_shapes = [] 

567 layer_inputs = node.call_args[0] 

568 for layer_input in tf.nest.flatten(layer_inputs): 

569 kh = layer_input._keras_history 

570 input_layer_key = kh.layer.name + "_%s_%s" % ( 

571 kh.node_index, 

572 kh.tensor_index, 

573 ) 

574 layer_input_shapes.append( 

575 layers_to_output_shapes[input_layer_key] 

576 ) 

577 layer_input_shapes = tf.nest.pack_sequence_as( 

578 layer_inputs, layer_input_shapes 

579 ) 

580 # Layers expect shapes to be tuples for 

581 # `compute_output_shape`. 

582 layer_input_shapes = tf_utils.convert_shapes( 

583 layer_input_shapes, to_tuples=True 

584 ) 

585 layer_output_shapes = layer.compute_output_shape( 

586 layer_input_shapes 

587 ) 

588 # Convert back to TensorShapes. 

589 layer_output_shapes = tf_utils.convert_shapes( 

590 layer_output_shapes, to_tuples=False 

591 ) 

592 

593 node_index = layer._inbound_nodes.index(node) 

594 for j, shape in enumerate( 

595 tf.nest.flatten(layer_output_shapes) 

596 ): 

597 shape_key = layer.name + f"_{node_index}_{j}" 

598 layers_to_output_shapes[shape_key] = shape 

599 

600 # Read final output shapes from layers_to_output_shapes. 

601 output_shapes = [] 

602 for i in range(len(self._output_layers)): 

603 layer, node_index, tensor_index = self._output_coordinates[i] 

604 shape_key = layer.name + f"_{node_index}_{tensor_index}" 

605 output_shapes.append(layers_to_output_shapes[shape_key]) 

606 output_shapes = tf.nest.pack_sequence_as( 

607 self._nested_outputs, output_shapes 

608 ) 

609 # Store in cache. 

610 self._output_shape_cache[cache_key] = output_shapes 

611 

612 # Return shapes as TensorShapes. 

613 return output_shapes 

614 

615 def _init_set_name(self, name, zero_based=True): 

616 if not name: 

617 cls_name = self.__class__.__name__ 

618 if self.__class__ == Functional: 

619 # Hide the functional class name from user, since its not a 

620 # public visible class. Use "Model" instead, 

621 cls_name = "Model" 

622 self._name = backend.unique_object_name( 

623 generic_utils.to_snake_case(cls_name), zero_based=zero_based 

624 ) 

625 else: 

626 self._name = name 

627 

628 def _run_internal_graph(self, inputs, training=None, mask=None): 

629 """Computes output tensors for new inputs. 

630 

631 # Note: 

632 - Can be run on non-Keras tensors. 

633 

634 Args: 

635 inputs: Tensor or nested structure of Tensors. 

636 training: Boolean learning phase. 

637 mask: (Optional) Tensor or nested structure of Tensors. 

638 

639 Returns: 

640 output_tensors 

641 """ 

642 inputs = self._flatten_to_reference_inputs(inputs) 

643 if mask is None: 

644 masks = [None] * len(inputs) 

645 else: 

646 masks = self._flatten_to_reference_inputs(mask) 

647 for input_t, mask in zip(inputs, masks): 

648 input_t._keras_mask = mask 

649 

650 # Dictionary mapping reference tensors to computed tensors. 

651 tensor_dict = {} 

652 tensor_usage_count = self._tensor_usage_count 

653 for x, y in zip(self.inputs, inputs): 

654 y = self._conform_to_reference_input(y, ref_input=x) 

655 x_id = str(id(x)) 

656 tensor_dict[x_id] = [y] * tensor_usage_count[x_id] 

657 

658 nodes_by_depth = self._nodes_by_depth 

659 depth_keys = list(nodes_by_depth.keys()) 

660 depth_keys.sort(reverse=True) 

661 

662 for depth in depth_keys: 

663 nodes = nodes_by_depth[depth] 

664 for node in nodes: 

665 if node.is_input: 

666 continue # Input tensors already exist. 

667 

668 if any(t_id not in tensor_dict for t_id in node.flat_input_ids): 

669 continue # Node is not computable, try skipping. 

670 

671 args, kwargs = node.map_arguments(tensor_dict) 

672 outputs = node.layer(*args, **kwargs) 

673 

674 # Update tensor_dict. 

675 for x_id, y in zip( 

676 node.flat_output_ids, tf.nest.flatten(outputs) 

677 ): 

678 tensor_dict[x_id] = [y] * tensor_usage_count[x_id] 

679 

680 output_tensors = [] 

681 for x in self.outputs: 

682 x_id = str(id(x)) 

683 assert x_id in tensor_dict, "Could not compute output " + str(x) 

684 output_tensors.append(tensor_dict[x_id].pop()) 

685 

686 return tf.nest.pack_sequence_as(self._nested_outputs, output_tensors) 

687 

688 def _flatten_to_reference_inputs(self, tensors): 

689 """Maps `tensors` to their respective `keras.Input`.""" 

690 if self._enable_dict_to_input_mapping and isinstance(tensors, dict): 

691 ref_inputs = self._nested_inputs 

692 if not tf.nest.is_nested(ref_inputs): 

693 ref_inputs = [self._nested_inputs] 

694 if isinstance(ref_inputs, dict): 

695 # In the case that the graph is constructed with dict input 

696 # tensors, We will use the original dict key to map with the 

697 # keys in the input data. Note that the model.inputs is using 

698 # nest.flatten to process the input tensors, which means the 

699 # dict input tensors are ordered by their keys. 

700 ref_input_names = sorted(ref_inputs.keys()) 

701 else: 

702 ref_input_names = [ 

703 inp._keras_history.layer.name for inp in ref_inputs 

704 ] 

705 

706 # Raise an warning if there are more input data comparing to input 

707 # tensor 

708 if len(tensors) > len(ref_input_names): 

709 warnings.warn( 

710 "Input dict contained keys {} which did not match any " 

711 "model input. They will be ignored by the model.".format( 

712 [n for n in tensors.keys() if n not in ref_input_names] 

713 ), 

714 stacklevel=2, 

715 ) 

716 

717 try: 

718 # Flatten in the order `Input`s were passed during Model 

719 # construction. 

720 return [tensors[n] for n in ref_input_names] 

721 except KeyError: 

722 # TODO(b/151582614) 

723 return tf.nest.flatten(tensors) 

724 

725 # Otherwise both self.inputs and tensors will already be in same order. 

726 return tf.nest.flatten(tensors) 

727 

728 def _conform_to_reference_input(self, tensor, ref_input): 

729 """Set shape and dtype based on `keras.Input`s.""" 

730 if isinstance(tensor, tf.Tensor): 

731 # Allow (None,) and (None, 1) Tensors to be passed interchangeably. 

732 # Use the shape specified by the `keras.Input`. 

733 t_shape = tensor.shape 

734 t_rank = t_shape.rank 

735 ref_shape = ref_input.shape 

736 ref_rank = ref_shape.rank 

737 keras_history = getattr(tensor, "_keras_history", None) 

738 if t_rank is not None and ref_rank is not None: 

739 # Should squeeze last dimension. True if tensor is (BATCH, ..., 

740 # 1) and reference is (BATCH, ...). 

741 if t_rank == ref_rank + 1 and t_shape[-1] == 1: 

742 tensor = tf.squeeze(tensor, axis=-1) 

743 # Should expand last_dimension. True if tensor is (BATCH, ...) 

744 # and reference is (BATCH, ..., 1). 

745 elif t_rank == ref_rank - 1 and ref_shape[-1] == 1: 

746 tensor = tf.expand_dims(tensor, axis=-1) 

747 if keras_history is not None: # Restore keras history. 

748 tensor._keras_history = keras_history 

749 

750 # Dtype casting. 

751 tensor = tf.cast(tensor, dtype=ref_input.dtype) 

752 elif tf_utils.is_extension_type(tensor): 

753 # Dtype casting (If the extension type has a non-variant dtype and 

754 # supports being cast). Only cast if necessary (since some 

755 # extension types may not implement tf.cast). 

756 tensor_dtype = getattr(tensor, "dtype", None) 

757 ref_input_dtype = getattr(ref_input, "dtype", None) 

758 if ( 

759 ref_input_dtype is not None 

760 and tensor_dtype is not None 

761 and tensor_dtype != ref_input_dtype 

762 and ref_input_dtype != tf.variant 

763 ): 

764 tensor = tf.cast(tensor, dtype=ref_input_dtype) 

765 

766 return tensor 

767 

768 @generic_utils.default 

769 def get_config(self): 

770 # Prepare base arguments 

771 config = { 

772 "name": self.name, 

773 "trainable": self.trainable, 

774 } 

775 

776 if saved_model_utils.in_tf_saved_model_scope(): 

777 # SavedModel special case: need to preserve legacy (potentially 

778 # incorrect) behavior. 

779 return copy.deepcopy(get_network_config(self, config=config)) 

780 

781 # Check whether the class has a constructor compatible with a Functional 

782 # model or if it has a custom constructor. 

783 if has_functional_like_constructor(self.__class__): 

784 # Only return a Functional config if the constructor is the same 

785 # as that of a Functional model. This excludes subclassed Functional 

786 # models with a custom __init__. 

787 config = copy.deepcopy(get_network_config(self, config=config)) 

788 else: 

789 # Try to autogenerate config 

790 xtra_args = set(config.keys()) 

791 if getattr(self, "_auto_get_config", False): 

792 config.update(self._auto_config.config) 

793 # Remove args non explicitly supported 

794 argspec = tf_inspect.getfullargspec(self.__init__) 

795 if argspec.varkw != "kwargs": 

796 for key in xtra_args - xtra_args.intersection(argspec.args[1:]): 

797 config.pop(key, None) 

798 return config 

799 

800 def get_weight_paths(self): 

801 result = {} 

802 for layer in self.layers: 

803 ( 

804 descendants, 

805 object_paths_dict, 

806 ) = tf.__internal__.tracking.ObjectGraphView( 

807 layer 

808 ).breadth_first_traversal() 

809 for descendant in descendants: 

810 if isinstance(descendant, tf.Variable): 

811 trackable_references = object_paths_dict[descendant] 

812 object_path = ".".join( 

813 [t.name for t in trackable_references] 

814 ) 

815 result[layer.name + "." + object_path] = descendant 

816 return result 

817 

818 def _validate_graph_inputs_and_outputs(self): 

819 """Validates the inputs and outputs of a Graph Network.""" 

820 # Check for redundancy in inputs. 

821 if len({id(i) for i in self.inputs}) != len(self.inputs): 

822 raise ValueError( 

823 "The list of inputs passed to the model " 

824 "contains the same input multiple times. " 

825 "All inputs should only appear once." 

826 f"Received inputs={self.inputs}" 

827 ) 

828 

829 for x in self.inputs: 

830 # Check that x has appropriate `_keras_history` metadata. 

831 if not hasattr(x, "_keras_history"): 

832 cls_name = self.__class__.__name__ 

833 raise ValueError( 

834 f"Input tensors to a {cls_name} model " 

835 "must come from `tf.keras.Input`. " 

836 f"Received inputs={x} (missing previous layer metadata)." 

837 ) 

838 # Check that x is an input tensor. 

839 

840 layer = x._keras_history.layer 

841 if len(layer._inbound_nodes) > 1 or ( 

842 layer._inbound_nodes and not layer._inbound_nodes[0].is_input 

843 ): 

844 cls_name = self.__class__.__name__ 

845 logging.warning( 

846 f"{cls_name} model inputs must come from " 

847 "`tf.keras.Input` (thus holding past layer metadata). " 

848 "They cannot be the output of " 

849 "a previous non-Input layer. " 

850 "Here, a tensor specified as " 

851 f'input to "{self.name}" was not an Input tensor, ' 

852 f'it was generated by layer "{layer.name}".\n' 

853 "Note that input tensors are " 

854 "instantiated via `tensor = tf.keras.Input(shape)`.\n" 

855 f"The tensor that caused the issue was: {x}" 

856 ) 

857 

858 # Check compatibility of batch sizes of Input Layers. 

859 input_batch_sizes = set( 

860 [ 

861 training_utils.get_static_batch_size(x._keras_history.layer) 

862 for x in self.inputs 

863 ] 

864 ) 

865 input_batch_sizes.discard(None) 

866 if len(input_batch_sizes) > 1: 

867 logging.warning( 

868 "Found incompatible static batch sizes among the " 

869 f"inputs. Batch sizes: {sorted(input_batch_sizes)}" 

870 ) 

871 

872 for x in self.outputs: 

873 if not hasattr(x, "_keras_history"): 

874 cls_name = self.__class__.__name__ 

875 raise ValueError( 

876 f"Output tensors of a {cls_name} model must be " 

877 "the output of a TensorFlow `Layer` " 

878 f"(thus holding past layer metadata). Found: {x}" 

879 ) 

880 

881 def _insert_layers(self, layers, relevant_nodes=None): 

882 """Inserts Layers into the Network after Network creation. 

883 

884 This is only valid for Keras Graph Networks. Layers added via this 

885 function will be included in the `call` computation and `get_config` of 

886 this Network. They will not be added to the Network's outputs. 

887 

888 Args: 

889 layers: Arbitrary nested structure of Layers. Layers must be reachable 

890 from one or more of the `keras.Input` Tensors that correspond to 

891 this Network's inputs. 

892 relevant_nodes: Nodes from the Layers that should be considered part 

893 of this Network. If `None`, all Nodes will be considered part of 

894 this Network. 

895 

896 Raises: 

897 ValueError: If the layers depend on `Input`s not found in this Model. 

898 """ 

899 layers = tf.nest.flatten(layers) 

900 tf_utils.assert_no_legacy_layers(layers) 

901 node_to_depth = {} 

902 for depth, nodes in self._nodes_by_depth.items(): 

903 node_to_depth.update({node: depth for node in nodes}) 

904 # The nodes of these Layers that are relevant to this Network. If not 

905 # provided, assume all Nodes are relevant 

906 if not relevant_nodes: 

907 relevant_nodes = tf.nest.flatten( 

908 [layer._inbound_nodes for layer in layers] 

909 ) 

910 network_nodes = set(relevant_nodes + list(node_to_depth.keys())) 

911 

912 def _get_min_depth(node): 

913 """Gets the minimum depth at which node can be computed.""" 

914 min_depth = 0 

915 for layer, node_id, _, _ in node.iterate_inbound(): 

916 inbound_node = layer._inbound_nodes[node_id] 

917 if inbound_node in node_to_depth: 

918 min_depth = min(min_depth, node_to_depth[inbound_node]) 

919 elif inbound_node not in network_nodes: 

920 continue 

921 else: 

922 # Previous relevant nodes haven't been processed yet. 

923 return None 

924 # New node is one shallower than its shallowest input. 

925 return min_depth - 1 

926 

927 # Insert nodes into `_nodes_by_depth` and other node attrs. 

928 unprocessed_nodes = copy.copy(relevant_nodes) 

929 i = 0 

930 while unprocessed_nodes: 

931 i += 1 

932 # Do a sanity check. This can occur if `Input`s from outside this 

933 # Model are being relied on. 

934 if i > 10000: 

935 raise ValueError( 

936 "Layers could not be added due to missing dependencies." 

937 ) 

938 

939 node = unprocessed_nodes.pop(0) 

940 depth = _get_min_depth(node) 

941 if depth is None: # Defer until inbound nodes are processed. 

942 unprocessed_nodes.append(node) 

943 continue 

944 node_key = _make_node_key( 

945 node.layer.name, node.layer._inbound_nodes.index(node) 

946 ) 

947 if node_key not in self._network_nodes: 

948 node_to_depth[node] = depth 

949 self._network_nodes.add(node_key) 

950 self._nodes_by_depth[depth].append(node) 

951 

952 # Insert layers and update other layer attrs. 

953 layer_set = set(self._self_tracked_trackables) 

954 deferred_layers = [] 

955 for layer in layers: 

956 if layer not in layer_set: 

957 self._self_tracked_trackables.append(layer) 

958 deferred_layers.append(layer) 

959 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec( 

960 layer.call 

961 ) 

962 layer_set.add(layer) 

963 self._handle_deferred_layer_dependencies(deferred_layers) 

964 

965 self._compute_tensor_usage_count() 

966 

967 def _compute_tensor_usage_count(self): 

968 """Compute the #. of tensor usages for all the output tensors of layers. 

969 

970 The computed tensor usage count is saved as `self._tensor_usage_count`. 

971 This is later used for saving memory in eager computation by releasing 

972 no-longer-needed tensors as early as possible. 

973 """ 

974 tensor_usage_count = collections.Counter() 

975 available_tensors = set(str(id(tensor)) for tensor in self.inputs) 

976 

977 depth_keys = list(self._nodes_by_depth.keys()) 

978 depth_keys.sort(reverse=True) 

979 depth_keys = depth_keys[1:] 

980 

981 for depth in depth_keys: 

982 for node in self._nodes_by_depth[depth]: 

983 input_tensors = { 

984 str(id(tensor)) 

985 for tensor in tf.nest.flatten(node.keras_inputs) 

986 } 

987 if input_tensors.issubset(available_tensors): 

988 for tensor in tf.nest.flatten(node.keras_inputs): 

989 tensor_usage_count[str(id(tensor))] += 1 

990 

991 for output_tensor in tf.nest.flatten(node.outputs): 

992 available_tensors.add(str(id(output_tensor))) 

993 

994 for tensor in self.outputs: 

995 tensor_usage_count[str(id(tensor))] += 1 

996 

997 self._tensor_usage_count = tensor_usage_count 

998 

999 def _assert_weights_created(self): 

1000 # Override the implementation in Model. 

1001 # The Functional model should always have weight created already. 

1002 return 

1003 

1004 def _graph_network_add_loss(self, symbolic_loss): 

1005 new_nodes, new_layers = _map_subgraph_network( 

1006 self.inputs, [symbolic_loss] 

1007 ) 

1008 # Losses must be keyed on inputs no matter what in order to be supported 

1009 # in DistributionStrategy. 

1010 add_loss_layer = base_layer.AddLoss( 

1011 unconditional=False, dtype=symbolic_loss.dtype 

1012 ) 

1013 add_loss_layer(symbolic_loss) 

1014 new_nodes.extend(add_loss_layer.inbound_nodes) 

1015 new_layers.append(add_loss_layer) 

1016 self._insert_layers(new_layers, new_nodes) 

1017 

1018 def _graph_network_add_metric(self, value, aggregation, name): 

1019 new_nodes, new_layers = _map_subgraph_network(self.inputs, [value]) 

1020 add_metric_layer = base_layer.AddMetric( 

1021 aggregation, name, dtype=value.dtype 

1022 ) 

1023 add_metric_layer(value) 

1024 new_nodes.extend(add_metric_layer.inbound_nodes) 

1025 new_layers.append(add_metric_layer) 

1026 self._insert_layers(new_layers, new_nodes) 

1027 

1028 @property 

1029 def _trackable_saved_model_saver(self): 

1030 return network_serialization.NetworkSavedModelSaver(self) 

1031 

1032 def _get_save_spec(self, dynamic_batch=True, inputs_only=True): 

1033 if getattr(self, "_has_explicit_input_shape", True): 

1034 # Functional models and Sequential models that have an explicit 

1035 # input shape should use the batch size set by the input layer. 

1036 dynamic_batch = False 

1037 return super()._get_save_spec(dynamic_batch, inputs_only) 

1038 

1039 

1040def _make_node_key(layer_name, node_index): 

1041 return layer_name + "_ib-" + str(node_index) 

1042 

1043 

1044def _map_graph_network(inputs, outputs): 

1045 """Validates a network's topology and gather its layers and nodes. 

1046 

1047 Args: 

1048 inputs: List of input tensors. 

1049 outputs: List of outputs tensors. 

1050 

1051 Returns: 

1052 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`. 

1053 - nodes: list of Node instances. 

1054 - nodes_by_depth: dict mapping ints (depth) to lists of node instances. 

1055 - layers: list of Layer instances. 

1056 - layers_by_depth: dict mapping ints (depth) to lists of layer instances. 

1057 

1058 Raises: 

1059 ValueError: In case the network is not valid (e.g. disconnected graph). 

1060 """ 

1061 # "depth" is number of layers between output Node and the Node. 

1062 # Nodes are ordered from inputs -> outputs. 

1063 nodes_in_decreasing_depth, layer_indices = _build_map(outputs) 

1064 network_nodes = { 

1065 _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node)) 

1066 for node in nodes_in_decreasing_depth 

1067 } 

1068 

1069 nodes_depths = {} # dict {node: depth value} 

1070 layers_depths = {} # dict {layer: depth value} 

1071 

1072 for node in reversed(nodes_in_decreasing_depth): 

1073 # If the depth is not set, the node has no outbound nodes (depth 0). 

1074 depth = nodes_depths.setdefault(node, 0) 

1075 

1076 # Update the depth of the corresponding layer 

1077 previous_depth = layers_depths.get(node.layer, 0) 

1078 # If we've seen this layer before at a higher depth, 

1079 # we should use that depth instead of the node depth. 

1080 # This is necessary for shared layers that have inputs at different 

1081 # depth levels in the graph. 

1082 depth = max(depth, previous_depth) 

1083 layers_depths[node.layer] = depth 

1084 nodes_depths[node] = depth 

1085 

1086 # Update the depth of inbound nodes. 

1087 # The "depth" of a node is the max of the depths 

1088 # of all nodes it is connected to + 1. 

1089 for node_dep in node.parent_nodes: 

1090 previous_depth = nodes_depths.get(node_dep, 0) 

1091 nodes_depths[node_dep] = max(depth + 1, previous_depth) 

1092 

1093 # Handle inputs that are not connected to outputs. 

1094 # We do not error out here because the inputs may be used to compute losses 

1095 # and metrics. 

1096 for input_t in inputs: 

1097 input_layer = input_t._keras_history[0] 

1098 if input_layer not in layers_depths: 

1099 layers_depths[input_layer] = 0 

1100 layer_indices[input_layer] = -1 

1101 nodes_depths[input_layer._inbound_nodes[0]] = 0 

1102 network_nodes.add(_make_node_key(input_layer.name, 0)) 

1103 

1104 # Build a dict {depth: list of nodes with this depth} 

1105 nodes_by_depth = collections.defaultdict(list) 

1106 for node, depth in nodes_depths.items(): 

1107 nodes_by_depth[depth].append(node) 

1108 

1109 # Build a dict {depth: list of layers with this depth} 

1110 layers_by_depth = collections.defaultdict(list) 

1111 for layer, depth in layers_depths.items(): 

1112 layers_by_depth[depth].append(layer) 

1113 

1114 # Get sorted list of layer depths. 

1115 depth_keys = list(layers_by_depth.keys()) 

1116 depth_keys.sort(reverse=True) 

1117 

1118 # Set self.layers ordered by depth. 

1119 layers = [] 

1120 for depth in depth_keys: 

1121 layers_for_depth = layers_by_depth[depth] 

1122 # Network.layers needs to have a deterministic order: 

1123 # here we order them by traversal order. 

1124 layers_for_depth.sort(key=lambda x: layer_indices[x]) 

1125 layers.extend(layers_for_depth) 

1126 

1127 # Get sorted list of node depths. 

1128 depth_keys = list(nodes_by_depth.keys()) 

1129 depth_keys.sort(reverse=True) 

1130 

1131 # Check that all tensors required are computable. 

1132 # computable_tensors: all tensors in the graph 

1133 # that can be computed from the inputs provided. 

1134 computable_tensors = set() 

1135 for x in inputs: 

1136 computable_tensors.add(id(x)) 

1137 

1138 layers_with_complete_input = [] # To provide a better error msg. 

1139 for depth in depth_keys: 

1140 for node in nodes_by_depth[depth]: 

1141 layer = node.layer 

1142 if layer and not node.is_input: 

1143 for x in tf.nest.flatten(node.keras_inputs): 

1144 if id(x) not in computable_tensors: 

1145 raise ValueError( 

1146 "Graph disconnected: cannot obtain value for " 

1147 f'tensor {x} at layer "{layer.name}". ' 

1148 "The following previous layers were accessed " 

1149 f"without issue: {layers_with_complete_input}" 

1150 ) 

1151 for x in tf.nest.flatten(node.outputs): 

1152 computable_tensors.add(id(x)) 

1153 layers_with_complete_input.append(layer.name) 

1154 

1155 # Ensure name unicity, which will be crucial for serialization 

1156 # (since serialized nodes refer to layers by their name). 

1157 all_names = [layer.name for layer in layers] 

1158 for name in all_names: 

1159 if all_names.count(name) != 1: 

1160 raise ValueError( 

1161 f'The name "{name}" is used {all_names.count(name)} ' 

1162 "times in the model. All layer names should be unique." 

1163 ) 

1164 return network_nodes, nodes_by_depth, layers, layers_by_depth 

1165 

1166 

1167def _build_map(outputs): 

1168 """This method topologically sorts nodes in order from inputs to outputs. 

1169 

1170 It uses a depth-first search to topologically sort nodes that appear in the 

1171 _keras_history connectivity metadata of `outputs`. 

1172 

1173 Args: 

1174 outputs: the output tensors whose _keras_history metadata should be 

1175 walked. This may be an arbitrary nested structure. 

1176 

1177 Returns: 

1178 A tuple like (ordered_nodes, layer_to_first_traversal_index) 

1179 ordered_nodes: list of nodes appearing in the keras history, topologically 

1180 sorted from original inputs to the `outputs`. 

1181 (If outputs have different sets of ancestors, the inputs to one output 

1182 may appear after a different output). 

1183 layer_to_first_traversal_index: 

1184 A dict mapping layer to the traversal index in the DFS where it is 

1185 seen. Note: if a layer is shared by several nodes, the dict will only 

1186 store the index corresponding to the *first* time the layer seen. 

1187 """ 

1188 finished_nodes = set() 

1189 nodes_in_progress = set() 

1190 nodes_in_decreasing_depth = [] # nodes from inputs -> outputs. 

1191 layer_indices = {} # layer -> in traversal order. 

1192 for output in tf.nest.flatten(outputs): 

1193 _build_map_helper( 

1194 output, 

1195 finished_nodes, 

1196 nodes_in_progress, 

1197 nodes_in_decreasing_depth, 

1198 layer_indices, 

1199 ) 

1200 return nodes_in_decreasing_depth, layer_indices 

1201 

1202 

1203def _build_map_helper( 

1204 tensor, 

1205 finished_nodes, 

1206 nodes_in_progress, 

1207 nodes_in_decreasing_depth, 

1208 layer_indices, 

1209): 

1210 """Recursive helper for `_build_map`.""" 

1211 ( 

1212 layer, 

1213 node_index, 

1214 _, 

1215 ) = tensor._keras_history 

1216 node = layer._inbound_nodes[node_index] 

1217 

1218 # Don't repeat work for shared subgraphs 

1219 if node in finished_nodes: 

1220 return 

1221 

1222 # Prevent cycles. 

1223 if node in nodes_in_progress: 

1224 raise ValueError( 

1225 f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.' 

1226 ) 

1227 

1228 # Store the traversal order for layer sorting. 

1229 if layer not in layer_indices: 

1230 layer_indices[layer] = len(layer_indices) 

1231 

1232 # Propagate to all previous tensors connected to this node. 

1233 nodes_in_progress.add(node) 

1234 if not node.is_input: 

1235 for tensor in node.keras_inputs: 

1236 _build_map_helper( 

1237 tensor, 

1238 finished_nodes, 

1239 nodes_in_progress, 

1240 nodes_in_decreasing_depth, 

1241 layer_indices, 

1242 ) 

1243 

1244 finished_nodes.add(node) 

1245 nodes_in_progress.remove(node) 

1246 nodes_in_decreasing_depth.append(node) 

1247 

1248 

1249def _map_subgraph_network(inputs, outputs): 

1250 """Returns the nodes and layers in the topology from `inputs` to `outputs`. 

1251 

1252 Args: 

1253 inputs: List of input tensors. 

1254 outputs: List of output tensors. 

1255 

1256 Returns: 

1257 A tuple of List{Node] and List[Layer]. 

1258 """ 

1259 if not tf.compat.v1.executing_eagerly_outside_functions(): 

1260 base_layer_utils.create_keras_history(outputs) 

1261 # Keep only nodes and layers in the topology between inputs and outputs. 

1262 _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs) 

1263 return tf.nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers 

1264 

1265 

1266def _should_skip_first_node(layer): 

1267 """Returns True if the first layer node should not be saved or loaded.""" 

1268 # Networks that are constructed with an Input layer/shape start with a 

1269 # pre-existing node linking their input to output. This node is excluded 

1270 # from the network config. 

1271 if not hasattr(layer, "_self_tracked_trackables"): 

1272 # Special case for serialization of Functional models without 

1273 # defined input shape argument. 

1274 return isinstance(layer, Functional) 

1275 if layer._self_tracked_trackables: 

1276 return ( 

1277 isinstance(layer, Functional) 

1278 # Filter out Sequential models without an input shape. 

1279 and isinstance( 

1280 layer._self_tracked_trackables[0], input_layer_module.InputLayer 

1281 ) 

1282 ) 

1283 else: 

1284 return isinstance(layer, Functional) 

1285 

1286 

1287def connect_ancillary_layers(model, created_layers): 

1288 """Adds layers that are not connected to the outputs to the model.""" 

1289 # Layers not connected to outputs, such as those added in `add_loss`. 

1290 ancillary_layers = [ 

1291 layer for layer in created_layers.values() if layer not in model.layers 

1292 ] 

1293 if ancillary_layers: 

1294 relevant_nodes = tf.nest.flatten( 

1295 [ 

1296 layer.inbound_nodes[1:] 

1297 if _should_skip_first_node(layer) 

1298 else layer.inbound_nodes 

1299 for layer in created_layers.values() 

1300 ] 

1301 ) 

1302 model._insert_layers(ancillary_layers, relevant_nodes) 

1303 return model 

1304 

1305 

1306def reconstruct_from_config(config, custom_objects=None, created_layers=None): 

1307 """Reconstructs graph from config object. 

1308 

1309 Args: 

1310 config: Dictionary returned from Network.get_config() 

1311 custom_objects: Optional dictionary mapping names (strings) to custom 

1312 classes or functions to be considered during deserialization. 

1313 created_layers: Optional dictionary mapping names to Layer objects. Any 

1314 layer not in this dictionary will be created and added to the dict. 

1315 This function will add new nodes to all layers (excluding InputLayers), 

1316 instead of re-using pre-existing nodes in the layers. 

1317 

1318 Returns: 

1319 Tuple of (input tensors, output tensors, dictionary of created layers) 

1320 """ 

1321 # Layer instances created during the graph reconstruction process. 

1322 created_layers = created_layers or collections.OrderedDict() 

1323 

1324 # Maps input data (tuple of inbound layer name, node index) from the config 

1325 # to node indices in the newly generated model. The node indices may be 

1326 # different if the layers have already been called previously. 

1327 node_index_map = {} 

1328 node_count_by_layer = {} 

1329 

1330 # Dictionary mapping layer instances to 

1331 # node data that specifies a layer call. 

1332 # It acts as a queue that maintains any unprocessed 

1333 # layer call until it becomes possible to process it 

1334 # (i.e. until the input tensors to the call all exist). 

1335 unprocessed_nodes = collections.defaultdict(list) 

1336 

1337 def get_node_index(layer, config_node_index): 

1338 """Returns node index in layer (might differ from config_node_index).""" 

1339 if isinstance(layer, input_layer_module.InputLayer): 

1340 return 0 

1341 return node_index_map.get((layer.name, config_node_index), None) 

1342 

1343 def _deserialize_keras_tensors(kwargs, layer_map): 

1344 """Deserializes Keras Tensors passed to `call`..""" 

1345 

1346 def _deserialize_keras_tensor(t): 

1347 """Deserializes a single Keras Tensor passed to `call`.""" 

1348 if isinstance(t, tf_utils.ListWrapper): 

1349 t = t.as_list() 

1350 layer_name = t[0] 

1351 node_index = t[1] 

1352 tensor_index = t[2] 

1353 

1354 layer = layer_map[layer_name] 

1355 new_node_index = get_node_index(layer, node_index) 

1356 if new_node_index is None: 

1357 # The inbound node may not have been processed yet, 

1358 # (This can happen e.g. if it depends on a different set 

1359 # of inputs than those that have been processed already). 

1360 # raise an IndexError so that the current node puts itself 

1361 # back on the unprocessed queue. 

1362 # Caution: This may lead to infinite loops for malformed 

1363 # network configurations! (or when there is a bug in 

1364 # the network config loading code). 

1365 raise IndexError 

1366 node = layer._inbound_nodes[new_node_index] 

1367 return tf.nest.flatten(node.outputs)[tensor_index] 

1368 return t 

1369 

1370 kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True) 

1371 return tf.nest.map_structure(_deserialize_keras_tensor, kwargs) 

1372 

1373 def process_node(layer, node_data): 

1374 """Deserialize a node. 

1375 

1376 Args: 

1377 layer: layer instance. 

1378 node_data: Nested structure of `ListWrapper`. 

1379 

1380 Returns: 

1381 Whether the node was processed (i.e. the layer was called on the 

1382 inputs specified by the node data) 

1383 

1384 Raises: 

1385 ValueError: In case of improperly formatted `node_data`. 

1386 """ 

1387 input_tensors = [] 

1388 for input_data in tf.nest.flatten(node_data): 

1389 input_data = input_data.as_list() 

1390 if len(input_data) == 3: 

1391 kwargs = {} 

1392 elif len(input_data) == 4: 

1393 kwargs = input_data[3] 

1394 try: 

1395 kwargs = _deserialize_keras_tensors(kwargs, created_layers) 

1396 except IndexError: 

1397 # Happens if keras tensors in kwargs are still unprocessed 

1398 return False 

1399 else: 

1400 raise ValueError("Improperly formatted model config.") 

1401 

1402 if input_data[0] != node_module._CONSTANT_VALUE: 

1403 inbound_layer_name = input_data[0] 

1404 inbound_node_index = input_data[1] 

1405 inbound_tensor_index = input_data[2] 

1406 inbound_layer = created_layers[inbound_layer_name] 

1407 inbound_node_index = get_node_index( 

1408 inbound_layer, inbound_node_index 

1409 ) 

1410 

1411 if inbound_node_index is None: 

1412 return False 

1413 inbound_node = inbound_layer._inbound_nodes[inbound_node_index] 

1414 input_tensors.append( 

1415 tf.nest.flatten(inbound_node.outputs)[inbound_tensor_index] 

1416 ) 

1417 else: 

1418 # We received a constant w/ no Keras history attached, 

1419 # which means it is a constant tensor input. 

1420 # Input is a constant value. 

1421 # Format = [_CONSTANT_VALUE, -1, const_val, kwargs] 

1422 assert input_data[1] == -1 

1423 assert len(input_data) >= 3 

1424 const_val = input_data[2] 

1425 if ( 

1426 isinstance(const_val, tuple) 

1427 and len(const_val) == 2 

1428 and const_val[0] == node_module._COMPOSITE_TYPE 

1429 ): 

1430 # It is a composite tensor. 

1431 input_tensors.append(json_utils.decode(const_val[1])) 

1432 else: 

1433 input_tensors.append(const_val) 

1434 input_tensors = tf.nest.pack_sequence_as(node_data, input_tensors) 

1435 # Call layer on its inputs, thus creating the node 

1436 # and building the layer if needed. 

1437 if input_tensors is not None: 

1438 if ( 

1439 not hasattr(layer, "_preserve_input_structure_in_config") 

1440 or not layer._preserve_input_structure_in_config 

1441 ): 

1442 input_tensors = base_layer_utils.unnest_if_single_tensor( 

1443 input_tensors 

1444 ) 

1445 output_tensors = layer(input_tensors, **kwargs) 

1446 

1447 # Update node index map. 

1448 output_index = tf.nest.flatten(output_tensors)[ 

1449 0 

1450 ]._keras_history.node_index 

1451 node_index_map[ 

1452 (layer.name, node_count_by_layer[layer]) 

1453 ] = output_index 

1454 node_count_by_layer[layer] += 1 

1455 return True 

1456 

1457 def process_layer(layer_data): 

1458 """Deserializes a layer, then call it on appropriate inputs. 

1459 

1460 Args: 

1461 layer_data: layer config dict. 

1462 

1463 Raises: 

1464 ValueError: In case of improperly formatted `layer_data` dict. 

1465 """ 

1466 layer_name = layer_data["name"] 

1467 

1468 if layer_name in created_layers: 

1469 layer = created_layers[layer_name] 

1470 else: 

1471 # Instantiate layer. 

1472 from keras.src.layers import deserialize as deserialize_layer 

1473 

1474 layer = deserialize_layer(layer_data, custom_objects=custom_objects) 

1475 created_layers[layer_name] = layer 

1476 

1477 node_count_by_layer[layer] = int(_should_skip_first_node(layer)) 

1478 

1479 # Gather layer inputs and convert to `ListWrapper` objects. 

1480 inbound_nodes_data = layer_data["inbound_nodes"] 

1481 inbound_nodes_data = tf_utils.convert_inner_node_data( 

1482 inbound_nodes_data, wrap=True 

1483 ) 

1484 for node_data in inbound_nodes_data: 

1485 # We don't process nodes (i.e. make layer calls) 

1486 # on the fly because the inbound node may not yet exist, 

1487 # in case of layer shared at different topological depths 

1488 # (e.g. a model such as A(B(A(B(x))))) 

1489 unprocessed_nodes[layer].append(node_data) 

1490 

1491 # First, we create all layers and enqueue nodes to be processed 

1492 for layer_data in config["layers"]: 

1493 process_layer(layer_data) 

1494 # Then we process nodes in order of layer depth. 

1495 # Nodes that cannot yet be processed (if the inbound node 

1496 # does not yet exist) are re-enqueued, and the process 

1497 # is repeated until all nodes are processed. 

1498 while unprocessed_nodes: 

1499 for layer_data in config["layers"]: 

1500 layer = created_layers[layer_data["name"]] 

1501 if layer in unprocessed_nodes: 

1502 layer_nodes = unprocessed_nodes.pop(layer) 

1503 while layer_nodes: 

1504 node_data = layer_nodes[0] 

1505 if process_node(layer, node_data): 

1506 layer_nodes.pop(0) 

1507 else: 

1508 # If a node can't be processed, stop processing the 

1509 # nodes of the current layer to maintain node ordering. 

1510 unprocessed_nodes[layer] = layer_nodes 

1511 break 

1512 

1513 input_tensors = [] 

1514 output_tensors = [] 

1515 

1516 input_layers = tf_utils.convert_inner_node_data( 

1517 config["input_layers"], wrap=True 

1518 ) 

1519 for layer_data in tf.nest.flatten(input_layers): 

1520 layer_name, node_index, tensor_index = layer_data.as_list() 

1521 assert layer_name in created_layers 

1522 layer = created_layers[layer_name] 

1523 node_index = get_node_index(layer, node_index) 

1524 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 

1525 input_tensors.append( 

1526 tf.nest.flatten(layer_output_tensors)[tensor_index] 

1527 ) 

1528 

1529 output_layers = tf_utils.convert_inner_node_data( 

1530 config["output_layers"], wrap=True 

1531 ) 

1532 for layer_data in tf.nest.flatten(output_layers): 

1533 layer_name, node_index, tensor_index = layer_data.as_list() 

1534 assert layer_name in created_layers 

1535 layer = created_layers[layer_name] 

1536 node_index = get_node_index(layer, node_index) 

1537 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 

1538 output_tensors.append( 

1539 tf.nest.flatten(layer_output_tensors)[tensor_index] 

1540 ) 

1541 

1542 input_tensors = tf.nest.pack_sequence_as(input_layers, input_tensors) 

1543 output_tensors = tf.nest.pack_sequence_as(output_layers, output_tensors) 

1544 return input_tensors, output_tensors, created_layers 

1545 

1546 

1547def get_network_config(network, serialize_layer_fn=None, config=None): 

1548 """Build the config, which consists of the node graph and serialized layers. 

1549 

1550 Args: 

1551 network: A Network object. 

1552 serialize_layer_fn: Function used to serialize layers. 

1553 config: A dict to append more config entries into. If None, start with a 

1554 new dict for the config. 

1555 

1556 Returns: 

1557 Config dictionary. 

1558 """ 

1559 config = config or {} 

1560 serialize_obj_fn = serialization_lib.serialize_keras_object 

1561 set_layers_legacy = False 

1562 # To be removed after full affected g3 user migration to Keras V3 Saving. 

1563 if getattr(network, "use_legacy_config", False): 

1564 serialize_obj_fn = serialization.serialize_keras_object 

1565 set_layers_legacy = True 

1566 serialize_layer_fn = serialize_layer_fn or serialize_obj_fn 

1567 config["name"] = network.name 

1568 node_conversion_map = {} 

1569 for layer in network.layers: 

1570 kept_nodes = 1 if _should_skip_first_node(layer) else 0 

1571 for original_node_index, node in enumerate(layer._inbound_nodes): 

1572 node_key = _make_node_key(layer.name, original_node_index) 

1573 if node_key in network._network_nodes: 

1574 node_conversion_map[node_key] = kept_nodes 

1575 kept_nodes += 1 

1576 layer_configs = [] 

1577 

1578 with serialization.SharedObjectSavingScope(): 

1579 for layer in network.layers: # From the earliest layers on. 

1580 filtered_inbound_nodes = [] 

1581 for original_node_index, node in enumerate(layer._inbound_nodes): 

1582 node_key = _make_node_key(layer.name, original_node_index) 

1583 if node_key in network._network_nodes and not node.is_input: 

1584 # The node is relevant to the model: 

1585 # add to filtered_inbound_nodes. 

1586 node_data = node.serialize( 

1587 _make_node_key, node_conversion_map 

1588 ) 

1589 filtered_inbound_nodes.append(node_data) 

1590 

1591 if isinstance(layer, Functional) and set_layers_legacy: 

1592 layer.use_legacy_config = True 

1593 layer_config = serialize_layer_fn(layer) 

1594 layer_config["name"] = layer.name 

1595 layer_config["inbound_nodes"] = filtered_inbound_nodes 

1596 layer_configs.append(layer_config) 

1597 config["layers"] = layer_configs 

1598 

1599 # Gather info about inputs and outputs. 

1600 model_inputs = [] 

1601 for i in range(len(network._input_layers)): 

1602 layer, node_index, tensor_index = network._input_coordinates[i] 

1603 node_key = _make_node_key(layer.name, node_index) 

1604 if node_key not in network._network_nodes: 

1605 continue 

1606 new_node_index = node_conversion_map[node_key] 

1607 model_inputs.append( 

1608 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]) 

1609 ) 

1610 model_inputs = tf.nest.pack_sequence_as( 

1611 network._nested_inputs, model_inputs 

1612 ) 

1613 # Preserve external Keras compat for Models with single input. 

1614 if not tf.nest.is_nested(model_inputs): 

1615 model_inputs = [model_inputs] 

1616 model_inputs = tf_utils.convert_inner_node_data(model_inputs) 

1617 config["input_layers"] = model_inputs 

1618 

1619 model_outputs = [] 

1620 for i in range(len(network._output_layers)): 

1621 layer, node_index, tensor_index = network._output_coordinates[i] 

1622 node_key = _make_node_key(layer.name, node_index) 

1623 if node_key not in network._network_nodes: 

1624 continue 

1625 new_node_index = node_conversion_map[node_key] 

1626 model_outputs.append( 

1627 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]) 

1628 ) 

1629 model_outputs = tf.nest.pack_sequence_as( 

1630 network._nested_outputs, model_outputs 

1631 ) 

1632 # Preserve external Keras compat for Models with single output. 

1633 if not tf.nest.is_nested(model_outputs): 

1634 model_outputs = [model_outputs] 

1635 model_outputs = tf_utils.convert_inner_node_data(model_outputs) 

1636 config["output_layers"] = model_outputs 

1637 return config 

1638 

1639 

1640def shape_with_no_batch_size(x): 

1641 if x.shape.rank is None: 

1642 return None 

1643 shape = x.shape.as_list() 

1644 if shape: 

1645 shape[0] = None 

1646 return shape 

1647 

1648 

1649class ModuleWrapper(base_layer.Layer): 

1650 """Wrapper for `tf.Module`s to support the Functional and Sequential API.""" 

1651 

1652 def __init__(self, module, method_name=None, **kwargs): 

1653 """Initializes the wrapper Layer for this module. 

1654 

1655 Args: 

1656 module: The `tf.Module` instance to be wrapped. 

1657 method_name: (Optional) str. The name of the method to use as the 

1658 forward pass of the module. If not set, becomes '__call__' if 

1659 defined, or 'call'. Defaults to `None`. 

1660 **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`. 

1661 

1662 Raises: 

1663 ValueError: If `method` is not defined on `module`. 

1664 """ 

1665 super().__init__(**kwargs) 

1666 if method_name is None: 

1667 if hasattr(module, "__call__"): 

1668 method_name = "__call__" 

1669 elif hasattr(module, "call"): 

1670 method_name = "call" 

1671 if method_name is None or not hasattr(module, method_name): 

1672 raise ValueError(f"{method_name} is not defined on object {module}") 

1673 

1674 self._module = module 

1675 self._method_name = method_name 

1676 

1677 # Check if module.__call__ has a `training` arg or accepts `**kwargs`. 

1678 method = getattr(module, method_name) 

1679 method_arg_spec = tf_inspect.getfullargspec(method) 

1680 self._call_spec.expects_training_arg = ( 

1681 "training" in method_arg_spec.args 

1682 or method_arg_spec.varkw is not None 

1683 ) 

1684 self._call_spec.expects_mask_arg = ( 

1685 "mask" in method_arg_spec.args or method_arg_spec.varkw is not None 

1686 ) 

1687 

1688 def call(self, *args, **kwargs): 

1689 if "training" in kwargs and not self._expects_training_arg: 

1690 kwargs.pop("training") 

1691 if "mask" in kwargs and not self._expects_mask_arg: 

1692 kwargs.pop("mask") 

1693 return getattr(self._module, self._method_name)(*args, **kwargs) 

1694 

1695 

1696def has_functional_like_constructor(cls): 

1697 init_args = tf_inspect.getfullargspec(cls.__init__).args[1:] 

1698 functional_init_args = tf_inspect.getfullargspec(Functional.__init__).args[ 

1699 1: 

1700 ] 

1701 if init_args == functional_init_args: 

1702 return True 

1703 return False 

1704