Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/functional.py: 12%
735 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""A `Network` is way to compose layers: the topological form of a `Model`."""
18import collections
19import copy
20import itertools
21import warnings
23import tensorflow.compat.v2 as tf
25from keras.src import backend
26from keras.src.dtensor import layout_map as layout_map_lib
27from keras.src.engine import base_layer
28from keras.src.engine import base_layer_utils
29from keras.src.engine import functional_utils
30from keras.src.engine import input_layer as input_layer_module
31from keras.src.engine import input_spec
32from keras.src.engine import node as node_module
33from keras.src.engine import training as training_lib
34from keras.src.engine import training_utils
35from keras.src.saving import serialization_lib
36from keras.src.saving.legacy import serialization
37from keras.src.saving.legacy.saved_model import json_utils
38from keras.src.saving.legacy.saved_model import network_serialization
39from keras.src.saving.legacy.saved_model import utils as saved_model_utils
40from keras.src.utils import generic_utils
41from keras.src.utils import tf_inspect
42from keras.src.utils import tf_utils
44# isort: off
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.tools.docs import doc_controls
49class Functional(training_lib.Model):
50 """A `Functional` model is a `Model` defined as a directed graph of layers.
52 Three types of `Model` exist: subclassed `Model`, `Functional` model,
53 and `Sequential` (a special case of `Functional`).
54 In general, more Keras features are supported with `Functional`
55 than with subclassed `Model`s, specifically:
57 - Model cloning (`keras.models.clone`)
58 - Serialization (`model.get_config()/from_config`, `model.to_json()`
59 - Whole-model saving (`model.save()`)
61 A `Functional` model can be instantiated by passing two arguments to
62 `__init__`. The first argument is the `keras.Input` Tensors that represent
63 the inputs to the model. The second argument specifies the output
64 tensors that represent the outputs of this model. Both arguments can be a
65 nested structure of tensors.
67 Example:
69 ```
70 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
71 t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
72 outputs = keras.layers.Add()([t, inputs['x2'])
73 model = keras.Model(inputs, outputs)
74 ```
76 A `Functional` model constructed using the Functional API can also include
77 raw TensorFlow functions, with the exception of functions that create
78 Variables or assign ops.
80 Example:
82 ```python
83 inputs = keras.Input(shape=(10,))
84 x = keras.layers.Dense(1)(inputs)
85 outputs = tf.nn.relu(x)
86 model = keras.Model(inputs, outputs)
87 ```
89 A new `Functional` model can also be created by using the
90 intermediate tensors. This enables you to quickly extract sub-components
91 of the model.
93 Example:
95 ```python
96 inputs = keras.Input(shape=(None, None, 3))
97 processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
98 conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
99 pooling = keras.layers.GlobalAveragePooling2D()(conv)
100 feature = keras.layers.Dense(10)(pooling)
102 full_model = keras.Model(inputs, feature)
103 backbone = keras.Model(processed, conv)
104 activations = keras.Model(conv, feature)
105 ```
107 Note that the `backbone` and `activations` models are not
108 created with `keras.Input` objects, but with the tensors that are originated
109 from `keras.Input` objects. Under the hood, the layers and weights will
110 be shared across these models, so that user can train the `full_model`, and
111 use `backbone` or `activations` to do feature extraction.
112 The inputs and outputs of the model can be nested structures of tensors as
113 well, and the created models are standard `Functional` model that support
114 all the existing API.
116 Args:
117 inputs: List of input tensors (must be created via `tf.keras.Input()` or
118 originated from `tf.keras.Input()`).
119 outputs: List of output tensors.
120 name: String, optional. Name of the model.
121 trainable: Boolean, optional. If the model's variables should be
122 trainable.
123 """
125 # See tf.Module for the usage of this property.
126 # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail
127 # to flatten the key since it is trying to convert Trackable/Layer to a
128 # string.
129 _TF_MODULE_IGNORED_PROPERTIES = frozenset(
130 itertools.chain(
131 (
132 "_layer_call_argspecs",
133 "_compiled_trainable_state",
134 "_output_mask_cache",
135 "_output_tensor_cache",
136 "_output_shape_cache",
137 ),
138 training_lib.Model._TF_MODULE_IGNORED_PROPERTIES,
139 )
140 )
142 @tf.__internal__.tracking.no_automatic_dependency_tracking
143 def __init__(self, inputs, outputs, name=None, trainable=True, **kwargs):
144 # This is used by the Model class, since we have some logic to swap the
145 # class in the __new__ method, which will lead to __init__ get invoked
146 # twice. Using the skip_init to skip one of the invocation of __init__
147 # to avoid any side effects
148 skip_init = kwargs.pop("skip_init", False)
149 if skip_init:
150 return
151 generic_utils.validate_kwargs(kwargs, {})
152 super().__init__(name=name, trainable=trainable)
153 # Check if the inputs contain any intermediate `KerasTensor` (not
154 # created by tf.keras.Input()). In this case we need to clone the `Node`
155 # and `KerasTensor` objects to mimic rebuilding a new model from new
156 # inputs. This feature is only enabled in TF2 not in v1 graph mode.
157 if tf.compat.v1.executing_eagerly_outside_functions():
158 if not all(
159 [
160 functional_utils.is_input_keras_tensor(t)
161 for t in tf.nest.flatten(inputs)
162 ]
163 ):
164 inputs, outputs = functional_utils.clone_graph_nodes(
165 inputs, outputs
166 )
167 self._init_graph_network(inputs, outputs)
169 @tf.__internal__.tracking.no_automatic_dependency_tracking
170 def _init_graph_network(self, inputs, outputs):
171 # This method is needed for Sequential to reinitialize graph network
172 # when layer is added or removed.
174 base_layer.keras_api_gauge.get_cell("Functional").set(True)
175 self._is_graph_network = True
177 # Normalize and set self.inputs, self.outputs.
178 if isinstance(inputs, list) and len(tf.nest.flatten(inputs)) == 1:
179 inputs = inputs[0]
180 if isinstance(outputs, list) and len(tf.nest.flatten(outputs)) == 1:
181 outputs = outputs[0]
182 self._nested_inputs = inputs
183 self._nested_outputs = outputs
184 self.inputs = tf.nest.flatten(inputs)
185 self.outputs = tf.nest.flatten(outputs)
187 # Models constructed with a single Tensor or list of Tensors can
188 # be called with a dict, where the keys of the dict are the names
189 # of the `Input` objects. Extra keys are ignored with warning.
190 if not tf.nest.is_nested(self._nested_inputs):
191 self._enable_dict_to_input_mapping = True
192 elif isinstance(self._nested_inputs, (list, tuple)) and not any(
193 tf.nest.is_nested(t) for t in self._nested_inputs
194 ):
195 self._enable_dict_to_input_mapping = True
196 elif isinstance(self._nested_inputs, dict) and not any(
197 tf.nest.is_nested(t) for t in self._nested_inputs.values()
198 ):
199 self._enable_dict_to_input_mapping = True
200 else:
201 self._enable_dict_to_input_mapping = False
203 if not tf.compat.v1.executing_eagerly_outside_functions():
204 if any(
205 not hasattr(tensor, "_keras_history") for tensor in self.outputs
206 ):
207 base_layer_utils.create_keras_history(self._nested_outputs)
209 self._validate_graph_inputs_and_outputs()
211 # A Network does not create weights of its own, thus it is already
212 # built.
213 self.built = True
214 self._build_input_shape = tf.nest.map_structure(
215 lambda x: x.shape, inputs
216 )
217 self._compute_output_and_mask_jointly = True
218 # `_expects_training_arg` is True since the `training` argument is
219 # always present in the signature of the `call` method of a graph
220 # network.
221 self._call_spec.expects_training_arg = True
222 self._call_spec.expects_mask_arg = True
223 # A graph network does not autocast inputs, as its layers will cast them
224 # instead.
225 self._autocast = False
227 self._input_layers = []
228 self._output_layers = []
229 self._input_coordinates = []
230 self._output_coordinates = []
232 # This is for performance optimization when calling the Network on new
233 # inputs. Every time the Network is called on a set on input tensors, we
234 # compute the output tensors, output masks and output shapes in one
235 # pass, then cache them here. When any of these outputs is queried
236 # later, we retrieve it from there instead of recomputing it.
237 self._output_mask_cache = {}
238 self._output_tensor_cache = {}
239 self._output_shape_cache = {}
241 # Build self._output_layers:
242 for x in self.outputs:
243 (
244 layer,
245 node_index,
246 tensor_index,
247 ) = x._keras_history
248 self._output_layers.append(layer)
249 self._output_coordinates.append((layer, node_index, tensor_index))
251 # Build self._input_layers:
252 for x in self.inputs:
253 (
254 layer,
255 node_index,
256 tensor_index,
257 ) = x._keras_history
258 # It's supposed to be an input layer, so only one node
259 # and one tensor output.
260 assert node_index == 0
261 assert tensor_index == 0
262 self._input_layers.append(layer)
263 self._input_coordinates.append((layer, node_index, tensor_index))
265 # Keep track of the network's nodes and layers.
266 nodes, nodes_by_depth, layers, _ = _map_graph_network(
267 self.inputs, self.outputs
268 )
269 self._network_nodes = nodes
270 self._nodes_by_depth = nodes_by_depth
271 self._self_tracked_trackables = layers
272 self._layer_call_argspecs = {}
273 for layer in self._self_tracked_trackables:
274 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(
275 layer.call
276 )
278 # Build self.input_names and self.output_names.
279 self._set_output_names()
280 self.input_names = []
281 self._feed_input_names = []
282 self._feed_inputs = []
283 self._feed_input_shapes = []
284 for layer in self._input_layers:
285 self.input_names.append(layer.name)
286 if layer.is_placeholder:
287 self._feed_input_names.append(layer.name)
288 # Use batch_input_shape here because non-eager composite tensors
289 # may not have a shape attribute that's meaningful (sparse, for
290 # instance, has a tensor that's non-constant and needs to be
291 # fed). This means that input layers that create placeholders
292 # will need to have the batch_input_shape attr to allow for
293 # input shape validation.
294 self._feed_input_shapes.append(layer._batch_input_shape)
295 self._feed_inputs.append(layer.input)
297 self._compute_tensor_usage_count()
298 self._set_save_spec(self._nested_inputs)
299 tf_utils.assert_no_legacy_layers(self.layers)
301 # Note that this method is used by both functional and sequential
302 # models, so we can't just have this method in functional.__init__,
303 # which will miss the coverage of sequential model.
304 if self._layout_map is not None:
305 layout_map_lib._map_functional_model_variable(
306 self, self._layout_map
307 )
309 @property
310 def input(self):
311 """Retrieves the input tensor(s) of a layer.
313 Only applicable if the layer has exactly one input,
314 i.e. if it is connected to one incoming layer.
316 Returns:
317 Input tensor or list of input tensors.
319 Raises:
320 RuntimeError: If called in Eager mode.
321 AttributeError: If no inbound nodes are found.
322 """
323 return self._nested_inputs
325 @property
326 def input_shape(self):
327 """Retrieves the input shape(s) of a layer.
329 Only applicable if the layer has exactly one input,
330 i.e. if it is connected to one incoming layer, or if all inputs
331 have the same shape.
333 Returns:
334 Input shape, as an integer shape tuple
335 (or list of shape tuples, one tuple per input tensor).
337 Raises:
338 AttributeError: if the layer has no defined input_shape.
339 RuntimeError: if called in Eager mode.
340 """
341 return tf.nest.map_structure(backend.int_shape, self.input)
343 @property
344 def input_spec(self):
345 if hasattr(self, "_manual_input_spec"):
346 return self._manual_input_spec
347 if isinstance(self._nested_inputs, (dict, list, tuple)) and len(
348 self._nested_inputs
349 ) != len(self.inputs):
350 # Case where we have a nested structure.
351 # In such a case we can't safely run any checks.
352 return None
353 if isinstance(self._nested_inputs, dict):
354 # Case where `_nested_inputs` is a plain dict of Inputs.
355 names = sorted(self._nested_inputs.keys())
356 return [
357 input_spec.InputSpec(
358 shape=shape_with_no_batch_size(self._nested_inputs[name]),
359 allow_last_axis_squeeze=True,
360 name=name,
361 )
362 for name in names
363 ]
364 else:
365 # Single input, or list / tuple of inputs.
366 # The data may be passed as a dict keyed by input name.
367 return [
368 input_spec.InputSpec(
369 shape=shape_with_no_batch_size(x),
370 allow_last_axis_squeeze=True,
371 name=x._keras_history.layer.name,
372 )
373 for x in self.inputs
374 ]
376 @input_spec.setter
377 def input_spec(self, value):
378 self._manual_input_spec = value
380 @property
381 def output(self):
382 """Retrieves the output tensor(s) of a layer.
384 Only applicable if the layer has exactly one output,
385 i.e. if it is connected to one incoming layer.
387 Returns:
388 Output tensor or list of output tensors.
390 Raises:
391 AttributeError: if the layer is connected to more than one incoming
392 layers.
393 RuntimeError: if called in Eager mode.
394 """
395 return self._nested_outputs
397 @property
398 def output_shape(self):
399 """Retrieves the output shape(s) of a layer.
401 Only applicable if the layer has one output,
402 or if all outputs have the same shape.
404 Returns:
405 Output shape, as an integer shape tuple
406 (or list of shape tuples, one tuple per output tensor).
408 Raises:
409 AttributeError: if the layer has no defined output shape.
410 RuntimeError: if called in Eager mode.
411 """
412 return tf.nest.map_structure(backend.int_shape, self.output)
414 def _set_output_names(self):
415 """Assigns unique names to the Network's outputs.
417 Output layers with multiple output tensors would otherwise lead to
418 duplicate names in self.output_names.
419 """
420 uniquified = []
421 output_names = set()
422 prefix_count = {}
423 for layer in self._output_layers:
424 proposal = layer.name
425 while proposal in output_names:
426 existing_count = prefix_count.get(layer.name, 1)
427 proposal = f"{layer.name}_{existing_count}"
428 prefix_count[layer.name] = existing_count + 1
429 output_names.add(proposal)
430 uniquified.append(proposal)
431 self.output_names = uniquified
433 @property
434 def _layer_checkpoint_dependencies(self):
435 """Dictionary of layer dependencies to be included in the checkpoint."""
436 weight_layer_index = 0
438 dependencies = collections.OrderedDict()
439 for layer_index, layer in enumerate(self.layers):
440 try:
441 if layer.weights:
442 # Keep a separate index for layers which have weights. This
443 # allows users to insert Layers without weights anywhere in
444 # the network without breaking checkpoints.
445 dependencies[
446 "layer_with_weights-%d" % weight_layer_index
447 ] = layer
448 weight_layer_index += 1
449 except ValueError:
450 # The layer might have weights, but may not be built yet. We
451 # just treat it as layer without weight.
452 pass
454 # Even if it doesn't have weights, we should still track everything
455 # in case it has/will have Trackable dependencies.
456 dependencies["layer-%d" % layer_index] = layer
457 return dependencies
459 def _trackable_children(self, save_type="checkpoint", **kwargs):
460 dependencies = self._layer_checkpoint_dependencies
461 dependencies.update(super()._trackable_children(save_type, **kwargs))
462 return dependencies
464 def _lookup_dependency(self, name):
465 layer_dependencies = self._layer_checkpoint_dependencies
466 if name in layer_dependencies:
467 return layer_dependencies[name]
468 return super()._lookup_dependency(name)
470 def _handle_deferred_layer_dependencies(self, layers):
471 """Handles layer checkpoint dependencies that are added after init."""
472 layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
473 layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
474 for layer in layers:
475 if layer in layer_to_name:
476 self._handle_deferred_dependencies(
477 name=layer_to_name[layer], trackable=layer
478 )
480 @property
481 def _should_compute_mask(self):
482 return True
484 def compute_mask(self, inputs, mask):
485 # TODO(omalleyt): b/123540974 This function is not really safe to call
486 # by itself because it will duplicate any updates and losses in graph
487 # mode by `call`ing the Layers again.
488 output_tensors = self._run_internal_graph(inputs, mask=mask)
489 return tf.nest.map_structure(
490 lambda t: getattr(t, "_keras_mask", None), output_tensors
491 )
493 @doc_controls.do_not_doc_inheritable
494 def call(self, inputs, training=None, mask=None):
495 """Calls the model on new inputs.
497 In this case `call` just reapplies
498 all ops in the graph to the new inputs
499 (e.g. build a new computational graph from the provided inputs).
501 Args:
502 inputs: A tensor or list of tensors.
503 training: Boolean or boolean scalar tensor, indicating whether to
504 run the `Network` in training mode or inference mode.
505 mask: A mask or list of masks. A mask can be
506 either a tensor or None (no mask).
508 Returns:
509 A tensor if there is a single output, or
510 a list of tensors if there are more than one outputs.
511 """
512 return self._run_internal_graph(inputs, training=training, mask=mask)
514 def compute_output_shape(self, input_shape):
515 # Convert any shapes in tuple format to TensorShapes.
516 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
518 if len(tf.nest.flatten(input_shape)) != len(
519 tf.nest.flatten(self._input_layers)
520 ):
521 raise ValueError(
522 f"Invalid `input_shape` argument {input_shape}: "
523 f"the model expects {len(self._input_layers)} "
524 "input tensors."
525 )
527 # Use the tuple of TensorShape as the cache key, since tuple is hashable
528 # and can be used as hash key.
529 try:
530 cache_key = tuple(
531 tf_utils.convert_shapes(input_shape, to_tuples=True)
532 )
533 if cache_key in self._output_shape_cache:
534 # Cache hit. Return shapes as TensorShapes.
535 return self._output_shape_cache[cache_key]
536 except ValueError:
537 # In case there are unknown TensorShape, eg for sparse tensor input,
538 # We skip the caching since the shape is unknown.
539 pass
541 layers_to_output_shapes = {}
542 for layer, shape in zip(
543 self._input_layers, tf.nest.flatten(input_shape)
544 ):
545 # It's an input layer: then `compute_output_shape` is identity,
546 # and there is only one node and one tensor..
547 shape_key = layer.name + "_0_0"
548 layers_to_output_shapes[shape_key] = shape
550 depth_keys = list(self._nodes_by_depth.keys())
551 depth_keys.sort(reverse=True)
552 # Iterate over nodes, by depth level.
553 if len(depth_keys) > 1:
554 for depth in depth_keys:
555 nodes = self._nodes_by_depth[depth]
556 for node in nodes:
557 layer = node.layer
558 if layer in self._input_layers:
559 # We've already covered the input layers
560 # a few lines above.
561 continue
562 # Get the input shapes for the first argument of the node
563 layer_input_shapes = []
564 layer_inputs = node.call_args[0]
565 for layer_input in tf.nest.flatten(layer_inputs):
566 kh = layer_input._keras_history
567 input_layer_key = kh.layer.name + "_%s_%s" % (
568 kh.node_index,
569 kh.tensor_index,
570 )
571 layer_input_shapes.append(
572 layers_to_output_shapes[input_layer_key]
573 )
574 layer_input_shapes = tf.nest.pack_sequence_as(
575 layer_inputs, layer_input_shapes
576 )
577 # Layers expect shapes to be tuples for
578 # `compute_output_shape`.
579 layer_input_shapes = tf_utils.convert_shapes(
580 layer_input_shapes, to_tuples=True
581 )
582 layer_output_shapes = layer.compute_output_shape(
583 layer_input_shapes
584 )
585 # Convert back to TensorShapes.
586 layer_output_shapes = tf_utils.convert_shapes(
587 layer_output_shapes, to_tuples=False
588 )
590 node_index = layer._inbound_nodes.index(node)
591 for j, shape in enumerate(
592 tf.nest.flatten(layer_output_shapes)
593 ):
594 shape_key = layer.name + f"_{node_index}_{j}"
595 layers_to_output_shapes[shape_key] = shape
597 # Read final output shapes from layers_to_output_shapes.
598 output_shapes = []
599 for i in range(len(self._output_layers)):
600 layer, node_index, tensor_index = self._output_coordinates[i]
601 shape_key = layer.name + f"_{node_index}_{tensor_index}"
602 output_shapes.append(layers_to_output_shapes[shape_key])
603 output_shapes = tf.nest.pack_sequence_as(
604 self._nested_outputs, output_shapes
605 )
606 # Store in cache.
607 self._output_shape_cache[cache_key] = output_shapes
609 # Return shapes as TensorShapes.
610 return output_shapes
612 def _init_set_name(self, name, zero_based=True):
613 if not name:
614 cls_name = self.__class__.__name__
615 if self.__class__ == Functional:
616 # Hide the functional class name from user, since its not a
617 # public visible class. Use "Model" instead,
618 cls_name = "Model"
619 self._name = backend.unique_object_name(
620 generic_utils.to_snake_case(cls_name), zero_based=zero_based
621 )
622 else:
623 self._name = name
625 def _run_internal_graph(self, inputs, training=None, mask=None):
626 """Computes output tensors for new inputs.
628 # Note:
629 - Can be run on non-Keras tensors.
631 Args:
632 inputs: Tensor or nested structure of Tensors.
633 training: Boolean learning phase.
634 mask: (Optional) Tensor or nested structure of Tensors.
636 Returns:
637 output_tensors
638 """
639 inputs = self._flatten_to_reference_inputs(inputs)
640 if mask is None:
641 masks = [None] * len(inputs)
642 else:
643 masks = self._flatten_to_reference_inputs(mask)
644 for input_t, mask in zip(inputs, masks):
645 input_t._keras_mask = mask
647 # Dictionary mapping reference tensors to computed tensors.
648 tensor_dict = {}
649 tensor_usage_count = self._tensor_usage_count
650 for x, y in zip(self.inputs, inputs):
651 y = self._conform_to_reference_input(y, ref_input=x)
652 x_id = str(id(x))
653 tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
655 nodes_by_depth = self._nodes_by_depth
656 depth_keys = list(nodes_by_depth.keys())
657 depth_keys.sort(reverse=True)
659 for depth in depth_keys:
660 nodes = nodes_by_depth[depth]
661 for node in nodes:
662 if node.is_input:
663 continue # Input tensors already exist.
665 if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
666 continue # Node is not computable, try skipping.
668 args, kwargs = node.map_arguments(tensor_dict)
669 outputs = node.layer(*args, **kwargs)
671 # Update tensor_dict.
672 for x_id, y in zip(
673 node.flat_output_ids, tf.nest.flatten(outputs)
674 ):
675 tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
677 output_tensors = []
678 for x in self.outputs:
679 x_id = str(id(x))
680 assert x_id in tensor_dict, "Could not compute output " + str(x)
681 output_tensors.append(tensor_dict[x_id].pop())
683 return tf.nest.pack_sequence_as(self._nested_outputs, output_tensors)
685 def _flatten_to_reference_inputs(self, tensors):
686 """Maps `tensors` to their respective `keras.Input`."""
687 if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
688 ref_inputs = self._nested_inputs
689 if not tf.nest.is_nested(ref_inputs):
690 ref_inputs = [self._nested_inputs]
691 if isinstance(ref_inputs, dict):
692 # In the case that the graph is constructed with dict input
693 # tensors, We will use the original dict key to map with the
694 # keys in the input data. Note that the model.inputs is using
695 # nest.flatten to process the input tensors, which means the
696 # dict input tensors are ordered by their keys.
697 ref_input_names = sorted(ref_inputs.keys())
698 else:
699 ref_input_names = [
700 inp._keras_history.layer.name for inp in ref_inputs
701 ]
703 # Raise an warning if there are more input data comparing to input
704 # tensor
705 if len(tensors) > len(ref_input_names):
706 warnings.warn(
707 "Input dict contained keys {} which did not match any "
708 "model input. They will be ignored by the model.".format(
709 [n for n in tensors.keys() if n not in ref_input_names]
710 ),
711 stacklevel=2,
712 )
714 try:
715 # Flatten in the order `Input`s were passed during Model
716 # construction.
717 return [tensors[n] for n in ref_input_names]
718 except KeyError:
719 # TODO(b/151582614)
720 return tf.nest.flatten(tensors)
722 # Otherwise both self.inputs and tensors will already be in same order.
723 return tf.nest.flatten(tensors)
725 def _conform_to_reference_input(self, tensor, ref_input):
726 """Set shape and dtype based on `keras.Input`s."""
727 if isinstance(tensor, tf.Tensor):
728 # Allow (None,) and (None, 1) Tensors to be passed interchangeably.
729 # Use the shape specified by the `keras.Input`.
730 t_shape = tensor.shape
731 t_rank = t_shape.rank
732 ref_shape = ref_input.shape
733 ref_rank = ref_shape.rank
734 keras_history = getattr(tensor, "_keras_history", None)
735 if t_rank is not None and ref_rank is not None:
736 # Should squeeze last dimension. True if tensor is (BATCH, ...,
737 # 1) and reference is (BATCH, ...).
738 if t_rank == ref_rank + 1 and t_shape[-1] == 1:
739 tensor = tf.squeeze(tensor, axis=-1)
740 # Should expand last_dimension. True if tensor is (BATCH, ...)
741 # and reference is (BATCH, ..., 1).
742 elif t_rank == ref_rank - 1 and ref_shape[-1] == 1:
743 tensor = tf.expand_dims(tensor, axis=-1)
744 if keras_history is not None: # Restore keras history.
745 tensor._keras_history = keras_history
747 # Dtype casting.
748 tensor = tf.cast(tensor, dtype=ref_input.dtype)
749 elif tf_utils.is_extension_type(tensor):
750 # Dtype casting (If the extension type has a non-variant dtype and
751 # supports being cast). Only cast if necessary (since some
752 # extension types may not implement tf.cast).
753 tensor_dtype = getattr(tensor, "dtype", None)
754 ref_input_dtype = getattr(ref_input, "dtype", None)
755 if (
756 ref_input_dtype is not None
757 and tensor_dtype is not None
758 and tensor_dtype != ref_input_dtype
759 and ref_input_dtype != tf.variant
760 ):
761 tensor = tf.cast(tensor, dtype=ref_input_dtype)
763 return tensor
765 @generic_utils.default
766 def get_config(self):
767 # Prepare base arguments
768 config = {
769 "name": self.name,
770 "trainable": self.trainable,
771 }
773 if saved_model_utils.in_tf_saved_model_scope():
774 # SavedModel special case: need to preserve legacy (potentially
775 # incorrect) behavior.
776 return copy.deepcopy(get_network_config(self, config=config))
778 # Check whether the class has a constructor compatible with a Functional
779 # model or if it has a custom constructor.
780 if has_functional_like_constructor(self.__class__):
781 # Only return a Functional config if the constructor is the same
782 # as that of a Functional model. This excludes subclassed Functional
783 # models with a custom __init__.
784 config = copy.deepcopy(get_network_config(self, config=config))
785 else:
786 # Try to autogenerate config
787 xtra_args = set(config.keys())
788 if getattr(self, "_auto_get_config", False):
789 config.update(self._auto_config.config)
790 # Remove args non explicitly supported
791 argspec = tf_inspect.getfullargspec(self.__init__)
792 if argspec.varkw != "kwargs":
793 for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
794 config.pop(key, None)
795 return config
797 def get_weight_paths(self):
798 result = {}
799 for layer in self.layers:
800 (
801 descendants,
802 object_paths_dict,
803 ) = tf.__internal__.tracking.ObjectGraphView(
804 layer
805 ).breadth_first_traversal()
806 for descendant in descendants:
807 if isinstance(descendant, tf.Variable):
808 trackable_references = object_paths_dict[descendant]
809 object_path = ".".join(
810 [t.name for t in trackable_references]
811 )
812 result[layer.name + "." + object_path] = descendant
813 return result
815 def _validate_graph_inputs_and_outputs(self):
816 """Validates the inputs and outputs of a Graph Network."""
817 # Check for redundancy in inputs.
818 if len({id(i) for i in self.inputs}) != len(self.inputs):
819 raise ValueError(
820 "The list of inputs passed to the model "
821 "contains the same input multiple times. "
822 "All inputs should only appear once."
823 f"Received inputs={self.inputs}"
824 )
826 for x in self.inputs:
827 # Check that x has appropriate `_keras_history` metadata.
828 if not hasattr(x, "_keras_history"):
829 cls_name = self.__class__.__name__
830 raise ValueError(
831 f"Input tensors to a {cls_name} model "
832 "must come from `tf.keras.Input`. "
833 f"Received inputs={x} (missing previous layer metadata)."
834 )
835 # Check that x is an input tensor.
837 layer = x._keras_history.layer
838 if len(layer._inbound_nodes) > 1 or (
839 layer._inbound_nodes and not layer._inbound_nodes[0].is_input
840 ):
841 cls_name = self.__class__.__name__
842 logging.warning(
843 f"{cls_name} model inputs must come from "
844 "`tf.keras.Input` (thus holding past layer metadata). "
845 "They cannot be the output of "
846 "a previous non-Input layer. "
847 "Here, a tensor specified as "
848 f'input to "{self.name}" was not an Input tensor, '
849 f'it was generated by layer "{layer.name}".\n'
850 "Note that input tensors are "
851 "instantiated via `tensor = tf.keras.Input(shape)`.\n"
852 f"The tensor that caused the issue was: {x}"
853 )
855 # Check compatibility of batch sizes of Input Layers.
856 input_batch_sizes = set(
857 [
858 training_utils.get_static_batch_size(x._keras_history.layer)
859 for x in self.inputs
860 ]
861 )
862 input_batch_sizes.discard(None)
863 if len(input_batch_sizes) > 1:
864 logging.warning(
865 "Found incompatible static batch sizes among the "
866 f"inputs. Batch sizes: {sorted(input_batch_sizes)}"
867 )
869 for x in self.outputs:
870 if not hasattr(x, "_keras_history"):
871 cls_name = self.__class__.__name__
872 raise ValueError(
873 f"Output tensors of a {cls_name} model must be "
874 "the output of a TensorFlow `Layer` "
875 f"(thus holding past layer metadata). Found: {x}"
876 )
878 def _insert_layers(self, layers, relevant_nodes=None):
879 """Inserts Layers into the Network after Network creation.
881 This is only valid for Keras Graph Networks. Layers added via this
882 function will be included in the `call` computation and `get_config` of
883 this Network. They will not be added to the Network's outputs.
885 Args:
886 layers: Arbitrary nested structure of Layers. Layers must be reachable
887 from one or more of the `keras.Input` Tensors that correspond to
888 this Network's inputs.
889 relevant_nodes: Nodes from the Layers that should be considered part
890 of this Network. If `None`, all Nodes will be considered part of
891 this Network.
893 Raises:
894 ValueError: If the layers depend on `Input`s not found in this Model.
895 """
896 layers = tf.nest.flatten(layers)
897 tf_utils.assert_no_legacy_layers(layers)
898 node_to_depth = {}
899 for depth, nodes in self._nodes_by_depth.items():
900 node_to_depth.update({node: depth for node in nodes})
901 # The nodes of these Layers that are relevant to this Network. If not
902 # provided, assume all Nodes are relevant
903 if not relevant_nodes:
904 relevant_nodes = tf.nest.flatten(
905 [layer._inbound_nodes for layer in layers]
906 )
907 network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
909 def _get_min_depth(node):
910 """Gets the minimum depth at which node can be computed."""
911 min_depth = 0
912 for layer, node_id, _, _ in node.iterate_inbound():
913 inbound_node = layer._inbound_nodes[node_id]
914 if inbound_node in node_to_depth:
915 min_depth = min(min_depth, node_to_depth[inbound_node])
916 elif inbound_node not in network_nodes:
917 continue
918 else:
919 # Previous relevant nodes haven't been processed yet.
920 return None
921 # New node is one shallower than its shallowest input.
922 return min_depth - 1
924 # Insert nodes into `_nodes_by_depth` and other node attrs.
925 unprocessed_nodes = copy.copy(relevant_nodes)
926 i = 0
927 while unprocessed_nodes:
928 i += 1
929 # Do a sanity check. This can occur if `Input`s from outside this
930 # Model are being relied on.
931 if i > 10000:
932 raise ValueError(
933 "Layers could not be added due to missing dependencies."
934 )
936 node = unprocessed_nodes.pop(0)
937 depth = _get_min_depth(node)
938 if depth is None: # Defer until inbound nodes are processed.
939 unprocessed_nodes.append(node)
940 continue
941 node_key = _make_node_key(
942 node.layer.name, node.layer._inbound_nodes.index(node)
943 )
944 if node_key not in self._network_nodes:
945 node_to_depth[node] = depth
946 self._network_nodes.add(node_key)
947 self._nodes_by_depth[depth].append(node)
949 # Insert layers and update other layer attrs.
950 layer_set = set(self._self_tracked_trackables)
951 deferred_layers = []
952 for layer in layers:
953 if layer not in layer_set:
954 self._self_tracked_trackables.append(layer)
955 deferred_layers.append(layer)
956 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(
957 layer.call
958 )
959 layer_set.add(layer)
960 self._handle_deferred_layer_dependencies(deferred_layers)
962 self._compute_tensor_usage_count()
964 def _compute_tensor_usage_count(self):
965 """Compute the #. of tensor usages for all the output tensors of layers.
967 The computed tensor usage count is saved as `self._tensor_usage_count`.
968 This is later used for saving memory in eager computation by releasing
969 no-longer-needed tensors as early as possible.
970 """
971 tensor_usage_count = collections.Counter()
972 available_tensors = set(str(id(tensor)) for tensor in self.inputs)
974 depth_keys = list(self._nodes_by_depth.keys())
975 depth_keys.sort(reverse=True)
976 depth_keys = depth_keys[1:]
978 for depth in depth_keys:
979 for node in self._nodes_by_depth[depth]:
980 input_tensors = {
981 str(id(tensor))
982 for tensor in tf.nest.flatten(node.keras_inputs)
983 }
984 if input_tensors.issubset(available_tensors):
985 for tensor in tf.nest.flatten(node.keras_inputs):
986 tensor_usage_count[str(id(tensor))] += 1
988 for output_tensor in tf.nest.flatten(node.outputs):
989 available_tensors.add(str(id(output_tensor)))
991 for tensor in self.outputs:
992 tensor_usage_count[str(id(tensor))] += 1
994 self._tensor_usage_count = tensor_usage_count
996 def _assert_weights_created(self):
997 # Override the implementation in Model.
998 # The Functional model should always have weight created already.
999 return
1001 def _graph_network_add_loss(self, symbolic_loss):
1002 new_nodes, new_layers = _map_subgraph_network(
1003 self.inputs, [symbolic_loss]
1004 )
1005 # Losses must be keyed on inputs no matter what in order to be supported
1006 # in DistributionStrategy.
1007 add_loss_layer = base_layer.AddLoss(
1008 unconditional=False, dtype=symbolic_loss.dtype
1009 )
1010 add_loss_layer(symbolic_loss)
1011 new_nodes.extend(add_loss_layer.inbound_nodes)
1012 new_layers.append(add_loss_layer)
1013 self._insert_layers(new_layers, new_nodes)
1015 def _graph_network_add_metric(self, value, aggregation, name):
1016 new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
1017 add_metric_layer = base_layer.AddMetric(
1018 aggregation, name, dtype=value.dtype
1019 )
1020 add_metric_layer(value)
1021 new_nodes.extend(add_metric_layer.inbound_nodes)
1022 new_layers.append(add_metric_layer)
1023 self._insert_layers(new_layers, new_nodes)
1025 @property
1026 def _trackable_saved_model_saver(self):
1027 return network_serialization.NetworkSavedModelSaver(self)
1029 def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
1030 if getattr(self, "_has_explicit_input_shape", True):
1031 # Functional models and Sequential models that have an explicit
1032 # input shape should use the batch size set by the input layer.
1033 dynamic_batch = False
1034 return super()._get_save_spec(dynamic_batch, inputs_only)
1037def _make_node_key(layer_name, node_index):
1038 return layer_name + "_ib-" + str(node_index)
1041def _map_graph_network(inputs, outputs):
1042 """Validates a network's topology and gather its layers and nodes.
1044 Args:
1045 inputs: List of input tensors.
1046 outputs: List of outputs tensors.
1048 Returns:
1049 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
1050 - nodes: list of Node instances.
1051 - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
1052 - layers: list of Layer instances.
1053 - layers_by_depth: dict mapping ints (depth) to lists of layer instances.
1055 Raises:
1056 ValueError: In case the network is not valid (e.g. disconnected graph).
1057 """
1058 # "depth" is number of layers between output Node and the Node.
1059 # Nodes are ordered from inputs -> outputs.
1060 nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
1061 network_nodes = {
1062 _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
1063 for node in nodes_in_decreasing_depth
1064 }
1066 nodes_depths = {} # dict {node: depth value}
1067 layers_depths = {} # dict {layer: depth value}
1069 for node in reversed(nodes_in_decreasing_depth):
1070 # If the depth is not set, the node has no outbound nodes (depth 0).
1071 depth = nodes_depths.setdefault(node, 0)
1073 # Update the depth of the corresponding layer
1074 previous_depth = layers_depths.get(node.layer, 0)
1075 # If we've seen this layer before at a higher depth,
1076 # we should use that depth instead of the node depth.
1077 # This is necessary for shared layers that have inputs at different
1078 # depth levels in the graph.
1079 depth = max(depth, previous_depth)
1080 layers_depths[node.layer] = depth
1081 nodes_depths[node] = depth
1083 # Update the depth of inbound nodes.
1084 # The "depth" of a node is the max of the depths
1085 # of all nodes it is connected to + 1.
1086 for node_dep in node.parent_nodes:
1087 previous_depth = nodes_depths.get(node_dep, 0)
1088 nodes_depths[node_dep] = max(depth + 1, previous_depth)
1090 # Handle inputs that are not connected to outputs.
1091 # We do not error out here because the inputs may be used to compute losses
1092 # and metrics.
1093 for input_t in inputs:
1094 input_layer = input_t._keras_history[0]
1095 if input_layer not in layers_depths:
1096 layers_depths[input_layer] = 0
1097 layer_indices[input_layer] = -1
1098 nodes_depths[input_layer._inbound_nodes[0]] = 0
1099 network_nodes.add(_make_node_key(input_layer.name, 0))
1101 # Build a dict {depth: list of nodes with this depth}
1102 nodes_by_depth = collections.defaultdict(list)
1103 for node, depth in nodes_depths.items():
1104 nodes_by_depth[depth].append(node)
1106 # Build a dict {depth: list of layers with this depth}
1107 layers_by_depth = collections.defaultdict(list)
1108 for layer, depth in layers_depths.items():
1109 layers_by_depth[depth].append(layer)
1111 # Get sorted list of layer depths.
1112 depth_keys = list(layers_by_depth.keys())
1113 depth_keys.sort(reverse=True)
1115 # Set self.layers ordered by depth.
1116 layers = []
1117 for depth in depth_keys:
1118 layers_for_depth = layers_by_depth[depth]
1119 # Network.layers needs to have a deterministic order:
1120 # here we order them by traversal order.
1121 layers_for_depth.sort(key=lambda x: layer_indices[x])
1122 layers.extend(layers_for_depth)
1124 # Get sorted list of node depths.
1125 depth_keys = list(nodes_by_depth.keys())
1126 depth_keys.sort(reverse=True)
1128 # Check that all tensors required are computable.
1129 # computable_tensors: all tensors in the graph
1130 # that can be computed from the inputs provided.
1131 computable_tensors = set()
1132 for x in inputs:
1133 computable_tensors.add(id(x))
1135 layers_with_complete_input = [] # To provide a better error msg.
1136 for depth in depth_keys:
1137 for node in nodes_by_depth[depth]:
1138 layer = node.layer
1139 if layer and not node.is_input:
1140 for x in tf.nest.flatten(node.keras_inputs):
1141 if id(x) not in computable_tensors:
1142 raise ValueError(
1143 "Graph disconnected: cannot obtain value for "
1144 f'tensor {x} at layer "{layer.name}". '
1145 "The following previous layers were accessed "
1146 f"without issue: {layers_with_complete_input}"
1147 )
1148 for x in tf.nest.flatten(node.outputs):
1149 computable_tensors.add(id(x))
1150 layers_with_complete_input.append(layer.name)
1152 # Ensure name unicity, which will be crucial for serialization
1153 # (since serialized nodes refer to layers by their name).
1154 all_names = [layer.name for layer in layers]
1155 for name in all_names:
1156 if all_names.count(name) != 1:
1157 raise ValueError(
1158 f'The name "{name}" is used {all_names.count(name)} '
1159 "times in the model. All layer names should be unique."
1160 )
1161 return network_nodes, nodes_by_depth, layers, layers_by_depth
1164def _build_map(outputs):
1165 """This method topologically sorts nodes in order from inputs to outputs.
1167 It uses a depth-first search to topologically sort nodes that appear in the
1168 _keras_history connectivity metadata of `outputs`.
1170 Args:
1171 outputs: the output tensors whose _keras_history metadata should be
1172 walked. This may be an arbitrary nested structure.
1174 Returns:
1175 A tuple like (ordered_nodes, layer_to_first_traversal_index)
1176 ordered_nodes: list of nodes appearing in the keras history, topologically
1177 sorted from original inputs to the `outputs`.
1178 (If outputs have different sets of ancestors, the inputs to one output
1179 may appear after a different output).
1180 layer_to_first_traversal_index:
1181 A dict mapping layer to the traversal index in the DFS where it is
1182 seen. Note: if a layer is shared by several nodes, the dict will only
1183 store the index corresponding to the *first* time the layer seen.
1184 """
1185 finished_nodes = set()
1186 nodes_in_progress = set()
1187 nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
1188 layer_indices = {} # layer -> in traversal order.
1189 for output in tf.nest.flatten(outputs):
1190 _build_map_helper(
1191 output,
1192 finished_nodes,
1193 nodes_in_progress,
1194 nodes_in_decreasing_depth,
1195 layer_indices,
1196 )
1197 return nodes_in_decreasing_depth, layer_indices
1200def _build_map_helper(
1201 tensor,
1202 finished_nodes,
1203 nodes_in_progress,
1204 nodes_in_decreasing_depth,
1205 layer_indices,
1206):
1207 """Recursive helper for `_build_map`."""
1208 (
1209 layer,
1210 node_index,
1211 _,
1212 ) = tensor._keras_history
1213 node = layer._inbound_nodes[node_index]
1215 # Don't repeat work for shared subgraphs
1216 if node in finished_nodes:
1217 return
1219 # Prevent cycles.
1220 if node in nodes_in_progress:
1221 raise ValueError(
1222 f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.'
1223 )
1225 # Store the traversal order for layer sorting.
1226 if layer not in layer_indices:
1227 layer_indices[layer] = len(layer_indices)
1229 # Propagate to all previous tensors connected to this node.
1230 nodes_in_progress.add(node)
1231 if not node.is_input:
1232 for tensor in node.keras_inputs:
1233 _build_map_helper(
1234 tensor,
1235 finished_nodes,
1236 nodes_in_progress,
1237 nodes_in_decreasing_depth,
1238 layer_indices,
1239 )
1241 finished_nodes.add(node)
1242 nodes_in_progress.remove(node)
1243 nodes_in_decreasing_depth.append(node)
1246def _map_subgraph_network(inputs, outputs):
1247 """Returns the nodes and layers in the topology from `inputs` to `outputs`.
1249 Args:
1250 inputs: List of input tensors.
1251 outputs: List of output tensors.
1253 Returns:
1254 A tuple of List{Node] and List[Layer].
1255 """
1256 if not tf.compat.v1.executing_eagerly_outside_functions():
1257 base_layer_utils.create_keras_history(outputs)
1258 # Keep only nodes and layers in the topology between inputs and outputs.
1259 _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
1260 return tf.nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
1263def _should_skip_first_node(layer):
1264 """Returns True if the first layer node should not be saved or loaded."""
1265 # Networks that are constructed with an Input layer/shape start with a
1266 # pre-existing node linking their input to output. This node is excluded
1267 # from the network config.
1268 if not hasattr(layer, "_self_tracked_trackables"):
1269 # Special case for serialization of Functional models without
1270 # defined input shape argument.
1271 return isinstance(layer, Functional)
1272 if layer._self_tracked_trackables:
1273 return (
1274 isinstance(layer, Functional)
1275 # Filter out Sequential models without an input shape.
1276 and isinstance(
1277 layer._self_tracked_trackables[0], input_layer_module.InputLayer
1278 )
1279 )
1280 else:
1281 return isinstance(layer, Functional)
1284def connect_ancillary_layers(model, created_layers):
1285 """Adds layers that are not connected to the outputs to the model."""
1286 # Layers not connected to outputs, such as those added in `add_loss`.
1287 ancillary_layers = [
1288 layer for layer in created_layers.values() if layer not in model.layers
1289 ]
1290 if ancillary_layers:
1291 relevant_nodes = tf.nest.flatten(
1292 [
1293 layer.inbound_nodes[1:]
1294 if _should_skip_first_node(layer)
1295 else layer.inbound_nodes
1296 for layer in created_layers.values()
1297 ]
1298 )
1299 model._insert_layers(ancillary_layers, relevant_nodes)
1300 return model
1303def reconstruct_from_config(config, custom_objects=None, created_layers=None):
1304 """Reconstructs graph from config object.
1306 Args:
1307 config: Dictionary returned from Network.get_config()
1308 custom_objects: Optional dictionary mapping names (strings) to custom
1309 classes or functions to be considered during deserialization.
1310 created_layers: Optional dictionary mapping names to Layer objects. Any
1311 layer not in this dictionary will be created and added to the dict.
1312 This function will add new nodes to all layers (excluding InputLayers),
1313 instead of re-using pre-existing nodes in the layers.
1315 Returns:
1316 Tuple of (input tensors, output tensors, dictionary of created layers)
1317 """
1318 # Layer instances created during the graph reconstruction process.
1319 created_layers = created_layers or collections.OrderedDict()
1321 # Maps input data (tuple of inbound layer name, node index) from the config
1322 # to node indices in the newly generated model. The node indices may be
1323 # different if the layers have already been called previously.
1324 node_index_map = {}
1325 node_count_by_layer = {}
1327 # Dictionary mapping layer instances to
1328 # node data that specifies a layer call.
1329 # It acts as a queue that maintains any unprocessed
1330 # layer call until it becomes possible to process it
1331 # (i.e. until the input tensors to the call all exist).
1332 unprocessed_nodes = collections.defaultdict(list)
1334 def get_node_index(layer, config_node_index):
1335 """Returns node index in layer (might differ from config_node_index)."""
1336 if isinstance(layer, input_layer_module.InputLayer):
1337 return 0
1338 return node_index_map.get((layer.name, config_node_index), None)
1340 def _deserialize_keras_tensors(kwargs, layer_map):
1341 """Deserializes Keras Tensors passed to `call`.."""
1343 def _deserialize_keras_tensor(t):
1344 """Deserializes a single Keras Tensor passed to `call`."""
1345 if isinstance(t, tf_utils.ListWrapper):
1346 t = t.as_list()
1347 layer_name = t[0]
1348 node_index = t[1]
1349 tensor_index = t[2]
1351 layer = layer_map[layer_name]
1352 new_node_index = get_node_index(layer, node_index)
1353 if new_node_index is None:
1354 # The inbound node may not have been processed yet,
1355 # (This can happen e.g. if it depends on a different set
1356 # of inputs than those that have been processed already).
1357 # raise an IndexError so that the current node puts itself
1358 # back on the unprocessed queue.
1359 # Caution: This may lead to infinite loops for malformed
1360 # network configurations! (or when there is a bug in
1361 # the network config loading code).
1362 raise IndexError
1363 node = layer._inbound_nodes[new_node_index]
1364 return tf.nest.flatten(node.outputs)[tensor_index]
1365 return t
1367 kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
1368 return tf.nest.map_structure(_deserialize_keras_tensor, kwargs)
1370 def process_node(layer, node_data):
1371 """Deserialize a node.
1373 Args:
1374 layer: layer instance.
1375 node_data: Nested structure of `ListWrapper`.
1377 Returns:
1378 Whether the node was processed (i.e. the layer was called on the
1379 inputs specified by the node data)
1381 Raises:
1382 ValueError: In case of improperly formatted `node_data`.
1383 """
1384 input_tensors = []
1385 for input_data in tf.nest.flatten(node_data):
1386 input_data = input_data.as_list()
1387 if len(input_data) == 3:
1388 kwargs = {}
1389 elif len(input_data) == 4:
1390 kwargs = input_data[3]
1391 try:
1392 kwargs = _deserialize_keras_tensors(kwargs, created_layers)
1393 except IndexError:
1394 # Happens if keras tensors in kwargs are still unprocessed
1395 return False
1396 else:
1397 raise ValueError("Improperly formatted model config.")
1399 if input_data[0] != node_module._CONSTANT_VALUE:
1400 inbound_layer_name = input_data[0]
1401 inbound_node_index = input_data[1]
1402 inbound_tensor_index = input_data[2]
1403 inbound_layer = created_layers[inbound_layer_name]
1404 inbound_node_index = get_node_index(
1405 inbound_layer, inbound_node_index
1406 )
1408 if inbound_node_index is None:
1409 return False
1410 inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
1411 input_tensors.append(
1412 tf.nest.flatten(inbound_node.outputs)[inbound_tensor_index]
1413 )
1414 else:
1415 # We received a constant w/ no Keras history attached,
1416 # which means it is a constant tensor input.
1417 # Input is a constant value.
1418 # Format = [_CONSTANT_VALUE, -1, const_val, kwargs]
1419 assert input_data[1] == -1
1420 assert len(input_data) >= 3
1421 const_val = input_data[2]
1422 if (
1423 isinstance(const_val, tuple)
1424 and len(const_val) == 2
1425 and const_val[0] == node_module._COMPOSITE_TYPE
1426 ):
1427 # It is a composite tensor.
1428 input_tensors.append(json_utils.decode(const_val[1]))
1429 else:
1430 input_tensors.append(const_val)
1431 input_tensors = tf.nest.pack_sequence_as(node_data, input_tensors)
1432 # Call layer on its inputs, thus creating the node
1433 # and building the layer if needed.
1434 if input_tensors is not None:
1435 if (
1436 not hasattr(layer, "_preserve_input_structure_in_config")
1437 or not layer._preserve_input_structure_in_config
1438 ):
1439 input_tensors = base_layer_utils.unnest_if_single_tensor(
1440 input_tensors
1441 )
1442 output_tensors = layer(input_tensors, **kwargs)
1444 # Update node index map.
1445 output_index = tf.nest.flatten(output_tensors)[
1446 0
1447 ]._keras_history.node_index
1448 node_index_map[
1449 (layer.name, node_count_by_layer[layer])
1450 ] = output_index
1451 node_count_by_layer[layer] += 1
1452 return True
1454 def process_layer(layer_data):
1455 """Deserializes a layer, then call it on appropriate inputs.
1457 Args:
1458 layer_data: layer config dict.
1460 Raises:
1461 ValueError: In case of improperly formatted `layer_data` dict.
1462 """
1463 layer_name = layer_data["name"]
1465 if layer_name in created_layers:
1466 layer = created_layers[layer_name]
1467 else:
1468 # Instantiate layer.
1469 from keras.src.layers import deserialize as deserialize_layer
1471 layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1472 created_layers[layer_name] = layer
1474 node_count_by_layer[layer] = int(_should_skip_first_node(layer))
1476 # Gather layer inputs and convert to `ListWrapper` objects.
1477 inbound_nodes_data = layer_data["inbound_nodes"]
1478 inbound_nodes_data = tf_utils.convert_inner_node_data(
1479 inbound_nodes_data, wrap=True
1480 )
1481 for node_data in inbound_nodes_data:
1482 # We don't process nodes (i.e. make layer calls)
1483 # on the fly because the inbound node may not yet exist,
1484 # in case of layer shared at different topological depths
1485 # (e.g. a model such as A(B(A(B(x)))))
1486 unprocessed_nodes[layer].append(node_data)
1488 # First, we create all layers and enqueue nodes to be processed
1489 for layer_data in config["layers"]:
1490 process_layer(layer_data)
1491 # Then we process nodes in order of layer depth.
1492 # Nodes that cannot yet be processed (if the inbound node
1493 # does not yet exist) are re-enqueued, and the process
1494 # is repeated until all nodes are processed.
1495 while unprocessed_nodes:
1496 for layer_data in config["layers"]:
1497 layer = created_layers[layer_data["name"]]
1498 if layer in unprocessed_nodes:
1499 layer_nodes = unprocessed_nodes.pop(layer)
1500 while layer_nodes:
1501 node_data = layer_nodes[0]
1502 if process_node(layer, node_data):
1503 layer_nodes.pop(0)
1504 else:
1505 # If a node can't be processed, stop processing the
1506 # nodes of the current layer to maintain node ordering.
1507 unprocessed_nodes[layer] = layer_nodes
1508 break
1510 input_tensors = []
1511 output_tensors = []
1513 input_layers = tf_utils.convert_inner_node_data(
1514 config["input_layers"], wrap=True
1515 )
1516 for layer_data in tf.nest.flatten(input_layers):
1517 layer_name, node_index, tensor_index = layer_data.as_list()
1518 assert layer_name in created_layers
1519 layer = created_layers[layer_name]
1520 node_index = get_node_index(layer, node_index)
1521 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1522 input_tensors.append(
1523 tf.nest.flatten(layer_output_tensors)[tensor_index]
1524 )
1526 output_layers = tf_utils.convert_inner_node_data(
1527 config["output_layers"], wrap=True
1528 )
1529 for layer_data in tf.nest.flatten(output_layers):
1530 layer_name, node_index, tensor_index = layer_data.as_list()
1531 assert layer_name in created_layers
1532 layer = created_layers[layer_name]
1533 node_index = get_node_index(layer, node_index)
1534 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1535 output_tensors.append(
1536 tf.nest.flatten(layer_output_tensors)[tensor_index]
1537 )
1539 input_tensors = tf.nest.pack_sequence_as(input_layers, input_tensors)
1540 output_tensors = tf.nest.pack_sequence_as(output_layers, output_tensors)
1541 return input_tensors, output_tensors, created_layers
1544def get_network_config(network, serialize_layer_fn=None, config=None):
1545 """Build the config, which consists of the node graph and serialized layers.
1547 Args:
1548 network: A Network object.
1549 serialize_layer_fn: Function used to serialize layers.
1550 config: A dict to append more config entries into. If None, start with a
1551 new dict for the config.
1553 Returns:
1554 Config dictionary.
1555 """
1556 config = config or {}
1557 serialize_obj_fn = serialization_lib.serialize_keras_object
1558 set_layers_legacy = False
1559 # To be removed after full affected g3 user migration to Keras V3 Saving.
1560 if getattr(network, "use_legacy_config", False):
1561 serialize_obj_fn = serialization.serialize_keras_object
1562 set_layers_legacy = True
1563 serialize_layer_fn = serialize_layer_fn or serialize_obj_fn
1564 config["name"] = network.name
1565 node_conversion_map = {}
1566 for layer in network.layers:
1567 kept_nodes = 1 if _should_skip_first_node(layer) else 0
1568 for original_node_index, node in enumerate(layer._inbound_nodes):
1569 node_key = _make_node_key(layer.name, original_node_index)
1570 if node_key in network._network_nodes:
1571 node_conversion_map[node_key] = kept_nodes
1572 kept_nodes += 1
1573 layer_configs = []
1575 with serialization.SharedObjectSavingScope():
1576 for layer in network.layers: # From the earliest layers on.
1577 filtered_inbound_nodes = []
1578 for original_node_index, node in enumerate(layer._inbound_nodes):
1579 node_key = _make_node_key(layer.name, original_node_index)
1580 if node_key in network._network_nodes and not node.is_input:
1581 # The node is relevant to the model:
1582 # add to filtered_inbound_nodes.
1583 node_data = node.serialize(
1584 _make_node_key, node_conversion_map
1585 )
1586 filtered_inbound_nodes.append(node_data)
1588 if isinstance(layer, Functional) and set_layers_legacy:
1589 layer.use_legacy_config = True
1590 layer_config = serialize_layer_fn(layer)
1591 layer_config["name"] = layer.name
1592 layer_config["inbound_nodes"] = filtered_inbound_nodes
1593 layer_configs.append(layer_config)
1594 config["layers"] = layer_configs
1596 # Gather info about inputs and outputs.
1597 model_inputs = []
1598 for i in range(len(network._input_layers)):
1599 layer, node_index, tensor_index = network._input_coordinates[i]
1600 node_key = _make_node_key(layer.name, node_index)
1601 if node_key not in network._network_nodes:
1602 continue
1603 new_node_index = node_conversion_map[node_key]
1604 model_inputs.append(
1605 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])
1606 )
1607 model_inputs = tf.nest.pack_sequence_as(
1608 network._nested_inputs, model_inputs
1609 )
1610 # Preserve external Keras compat for Models with single input.
1611 if not tf.nest.is_nested(model_inputs):
1612 model_inputs = [model_inputs]
1613 model_inputs = tf_utils.convert_inner_node_data(model_inputs)
1614 config["input_layers"] = model_inputs
1616 model_outputs = []
1617 for i in range(len(network._output_layers)):
1618 layer, node_index, tensor_index = network._output_coordinates[i]
1619 node_key = _make_node_key(layer.name, node_index)
1620 if node_key not in network._network_nodes:
1621 continue
1622 new_node_index = node_conversion_map[node_key]
1623 model_outputs.append(
1624 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])
1625 )
1626 model_outputs = tf.nest.pack_sequence_as(
1627 network._nested_outputs, model_outputs
1628 )
1629 # Preserve external Keras compat for Models with single output.
1630 if not tf.nest.is_nested(model_outputs):
1631 model_outputs = [model_outputs]
1632 model_outputs = tf_utils.convert_inner_node_data(model_outputs)
1633 config["output_layers"] = model_outputs
1634 return config
1637def shape_with_no_batch_size(x):
1638 if x.shape.rank is None:
1639 return None
1640 shape = x.shape.as_list()
1641 if shape:
1642 shape[0] = None
1643 return shape
1646class ModuleWrapper(base_layer.Layer):
1647 """Wrapper for `tf.Module`s to support the Functional and Sequential API."""
1649 def __init__(self, module, method_name=None, **kwargs):
1650 """Initializes the wrapper Layer for this module.
1652 Args:
1653 module: The `tf.Module` instance to be wrapped.
1654 method_name: (Optional) str. The name of the method to use as the
1655 forward pass of the module. If not set, becomes '__call__' if
1656 defined, or 'call'. Defaults to `None`.
1657 **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
1659 Raises:
1660 ValueError: If `method` is not defined on `module`.
1661 """
1662 super().__init__(**kwargs)
1663 if method_name is None:
1664 if hasattr(module, "__call__"):
1665 method_name = "__call__"
1666 elif hasattr(module, "call"):
1667 method_name = "call"
1668 if method_name is None or not hasattr(module, method_name):
1669 raise ValueError(f"{method_name} is not defined on object {module}")
1671 self._module = module
1672 self._method_name = method_name
1674 # Check if module.__call__ has a `training` arg or accepts `**kwargs`.
1675 method = getattr(module, method_name)
1676 method_arg_spec = tf_inspect.getfullargspec(method)
1677 self._call_spec.expects_training_arg = (
1678 "training" in method_arg_spec.args
1679 or method_arg_spec.varkw is not None
1680 )
1681 self._call_spec.expects_mask_arg = (
1682 "mask" in method_arg_spec.args or method_arg_spec.varkw is not None
1683 )
1685 def call(self, *args, **kwargs):
1686 if "training" in kwargs and not self._expects_training_arg:
1687 kwargs.pop("training")
1688 if "mask" in kwargs and not self._expects_mask_arg:
1689 kwargs.pop("mask")
1690 return getattr(self._module, self._method_name)(*args, **kwargs)
1693def has_functional_like_constructor(cls):
1694 init_args = tf_inspect.getfullargspec(cls.__init__).args[1:]
1695 functional_init_args = tf_inspect.getfullargspec(Functional.__init__).args[
1696 1:
1697 ]
1698 if init_args == functional_init_args:
1699 return True
1700 return False