Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/functional.py: 1%
737 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""A `Network` is way to compose layers: the topological form of a `Model`."""
18import collections
19import copy
20import itertools
21import warnings
23import tensorflow.compat.v2 as tf
25from keras.src import backend
26from keras.src.dtensor import layout_map as layout_map_lib
27from keras.src.engine import base_layer
28from keras.src.engine import base_layer_utils
29from keras.src.engine import functional_utils
30from keras.src.engine import input_layer as input_layer_module
31from keras.src.engine import input_spec
32from keras.src.engine import node as node_module
33from keras.src.engine import training as training_lib
34from keras.src.engine import training_utils
35from keras.src.saving import serialization_lib
36from keras.src.saving.legacy import serialization
37from keras.src.saving.legacy.saved_model import json_utils
38from keras.src.saving.legacy.saved_model import network_serialization
39from keras.src.saving.legacy.saved_model import utils as saved_model_utils
40from keras.src.utils import generic_utils
41from keras.src.utils import tf_inspect
42from keras.src.utils import tf_utils
44# isort: off
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.tools.docs import doc_controls
49class Functional(training_lib.Model):
50 """A `Functional` model is a `Model` defined as a directed graph of layers.
52 Three types of `Model` exist: subclassed `Model`, `Functional` model,
53 and `Sequential` (a special case of `Functional`).
54 In general, more Keras features are supported with `Functional`
55 than with subclassed `Model`s, specifically:
57 - Model cloning (`keras.models.clone`)
58 - Serialization (`model.get_config()/from_config`, `model.to_json()`
59 - Whole-model saving (`model.save()`)
61 A `Functional` model can be instantiated by passing two arguments to
62 `__init__`. The first argument is the `keras.Input` Tensors that represent
63 the inputs to the model. The second argument specifies the output
64 tensors that represent the outputs of this model. Both arguments can be a
65 nested structure of tensors.
67 Example:
69 ```
70 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
71 t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
72 outputs = keras.layers.Add()([t, inputs['x2'])
73 model = keras.Model(inputs, outputs)
74 ```
76 A `Functional` model constructed using the Functional API can also include
77 raw TensorFlow functions, with the exception of functions that create
78 Variables or assign ops.
80 Example:
82 ```python
83 inputs = keras.Input(shape=(10,))
84 x = keras.layers.Dense(1)(inputs)
85 outputs = tf.nn.relu(x)
86 model = keras.Model(inputs, outputs)
87 ```
89 A new `Functional` model can also be created by using the
90 intermediate tensors. This enables you to quickly extract sub-components
91 of the model.
93 Example:
95 ```python
96 inputs = keras.Input(shape=(None, None, 3))
97 processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
98 conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
99 pooling = keras.layers.GlobalAveragePooling2D()(conv)
100 feature = keras.layers.Dense(10)(pooling)
102 full_model = keras.Model(inputs, feature)
103 backbone = keras.Model(processed, conv)
104 activations = keras.Model(conv, feature)
105 ```
107 Note that the `backbone` and `activations` models are not
108 created with `keras.Input` objects, but with the tensors that are originated
109 from `keras.Input` objects. Under the hood, the layers and weights will
110 be shared across these models, so that user can train the `full_model`, and
111 use `backbone` or `activations` to do feature extraction.
112 The inputs and outputs of the model can be nested structures of tensors as
113 well, and the created models are standard `Functional` model that support
114 all the existing API.
116 Args:
117 inputs: List of input tensors (must be created via `tf.keras.Input()` or
118 originated from `tf.keras.Input()`).
119 outputs: List of output tensors.
120 name: String, optional. Name of the model.
121 trainable: Boolean, optional. If the model's variables should be
122 trainable.
123 """
125 # See tf.Module for the usage of this property.
126 # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail
127 # to flatten the key since it is trying to convert Trackable/Layer to a
128 # string.
129 _TF_MODULE_IGNORED_PROPERTIES = frozenset(
130 itertools.chain(
131 (
132 "_layer_call_argspecs",
133 "_output_mask_cache",
134 "_output_tensor_cache",
135 "_output_shape_cache",
136 ),
137 training_lib.Model._TF_MODULE_IGNORED_PROPERTIES,
138 )
139 )
141 @tf.__internal__.tracking.no_automatic_dependency_tracking
142 def __init__(self, inputs, outputs, name=None, trainable=True, **kwargs):
143 # This is used by the Model class, since we have some logic to swap the
144 # class in the __new__ method, which will lead to __init__ get invoked
145 # twice. Using the skip_init to skip one of the invocation of __init__
146 # to avoid any side effects
147 skip_init = kwargs.pop("skip_init", False)
148 if skip_init:
149 return
150 generic_utils.validate_kwargs(kwargs, {})
151 super().__init__(name=name, trainable=trainable)
152 # Check if the inputs contain any intermediate `KerasTensor` (not
153 # created by tf.keras.Input()). In this case we need to clone the `Node`
154 # and `KerasTensor` objects to mimic rebuilding a new model from new
155 # inputs. This feature is only enabled in TF2 not in v1 graph mode.
156 if tf.compat.v1.executing_eagerly_outside_functions():
157 if not all(
158 [
159 functional_utils.is_input_keras_tensor(t)
160 for t in tf.nest.flatten(inputs)
161 ]
162 ):
163 inputs, outputs = functional_utils.clone_graph_nodes(
164 inputs, outputs
165 )
166 self._init_graph_network(inputs, outputs)
168 @tf.__internal__.tracking.no_automatic_dependency_tracking
169 def _init_graph_network(self, inputs, outputs):
170 # This method is needed for Sequential to reinitialize graph network
171 # when layer is added or removed.
173 base_layer.keras_api_gauge.get_cell("Functional").set(True)
174 self._is_graph_network = True
176 # Normalize and set self.inputs, self.outputs.
177 if isinstance(inputs, list) and len(tf.nest.flatten(inputs)) == 1:
178 inputs = inputs[0]
179 if isinstance(outputs, list) and len(tf.nest.flatten(outputs)) == 1:
180 outputs = outputs[0]
181 self._nested_inputs = inputs
182 self._nested_outputs = outputs
183 self.inputs = tf.nest.flatten(inputs)
184 self.outputs = tf.nest.flatten(outputs)
186 # Models constructed with a single Tensor or list of Tensors can
187 # be called with a dict, where the keys of the dict are the names
188 # of the `Input` objects. Extra keys are ignored with warning.
189 if not tf.nest.is_nested(self._nested_inputs):
190 self._enable_dict_to_input_mapping = True
191 elif isinstance(self._nested_inputs, (list, tuple)) and not any(
192 tf.nest.is_nested(t) for t in self._nested_inputs
193 ):
194 self._enable_dict_to_input_mapping = True
195 elif isinstance(self._nested_inputs, dict) and not any(
196 tf.nest.is_nested(t) for t in self._nested_inputs.values()
197 ):
198 self._enable_dict_to_input_mapping = True
199 else:
200 self._enable_dict_to_input_mapping = False
202 if not tf.compat.v1.executing_eagerly_outside_functions():
203 if any(
204 not hasattr(tensor, "_keras_history") for tensor in self.outputs
205 ):
206 base_layer_utils.create_keras_history(self._nested_outputs)
208 self._validate_graph_inputs_and_outputs()
210 # A Network does not create weights of its own, thus it is already
211 # built.
212 self.built = True
213 self._build_input_shape = tf.nest.map_structure(
214 lambda x: x.shape, inputs
215 )
216 self._compute_output_and_mask_jointly = True
217 # `_expects_training_arg` is True since the `training` argument is
218 # always present in the signature of the `call` method of a graph
219 # network.
220 self._call_spec.expects_training_arg = True
221 self._call_spec.expects_mask_arg = True
222 # A graph network does not autocast inputs, as its layers will cast them
223 # instead.
224 self._autocast = False
226 self._input_layers = []
227 self._output_layers = []
228 self._input_coordinates = []
229 self._output_coordinates = []
231 # This is for performance optimization when calling the Network on new
232 # inputs. Every time the Network is called on a set on input tensors, we
233 # compute the output tensors, output masks and output shapes in one
234 # pass, then cache them here. When any of these outputs is queried
235 # later, we retrieve it from there instead of recomputing it.
236 self._output_mask_cache = {}
237 self._output_tensor_cache = {}
238 self._output_shape_cache = {}
240 # Build self._output_layers:
241 for x in self.outputs:
242 (
243 layer,
244 node_index,
245 tensor_index,
246 ) = x._keras_history
247 self._output_layers.append(layer)
248 self._output_coordinates.append((layer, node_index, tensor_index))
250 # Build self._input_layers:
251 for x in self.inputs:
252 (
253 layer,
254 node_index,
255 tensor_index,
256 ) = x._keras_history
257 # It's supposed to be an input layer, so only one node
258 # and one tensor output.
259 assert node_index == 0
260 assert tensor_index == 0
261 self._input_layers.append(layer)
262 self._input_coordinates.append((layer, node_index, tensor_index))
264 # Keep track of the network's nodes and layers.
265 nodes, nodes_by_depth, layers, _ = _map_graph_network(
266 self.inputs, self.outputs
267 )
268 self._network_nodes = nodes
269 self._nodes_by_depth = nodes_by_depth
270 self._self_tracked_trackables = layers
271 self._layer_call_argspecs = {}
272 for layer in self._self_tracked_trackables:
273 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(
274 layer.call
275 )
277 # Build self.input_names and self.output_names.
278 self._set_output_names()
279 self.input_names = []
280 self._feed_input_names = []
281 self._feed_inputs = []
282 self._feed_input_shapes = []
283 for layer in self._input_layers:
284 self.input_names.append(layer.name)
285 if layer.is_placeholder:
286 self._feed_input_names.append(layer.name)
287 # Use batch_input_shape here because non-eager composite tensors
288 # may not have a shape attribute that's meaningful (sparse, for
289 # instance, has a tensor that's non-constant and needs to be
290 # fed). This means that input layers that create placeholders
291 # will need to have the batch_input_shape attr to allow for
292 # input shape validation.
293 self._feed_input_shapes.append(layer._batch_input_shape)
294 self._feed_inputs.append(layer.input)
296 self._compute_tensor_usage_count()
297 self._set_save_spec(self._nested_inputs)
298 tf_utils.assert_no_legacy_layers(self.layers)
300 # Note that this method is used by both functional and sequential
301 # models, so we can't just have this method in functional.__init__,
302 # which will miss the coverage of sequential model.
303 if self._layout_map is not None:
304 layout_map_lib._map_functional_model_variable(
305 self, self._layout_map
306 )
308 @property
309 def input(self):
310 """Retrieves the input tensor(s) of a layer.
312 Only applicable if the layer has exactly one input,
313 i.e. if it is connected to one incoming layer.
315 Returns:
316 Input tensor or list of input tensors.
318 Raises:
319 RuntimeError: If called in Eager mode.
320 AttributeError: If no inbound nodes are found.
321 """
322 return self._nested_inputs
324 @property
325 def input_shape(self):
326 """Retrieves the input shape(s) of a layer.
328 Only applicable if the layer has exactly one input,
329 i.e. if it is connected to one incoming layer, or if all inputs
330 have the same shape.
332 Returns:
333 Input shape, as an integer shape tuple
334 (or list of shape tuples, one tuple per input tensor).
336 Raises:
337 AttributeError: if the layer has no defined input_shape.
338 RuntimeError: if called in Eager mode.
339 """
340 return tf.nest.map_structure(backend.int_shape, self.input)
342 @property
343 def input_spec(self):
344 if hasattr(self, "_manual_input_spec"):
345 return self._manual_input_spec
346 if isinstance(self._nested_inputs, (dict, list, tuple)) and len(
347 self._nested_inputs
348 ) != len(self.inputs):
349 # Case where we have a nested structure.
350 # In such a case we can't safely run any checks.
351 return None
352 if isinstance(self._nested_inputs, dict):
353 # Case where `_nested_inputs` is a plain dict of Inputs.
354 names = sorted(self._nested_inputs.keys())
355 return [
356 input_spec.InputSpec(
357 shape=shape_with_no_batch_size(self._nested_inputs[name]),
358 allow_last_axis_squeeze=True,
359 name=name,
360 )
361 for name in names
362 ]
363 else:
364 # Single input, or list / tuple of inputs.
365 # The data may be passed as a dict keyed by input name.
366 return [
367 input_spec.InputSpec(
368 shape=shape_with_no_batch_size(x),
369 allow_last_axis_squeeze=True,
370 name=x._keras_history.layer.name,
371 )
372 for x in self.inputs
373 ]
375 @input_spec.setter
376 def input_spec(self, value):
377 self._manual_input_spec = value
379 @property
380 def output(self):
381 """Retrieves the output tensor(s) of a layer.
383 Only applicable if the layer has exactly one output,
384 i.e. if it is connected to one incoming layer.
386 Returns:
387 Output tensor or list of output tensors.
389 Raises:
390 AttributeError: if the layer is connected to more than one incoming
391 layers.
392 RuntimeError: if called in Eager mode.
393 """
394 return self._nested_outputs
396 @property
397 def output_shape(self):
398 """Retrieves the output shape(s) of a layer.
400 Only applicable if the layer has one output,
401 or if all outputs have the same shape.
403 Returns:
404 Output shape, as an integer shape tuple
405 (or list of shape tuples, one tuple per output tensor).
407 Raises:
408 AttributeError: if the layer has no defined output shape.
409 RuntimeError: if called in Eager mode.
410 """
411 return tf.nest.map_structure(backend.int_shape, self.output)
413 def _set_output_names(self):
414 """Assigns unique names to the Network's outputs.
416 Output layers with multiple output tensors would otherwise lead to
417 duplicate names in self.output_names.
418 """
419 uniquified = []
420 output_names = set()
421 prefix_count = {}
422 for layer in self._output_layers:
423 proposal = layer.name
424 while proposal in output_names:
425 existing_count = prefix_count.get(layer.name, 1)
426 proposal = f"{layer.name}_{existing_count}"
427 prefix_count[layer.name] = existing_count + 1
428 output_names.add(proposal)
429 uniquified.append(proposal)
430 self.output_names = uniquified
432 @property
433 def _layer_checkpoint_dependencies(self):
434 """Dictionary of layer dependencies to be included in the checkpoint."""
435 weight_layer_index = 0
437 dependencies = collections.OrderedDict()
438 for layer_index, layer in enumerate(self.layers):
439 try:
440 if layer.weights:
441 # Keep a separate index for layers which have weights. This
442 # allows users to insert Layers without weights anywhere in
443 # the network without breaking checkpoints.
444 dependencies[
445 "layer_with_weights-%d" % weight_layer_index
446 ] = layer
447 weight_layer_index += 1
448 except ValueError:
449 # The layer might have weights, but may not be built yet. We
450 # just treat it as layer without weight.
451 pass
453 # Even if it doesn't have weights, we should still track everything
454 # in case it has/will have Trackable dependencies.
455 dependencies["layer-%d" % layer_index] = layer
456 return dependencies
458 def _trackable_children(self, save_type="checkpoint", **kwargs):
459 dependencies = self._layer_checkpoint_dependencies
460 dependencies.update(super()._trackable_children(save_type, **kwargs))
461 return dependencies
463 def _lookup_dependency(self, name, cached_dependencies=None):
464 if cached_dependencies:
465 return cached_dependencies.get(name)
466 # Fall back to slow lookup (`layer_checkpoint_dependencies` does a
467 # thorough check of all layer to see if they contain weights.)
468 layer_dependencies = self._layer_checkpoint_dependencies
469 if name in layer_dependencies:
470 return layer_dependencies[name]
471 return super()._lookup_dependency(name)
473 def _handle_deferred_layer_dependencies(self, layers):
474 """Handles layer checkpoint dependencies that are added after init."""
475 layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
476 layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
477 for layer in layers:
478 if layer in layer_to_name:
479 self._handle_deferred_dependencies(
480 name=layer_to_name[layer], trackable=layer
481 )
483 @property
484 def _should_compute_mask(self):
485 return True
487 def compute_mask(self, inputs, mask):
488 # TODO(omalleyt): b/123540974 This function is not really safe to call
489 # by itself because it will duplicate any updates and losses in graph
490 # mode by `call`ing the Layers again.
491 output_tensors = self._run_internal_graph(inputs, mask=mask)
492 return tf.nest.map_structure(
493 lambda t: getattr(t, "_keras_mask", None), output_tensors
494 )
496 @doc_controls.do_not_doc_inheritable
497 def call(self, inputs, training=None, mask=None):
498 """Calls the model on new inputs.
500 In this case `call` just reapplies
501 all ops in the graph to the new inputs
502 (e.g. build a new computational graph from the provided inputs).
504 Args:
505 inputs: A tensor or list of tensors.
506 training: Boolean or boolean scalar tensor, indicating whether to
507 run the `Network` in training mode or inference mode.
508 mask: A mask or list of masks. A mask can be
509 either a tensor or None (no mask).
511 Returns:
512 A tensor if there is a single output, or
513 a list of tensors if there are more than one outputs.
514 """
515 return self._run_internal_graph(inputs, training=training, mask=mask)
517 def compute_output_shape(self, input_shape):
518 # Convert any shapes in tuple format to TensorShapes.
519 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
521 if len(tf.nest.flatten(input_shape)) != len(
522 tf.nest.flatten(self._input_layers)
523 ):
524 raise ValueError(
525 f"Invalid `input_shape` argument {input_shape}: "
526 f"the model expects {len(self._input_layers)} "
527 "input tensors."
528 )
530 # Use the tuple of TensorShape as the cache key, since tuple is hashable
531 # and can be used as hash key.
532 try:
533 cache_key = tuple(
534 tf_utils.convert_shapes(input_shape, to_tuples=True)
535 )
536 if cache_key in self._output_shape_cache:
537 # Cache hit. Return shapes as TensorShapes.
538 return self._output_shape_cache[cache_key]
539 except ValueError:
540 # In case there are unknown TensorShape, eg for sparse tensor input,
541 # We skip the caching since the shape is unknown.
542 pass
544 layers_to_output_shapes = {}
545 for layer, shape in zip(
546 self._input_layers, tf.nest.flatten(input_shape)
547 ):
548 # It's an input layer: then `compute_output_shape` is identity,
549 # and there is only one node and one tensor..
550 shape_key = layer.name + "_0_0"
551 layers_to_output_shapes[shape_key] = shape
553 depth_keys = list(self._nodes_by_depth.keys())
554 depth_keys.sort(reverse=True)
555 # Iterate over nodes, by depth level.
556 if len(depth_keys) > 1:
557 for depth in depth_keys:
558 nodes = self._nodes_by_depth[depth]
559 for node in nodes:
560 layer = node.layer
561 if layer in self._input_layers:
562 # We've already covered the input layers
563 # a few lines above.
564 continue
565 # Get the input shapes for the first argument of the node
566 layer_input_shapes = []
567 layer_inputs = node.call_args[0]
568 for layer_input in tf.nest.flatten(layer_inputs):
569 kh = layer_input._keras_history
570 input_layer_key = kh.layer.name + "_%s_%s" % (
571 kh.node_index,
572 kh.tensor_index,
573 )
574 layer_input_shapes.append(
575 layers_to_output_shapes[input_layer_key]
576 )
577 layer_input_shapes = tf.nest.pack_sequence_as(
578 layer_inputs, layer_input_shapes
579 )
580 # Layers expect shapes to be tuples for
581 # `compute_output_shape`.
582 layer_input_shapes = tf_utils.convert_shapes(
583 layer_input_shapes, to_tuples=True
584 )
585 layer_output_shapes = layer.compute_output_shape(
586 layer_input_shapes
587 )
588 # Convert back to TensorShapes.
589 layer_output_shapes = tf_utils.convert_shapes(
590 layer_output_shapes, to_tuples=False
591 )
593 node_index = layer._inbound_nodes.index(node)
594 for j, shape in enumerate(
595 tf.nest.flatten(layer_output_shapes)
596 ):
597 shape_key = layer.name + f"_{node_index}_{j}"
598 layers_to_output_shapes[shape_key] = shape
600 # Read final output shapes from layers_to_output_shapes.
601 output_shapes = []
602 for i in range(len(self._output_layers)):
603 layer, node_index, tensor_index = self._output_coordinates[i]
604 shape_key = layer.name + f"_{node_index}_{tensor_index}"
605 output_shapes.append(layers_to_output_shapes[shape_key])
606 output_shapes = tf.nest.pack_sequence_as(
607 self._nested_outputs, output_shapes
608 )
609 # Store in cache.
610 self._output_shape_cache[cache_key] = output_shapes
612 # Return shapes as TensorShapes.
613 return output_shapes
615 def _init_set_name(self, name, zero_based=True):
616 if not name:
617 cls_name = self.__class__.__name__
618 if self.__class__ == Functional:
619 # Hide the functional class name from user, since its not a
620 # public visible class. Use "Model" instead,
621 cls_name = "Model"
622 self._name = backend.unique_object_name(
623 generic_utils.to_snake_case(cls_name), zero_based=zero_based
624 )
625 else:
626 self._name = name
628 def _run_internal_graph(self, inputs, training=None, mask=None):
629 """Computes output tensors for new inputs.
631 # Note:
632 - Can be run on non-Keras tensors.
634 Args:
635 inputs: Tensor or nested structure of Tensors.
636 training: Boolean learning phase.
637 mask: (Optional) Tensor or nested structure of Tensors.
639 Returns:
640 output_tensors
641 """
642 inputs = self._flatten_to_reference_inputs(inputs)
643 if mask is None:
644 masks = [None] * len(inputs)
645 else:
646 masks = self._flatten_to_reference_inputs(mask)
647 for input_t, mask in zip(inputs, masks):
648 input_t._keras_mask = mask
650 # Dictionary mapping reference tensors to computed tensors.
651 tensor_dict = {}
652 tensor_usage_count = self._tensor_usage_count
653 for x, y in zip(self.inputs, inputs):
654 y = self._conform_to_reference_input(y, ref_input=x)
655 x_id = str(id(x))
656 tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
658 nodes_by_depth = self._nodes_by_depth
659 depth_keys = list(nodes_by_depth.keys())
660 depth_keys.sort(reverse=True)
662 for depth in depth_keys:
663 nodes = nodes_by_depth[depth]
664 for node in nodes:
665 if node.is_input:
666 continue # Input tensors already exist.
668 if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
669 continue # Node is not computable, try skipping.
671 args, kwargs = node.map_arguments(tensor_dict)
672 outputs = node.layer(*args, **kwargs)
674 # Update tensor_dict.
675 for x_id, y in zip(
676 node.flat_output_ids, tf.nest.flatten(outputs)
677 ):
678 tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
680 output_tensors = []
681 for x in self.outputs:
682 x_id = str(id(x))
683 assert x_id in tensor_dict, "Could not compute output " + str(x)
684 output_tensors.append(tensor_dict[x_id].pop())
686 return tf.nest.pack_sequence_as(self._nested_outputs, output_tensors)
688 def _flatten_to_reference_inputs(self, tensors):
689 """Maps `tensors` to their respective `keras.Input`."""
690 if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
691 ref_inputs = self._nested_inputs
692 if not tf.nest.is_nested(ref_inputs):
693 ref_inputs = [self._nested_inputs]
694 if isinstance(ref_inputs, dict):
695 # In the case that the graph is constructed with dict input
696 # tensors, We will use the original dict key to map with the
697 # keys in the input data. Note that the model.inputs is using
698 # nest.flatten to process the input tensors, which means the
699 # dict input tensors are ordered by their keys.
700 ref_input_names = sorted(ref_inputs.keys())
701 else:
702 ref_input_names = [
703 inp._keras_history.layer.name for inp in ref_inputs
704 ]
706 # Raise an warning if there are more input data comparing to input
707 # tensor
708 if len(tensors) > len(ref_input_names):
709 warnings.warn(
710 "Input dict contained keys {} which did not match any "
711 "model input. They will be ignored by the model.".format(
712 [n for n in tensors.keys() if n not in ref_input_names]
713 ),
714 stacklevel=2,
715 )
717 try:
718 # Flatten in the order `Input`s were passed during Model
719 # construction.
720 return [tensors[n] for n in ref_input_names]
721 except KeyError:
722 # TODO(b/151582614)
723 return tf.nest.flatten(tensors)
725 # Otherwise both self.inputs and tensors will already be in same order.
726 return tf.nest.flatten(tensors)
728 def _conform_to_reference_input(self, tensor, ref_input):
729 """Set shape and dtype based on `keras.Input`s."""
730 if isinstance(tensor, tf.Tensor):
731 # Allow (None,) and (None, 1) Tensors to be passed interchangeably.
732 # Use the shape specified by the `keras.Input`.
733 t_shape = tensor.shape
734 t_rank = t_shape.rank
735 ref_shape = ref_input.shape
736 ref_rank = ref_shape.rank
737 keras_history = getattr(tensor, "_keras_history", None)
738 if t_rank is not None and ref_rank is not None:
739 # Should squeeze last dimension. True if tensor is (BATCH, ...,
740 # 1) and reference is (BATCH, ...).
741 if t_rank == ref_rank + 1 and t_shape[-1] == 1:
742 tensor = tf.squeeze(tensor, axis=-1)
743 # Should expand last_dimension. True if tensor is (BATCH, ...)
744 # and reference is (BATCH, ..., 1).
745 elif t_rank == ref_rank - 1 and ref_shape[-1] == 1:
746 tensor = tf.expand_dims(tensor, axis=-1)
747 if keras_history is not None: # Restore keras history.
748 tensor._keras_history = keras_history
750 # Dtype casting.
751 tensor = tf.cast(tensor, dtype=ref_input.dtype)
752 elif tf_utils.is_extension_type(tensor):
753 # Dtype casting (If the extension type has a non-variant dtype and
754 # supports being cast). Only cast if necessary (since some
755 # extension types may not implement tf.cast).
756 tensor_dtype = getattr(tensor, "dtype", None)
757 ref_input_dtype = getattr(ref_input, "dtype", None)
758 if (
759 ref_input_dtype is not None
760 and tensor_dtype is not None
761 and tensor_dtype != ref_input_dtype
762 and ref_input_dtype != tf.variant
763 ):
764 tensor = tf.cast(tensor, dtype=ref_input_dtype)
766 return tensor
768 @generic_utils.default
769 def get_config(self):
770 # Prepare base arguments
771 config = {
772 "name": self.name,
773 "trainable": self.trainable,
774 }
776 if saved_model_utils.in_tf_saved_model_scope():
777 # SavedModel special case: need to preserve legacy (potentially
778 # incorrect) behavior.
779 return copy.deepcopy(get_network_config(self, config=config))
781 # Check whether the class has a constructor compatible with a Functional
782 # model or if it has a custom constructor.
783 if has_functional_like_constructor(self.__class__):
784 # Only return a Functional config if the constructor is the same
785 # as that of a Functional model. This excludes subclassed Functional
786 # models with a custom __init__.
787 config = copy.deepcopy(get_network_config(self, config=config))
788 else:
789 # Try to autogenerate config
790 xtra_args = set(config.keys())
791 if getattr(self, "_auto_get_config", False):
792 config.update(self._auto_config.config)
793 # Remove args non explicitly supported
794 argspec = tf_inspect.getfullargspec(self.__init__)
795 if argspec.varkw != "kwargs":
796 for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
797 config.pop(key, None)
798 return config
800 def get_weight_paths(self):
801 result = {}
802 for layer in self.layers:
803 (
804 descendants,
805 object_paths_dict,
806 ) = tf.__internal__.tracking.ObjectGraphView(
807 layer
808 ).breadth_first_traversal()
809 for descendant in descendants:
810 if isinstance(descendant, tf.Variable):
811 trackable_references = object_paths_dict[descendant]
812 object_path = ".".join(
813 [t.name for t in trackable_references]
814 )
815 result[layer.name + "." + object_path] = descendant
816 return result
818 def _validate_graph_inputs_and_outputs(self):
819 """Validates the inputs and outputs of a Graph Network."""
820 # Check for redundancy in inputs.
821 if len({id(i) for i in self.inputs}) != len(self.inputs):
822 raise ValueError(
823 "The list of inputs passed to the model "
824 "contains the same input multiple times. "
825 "All inputs should only appear once."
826 f"Received inputs={self.inputs}"
827 )
829 for x in self.inputs:
830 # Check that x has appropriate `_keras_history` metadata.
831 if not hasattr(x, "_keras_history"):
832 cls_name = self.__class__.__name__
833 raise ValueError(
834 f"Input tensors to a {cls_name} model "
835 "must come from `tf.keras.Input`. "
836 f"Received inputs={x} (missing previous layer metadata)."
837 )
838 # Check that x is an input tensor.
840 layer = x._keras_history.layer
841 if len(layer._inbound_nodes) > 1 or (
842 layer._inbound_nodes and not layer._inbound_nodes[0].is_input
843 ):
844 cls_name = self.__class__.__name__
845 logging.warning(
846 f"{cls_name} model inputs must come from "
847 "`tf.keras.Input` (thus holding past layer metadata). "
848 "They cannot be the output of "
849 "a previous non-Input layer. "
850 "Here, a tensor specified as "
851 f'input to "{self.name}" was not an Input tensor, '
852 f'it was generated by layer "{layer.name}".\n'
853 "Note that input tensors are "
854 "instantiated via `tensor = tf.keras.Input(shape)`.\n"
855 f"The tensor that caused the issue was: {x}"
856 )
858 # Check compatibility of batch sizes of Input Layers.
859 input_batch_sizes = set(
860 [
861 training_utils.get_static_batch_size(x._keras_history.layer)
862 for x in self.inputs
863 ]
864 )
865 input_batch_sizes.discard(None)
866 if len(input_batch_sizes) > 1:
867 logging.warning(
868 "Found incompatible static batch sizes among the "
869 f"inputs. Batch sizes: {sorted(input_batch_sizes)}"
870 )
872 for x in self.outputs:
873 if not hasattr(x, "_keras_history"):
874 cls_name = self.__class__.__name__
875 raise ValueError(
876 f"Output tensors of a {cls_name} model must be "
877 "the output of a TensorFlow `Layer` "
878 f"(thus holding past layer metadata). Found: {x}"
879 )
881 def _insert_layers(self, layers, relevant_nodes=None):
882 """Inserts Layers into the Network after Network creation.
884 This is only valid for Keras Graph Networks. Layers added via this
885 function will be included in the `call` computation and `get_config` of
886 this Network. They will not be added to the Network's outputs.
888 Args:
889 layers: Arbitrary nested structure of Layers. Layers must be reachable
890 from one or more of the `keras.Input` Tensors that correspond to
891 this Network's inputs.
892 relevant_nodes: Nodes from the Layers that should be considered part
893 of this Network. If `None`, all Nodes will be considered part of
894 this Network.
896 Raises:
897 ValueError: If the layers depend on `Input`s not found in this Model.
898 """
899 layers = tf.nest.flatten(layers)
900 tf_utils.assert_no_legacy_layers(layers)
901 node_to_depth = {}
902 for depth, nodes in self._nodes_by_depth.items():
903 node_to_depth.update({node: depth for node in nodes})
904 # The nodes of these Layers that are relevant to this Network. If not
905 # provided, assume all Nodes are relevant
906 if not relevant_nodes:
907 relevant_nodes = tf.nest.flatten(
908 [layer._inbound_nodes for layer in layers]
909 )
910 network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
912 def _get_min_depth(node):
913 """Gets the minimum depth at which node can be computed."""
914 min_depth = 0
915 for layer, node_id, _, _ in node.iterate_inbound():
916 inbound_node = layer._inbound_nodes[node_id]
917 if inbound_node in node_to_depth:
918 min_depth = min(min_depth, node_to_depth[inbound_node])
919 elif inbound_node not in network_nodes:
920 continue
921 else:
922 # Previous relevant nodes haven't been processed yet.
923 return None
924 # New node is one shallower than its shallowest input.
925 return min_depth - 1
927 # Insert nodes into `_nodes_by_depth` and other node attrs.
928 unprocessed_nodes = copy.copy(relevant_nodes)
929 i = 0
930 while unprocessed_nodes:
931 i += 1
932 # Do a sanity check. This can occur if `Input`s from outside this
933 # Model are being relied on.
934 if i > 10000:
935 raise ValueError(
936 "Layers could not be added due to missing dependencies."
937 )
939 node = unprocessed_nodes.pop(0)
940 depth = _get_min_depth(node)
941 if depth is None: # Defer until inbound nodes are processed.
942 unprocessed_nodes.append(node)
943 continue
944 node_key = _make_node_key(
945 node.layer.name, node.layer._inbound_nodes.index(node)
946 )
947 if node_key not in self._network_nodes:
948 node_to_depth[node] = depth
949 self._network_nodes.add(node_key)
950 self._nodes_by_depth[depth].append(node)
952 # Insert layers and update other layer attrs.
953 layer_set = set(self._self_tracked_trackables)
954 deferred_layers = []
955 for layer in layers:
956 if layer not in layer_set:
957 self._self_tracked_trackables.append(layer)
958 deferred_layers.append(layer)
959 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(
960 layer.call
961 )
962 layer_set.add(layer)
963 self._handle_deferred_layer_dependencies(deferred_layers)
965 self._compute_tensor_usage_count()
967 def _compute_tensor_usage_count(self):
968 """Compute the #. of tensor usages for all the output tensors of layers.
970 The computed tensor usage count is saved as `self._tensor_usage_count`.
971 This is later used for saving memory in eager computation by releasing
972 no-longer-needed tensors as early as possible.
973 """
974 tensor_usage_count = collections.Counter()
975 available_tensors = set(str(id(tensor)) for tensor in self.inputs)
977 depth_keys = list(self._nodes_by_depth.keys())
978 depth_keys.sort(reverse=True)
979 depth_keys = depth_keys[1:]
981 for depth in depth_keys:
982 for node in self._nodes_by_depth[depth]:
983 input_tensors = {
984 str(id(tensor))
985 for tensor in tf.nest.flatten(node.keras_inputs)
986 }
987 if input_tensors.issubset(available_tensors):
988 for tensor in tf.nest.flatten(node.keras_inputs):
989 tensor_usage_count[str(id(tensor))] += 1
991 for output_tensor in tf.nest.flatten(node.outputs):
992 available_tensors.add(str(id(output_tensor)))
994 for tensor in self.outputs:
995 tensor_usage_count[str(id(tensor))] += 1
997 self._tensor_usage_count = tensor_usage_count
999 def _assert_weights_created(self):
1000 # Override the implementation in Model.
1001 # The Functional model should always have weight created already.
1002 return
1004 def _graph_network_add_loss(self, symbolic_loss):
1005 new_nodes, new_layers = _map_subgraph_network(
1006 self.inputs, [symbolic_loss]
1007 )
1008 # Losses must be keyed on inputs no matter what in order to be supported
1009 # in DistributionStrategy.
1010 add_loss_layer = base_layer.AddLoss(
1011 unconditional=False, dtype=symbolic_loss.dtype
1012 )
1013 add_loss_layer(symbolic_loss)
1014 new_nodes.extend(add_loss_layer.inbound_nodes)
1015 new_layers.append(add_loss_layer)
1016 self._insert_layers(new_layers, new_nodes)
1018 def _graph_network_add_metric(self, value, aggregation, name):
1019 new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
1020 add_metric_layer = base_layer.AddMetric(
1021 aggregation, name, dtype=value.dtype
1022 )
1023 add_metric_layer(value)
1024 new_nodes.extend(add_metric_layer.inbound_nodes)
1025 new_layers.append(add_metric_layer)
1026 self._insert_layers(new_layers, new_nodes)
1028 @property
1029 def _trackable_saved_model_saver(self):
1030 return network_serialization.NetworkSavedModelSaver(self)
1032 def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
1033 if getattr(self, "_has_explicit_input_shape", True):
1034 # Functional models and Sequential models that have an explicit
1035 # input shape should use the batch size set by the input layer.
1036 dynamic_batch = False
1037 return super()._get_save_spec(dynamic_batch, inputs_only)
1040def _make_node_key(layer_name, node_index):
1041 return layer_name + "_ib-" + str(node_index)
1044def _map_graph_network(inputs, outputs):
1045 """Validates a network's topology and gather its layers and nodes.
1047 Args:
1048 inputs: List of input tensors.
1049 outputs: List of outputs tensors.
1051 Returns:
1052 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
1053 - nodes: list of Node instances.
1054 - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
1055 - layers: list of Layer instances.
1056 - layers_by_depth: dict mapping ints (depth) to lists of layer instances.
1058 Raises:
1059 ValueError: In case the network is not valid (e.g. disconnected graph).
1060 """
1061 # "depth" is number of layers between output Node and the Node.
1062 # Nodes are ordered from inputs -> outputs.
1063 nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
1064 network_nodes = {
1065 _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
1066 for node in nodes_in_decreasing_depth
1067 }
1069 nodes_depths = {} # dict {node: depth value}
1070 layers_depths = {} # dict {layer: depth value}
1072 for node in reversed(nodes_in_decreasing_depth):
1073 # If the depth is not set, the node has no outbound nodes (depth 0).
1074 depth = nodes_depths.setdefault(node, 0)
1076 # Update the depth of the corresponding layer
1077 previous_depth = layers_depths.get(node.layer, 0)
1078 # If we've seen this layer before at a higher depth,
1079 # we should use that depth instead of the node depth.
1080 # This is necessary for shared layers that have inputs at different
1081 # depth levels in the graph.
1082 depth = max(depth, previous_depth)
1083 layers_depths[node.layer] = depth
1084 nodes_depths[node] = depth
1086 # Update the depth of inbound nodes.
1087 # The "depth" of a node is the max of the depths
1088 # of all nodes it is connected to + 1.
1089 for node_dep in node.parent_nodes:
1090 previous_depth = nodes_depths.get(node_dep, 0)
1091 nodes_depths[node_dep] = max(depth + 1, previous_depth)
1093 # Handle inputs that are not connected to outputs.
1094 # We do not error out here because the inputs may be used to compute losses
1095 # and metrics.
1096 for input_t in inputs:
1097 input_layer = input_t._keras_history[0]
1098 if input_layer not in layers_depths:
1099 layers_depths[input_layer] = 0
1100 layer_indices[input_layer] = -1
1101 nodes_depths[input_layer._inbound_nodes[0]] = 0
1102 network_nodes.add(_make_node_key(input_layer.name, 0))
1104 # Build a dict {depth: list of nodes with this depth}
1105 nodes_by_depth = collections.defaultdict(list)
1106 for node, depth in nodes_depths.items():
1107 nodes_by_depth[depth].append(node)
1109 # Build a dict {depth: list of layers with this depth}
1110 layers_by_depth = collections.defaultdict(list)
1111 for layer, depth in layers_depths.items():
1112 layers_by_depth[depth].append(layer)
1114 # Get sorted list of layer depths.
1115 depth_keys = list(layers_by_depth.keys())
1116 depth_keys.sort(reverse=True)
1118 # Set self.layers ordered by depth.
1119 layers = []
1120 for depth in depth_keys:
1121 layers_for_depth = layers_by_depth[depth]
1122 # Network.layers needs to have a deterministic order:
1123 # here we order them by traversal order.
1124 layers_for_depth.sort(key=lambda x: layer_indices[x])
1125 layers.extend(layers_for_depth)
1127 # Get sorted list of node depths.
1128 depth_keys = list(nodes_by_depth.keys())
1129 depth_keys.sort(reverse=True)
1131 # Check that all tensors required are computable.
1132 # computable_tensors: all tensors in the graph
1133 # that can be computed from the inputs provided.
1134 computable_tensors = set()
1135 for x in inputs:
1136 computable_tensors.add(id(x))
1138 layers_with_complete_input = [] # To provide a better error msg.
1139 for depth in depth_keys:
1140 for node in nodes_by_depth[depth]:
1141 layer = node.layer
1142 if layer and not node.is_input:
1143 for x in tf.nest.flatten(node.keras_inputs):
1144 if id(x) not in computable_tensors:
1145 raise ValueError(
1146 "Graph disconnected: cannot obtain value for "
1147 f'tensor {x} at layer "{layer.name}". '
1148 "The following previous layers were accessed "
1149 f"without issue: {layers_with_complete_input}"
1150 )
1151 for x in tf.nest.flatten(node.outputs):
1152 computable_tensors.add(id(x))
1153 layers_with_complete_input.append(layer.name)
1155 # Ensure name unicity, which will be crucial for serialization
1156 # (since serialized nodes refer to layers by their name).
1157 all_names = [layer.name for layer in layers]
1158 for name in all_names:
1159 if all_names.count(name) != 1:
1160 raise ValueError(
1161 f'The name "{name}" is used {all_names.count(name)} '
1162 "times in the model. All layer names should be unique."
1163 )
1164 return network_nodes, nodes_by_depth, layers, layers_by_depth
1167def _build_map(outputs):
1168 """This method topologically sorts nodes in order from inputs to outputs.
1170 It uses a depth-first search to topologically sort nodes that appear in the
1171 _keras_history connectivity metadata of `outputs`.
1173 Args:
1174 outputs: the output tensors whose _keras_history metadata should be
1175 walked. This may be an arbitrary nested structure.
1177 Returns:
1178 A tuple like (ordered_nodes, layer_to_first_traversal_index)
1179 ordered_nodes: list of nodes appearing in the keras history, topologically
1180 sorted from original inputs to the `outputs`.
1181 (If outputs have different sets of ancestors, the inputs to one output
1182 may appear after a different output).
1183 layer_to_first_traversal_index:
1184 A dict mapping layer to the traversal index in the DFS where it is
1185 seen. Note: if a layer is shared by several nodes, the dict will only
1186 store the index corresponding to the *first* time the layer seen.
1187 """
1188 finished_nodes = set()
1189 nodes_in_progress = set()
1190 nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
1191 layer_indices = {} # layer -> in traversal order.
1192 for output in tf.nest.flatten(outputs):
1193 _build_map_helper(
1194 output,
1195 finished_nodes,
1196 nodes_in_progress,
1197 nodes_in_decreasing_depth,
1198 layer_indices,
1199 )
1200 return nodes_in_decreasing_depth, layer_indices
1203def _build_map_helper(
1204 tensor,
1205 finished_nodes,
1206 nodes_in_progress,
1207 nodes_in_decreasing_depth,
1208 layer_indices,
1209):
1210 """Recursive helper for `_build_map`."""
1211 (
1212 layer,
1213 node_index,
1214 _,
1215 ) = tensor._keras_history
1216 node = layer._inbound_nodes[node_index]
1218 # Don't repeat work for shared subgraphs
1219 if node in finished_nodes:
1220 return
1222 # Prevent cycles.
1223 if node in nodes_in_progress:
1224 raise ValueError(
1225 f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.'
1226 )
1228 # Store the traversal order for layer sorting.
1229 if layer not in layer_indices:
1230 layer_indices[layer] = len(layer_indices)
1232 # Propagate to all previous tensors connected to this node.
1233 nodes_in_progress.add(node)
1234 if not node.is_input:
1235 for tensor in node.keras_inputs:
1236 _build_map_helper(
1237 tensor,
1238 finished_nodes,
1239 nodes_in_progress,
1240 nodes_in_decreasing_depth,
1241 layer_indices,
1242 )
1244 finished_nodes.add(node)
1245 nodes_in_progress.remove(node)
1246 nodes_in_decreasing_depth.append(node)
1249def _map_subgraph_network(inputs, outputs):
1250 """Returns the nodes and layers in the topology from `inputs` to `outputs`.
1252 Args:
1253 inputs: List of input tensors.
1254 outputs: List of output tensors.
1256 Returns:
1257 A tuple of List{Node] and List[Layer].
1258 """
1259 if not tf.compat.v1.executing_eagerly_outside_functions():
1260 base_layer_utils.create_keras_history(outputs)
1261 # Keep only nodes and layers in the topology between inputs and outputs.
1262 _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
1263 return tf.nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
1266def _should_skip_first_node(layer):
1267 """Returns True if the first layer node should not be saved or loaded."""
1268 # Networks that are constructed with an Input layer/shape start with a
1269 # pre-existing node linking their input to output. This node is excluded
1270 # from the network config.
1271 if not hasattr(layer, "_self_tracked_trackables"):
1272 # Special case for serialization of Functional models without
1273 # defined input shape argument.
1274 return isinstance(layer, Functional)
1275 if layer._self_tracked_trackables:
1276 return (
1277 isinstance(layer, Functional)
1278 # Filter out Sequential models without an input shape.
1279 and isinstance(
1280 layer._self_tracked_trackables[0], input_layer_module.InputLayer
1281 )
1282 )
1283 else:
1284 return isinstance(layer, Functional)
1287def connect_ancillary_layers(model, created_layers):
1288 """Adds layers that are not connected to the outputs to the model."""
1289 # Layers not connected to outputs, such as those added in `add_loss`.
1290 ancillary_layers = [
1291 layer for layer in created_layers.values() if layer not in model.layers
1292 ]
1293 if ancillary_layers:
1294 relevant_nodes = tf.nest.flatten(
1295 [
1296 layer.inbound_nodes[1:]
1297 if _should_skip_first_node(layer)
1298 else layer.inbound_nodes
1299 for layer in created_layers.values()
1300 ]
1301 )
1302 model._insert_layers(ancillary_layers, relevant_nodes)
1303 return model
1306def reconstruct_from_config(config, custom_objects=None, created_layers=None):
1307 """Reconstructs graph from config object.
1309 Args:
1310 config: Dictionary returned from Network.get_config()
1311 custom_objects: Optional dictionary mapping names (strings) to custom
1312 classes or functions to be considered during deserialization.
1313 created_layers: Optional dictionary mapping names to Layer objects. Any
1314 layer not in this dictionary will be created and added to the dict.
1315 This function will add new nodes to all layers (excluding InputLayers),
1316 instead of re-using pre-existing nodes in the layers.
1318 Returns:
1319 Tuple of (input tensors, output tensors, dictionary of created layers)
1320 """
1321 # Layer instances created during the graph reconstruction process.
1322 created_layers = created_layers or collections.OrderedDict()
1324 # Maps input data (tuple of inbound layer name, node index) from the config
1325 # to node indices in the newly generated model. The node indices may be
1326 # different if the layers have already been called previously.
1327 node_index_map = {}
1328 node_count_by_layer = {}
1330 # Dictionary mapping layer instances to
1331 # node data that specifies a layer call.
1332 # It acts as a queue that maintains any unprocessed
1333 # layer call until it becomes possible to process it
1334 # (i.e. until the input tensors to the call all exist).
1335 unprocessed_nodes = collections.defaultdict(list)
1337 def get_node_index(layer, config_node_index):
1338 """Returns node index in layer (might differ from config_node_index)."""
1339 if isinstance(layer, input_layer_module.InputLayer):
1340 return 0
1341 return node_index_map.get((layer.name, config_node_index), None)
1343 def _deserialize_keras_tensors(kwargs, layer_map):
1344 """Deserializes Keras Tensors passed to `call`.."""
1346 def _deserialize_keras_tensor(t):
1347 """Deserializes a single Keras Tensor passed to `call`."""
1348 if isinstance(t, tf_utils.ListWrapper):
1349 t = t.as_list()
1350 layer_name = t[0]
1351 node_index = t[1]
1352 tensor_index = t[2]
1354 layer = layer_map[layer_name]
1355 new_node_index = get_node_index(layer, node_index)
1356 if new_node_index is None:
1357 # The inbound node may not have been processed yet,
1358 # (This can happen e.g. if it depends on a different set
1359 # of inputs than those that have been processed already).
1360 # raise an IndexError so that the current node puts itself
1361 # back on the unprocessed queue.
1362 # Caution: This may lead to infinite loops for malformed
1363 # network configurations! (or when there is a bug in
1364 # the network config loading code).
1365 raise IndexError
1366 node = layer._inbound_nodes[new_node_index]
1367 return tf.nest.flatten(node.outputs)[tensor_index]
1368 return t
1370 kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
1371 return tf.nest.map_structure(_deserialize_keras_tensor, kwargs)
1373 def process_node(layer, node_data):
1374 """Deserialize a node.
1376 Args:
1377 layer: layer instance.
1378 node_data: Nested structure of `ListWrapper`.
1380 Returns:
1381 Whether the node was processed (i.e. the layer was called on the
1382 inputs specified by the node data)
1384 Raises:
1385 ValueError: In case of improperly formatted `node_data`.
1386 """
1387 input_tensors = []
1388 for input_data in tf.nest.flatten(node_data):
1389 input_data = input_data.as_list()
1390 if len(input_data) == 3:
1391 kwargs = {}
1392 elif len(input_data) == 4:
1393 kwargs = input_data[3]
1394 try:
1395 kwargs = _deserialize_keras_tensors(kwargs, created_layers)
1396 except IndexError:
1397 # Happens if keras tensors in kwargs are still unprocessed
1398 return False
1399 else:
1400 raise ValueError("Improperly formatted model config.")
1402 if input_data[0] != node_module._CONSTANT_VALUE:
1403 inbound_layer_name = input_data[0]
1404 inbound_node_index = input_data[1]
1405 inbound_tensor_index = input_data[2]
1406 inbound_layer = created_layers[inbound_layer_name]
1407 inbound_node_index = get_node_index(
1408 inbound_layer, inbound_node_index
1409 )
1411 if inbound_node_index is None:
1412 return False
1413 inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
1414 input_tensors.append(
1415 tf.nest.flatten(inbound_node.outputs)[inbound_tensor_index]
1416 )
1417 else:
1418 # We received a constant w/ no Keras history attached,
1419 # which means it is a constant tensor input.
1420 # Input is a constant value.
1421 # Format = [_CONSTANT_VALUE, -1, const_val, kwargs]
1422 assert input_data[1] == -1
1423 assert len(input_data) >= 3
1424 const_val = input_data[2]
1425 if (
1426 isinstance(const_val, tuple)
1427 and len(const_val) == 2
1428 and const_val[0] == node_module._COMPOSITE_TYPE
1429 ):
1430 # It is a composite tensor.
1431 input_tensors.append(json_utils.decode(const_val[1]))
1432 else:
1433 input_tensors.append(const_val)
1434 input_tensors = tf.nest.pack_sequence_as(node_data, input_tensors)
1435 # Call layer on its inputs, thus creating the node
1436 # and building the layer if needed.
1437 if input_tensors is not None:
1438 if (
1439 not hasattr(layer, "_preserve_input_structure_in_config")
1440 or not layer._preserve_input_structure_in_config
1441 ):
1442 input_tensors = base_layer_utils.unnest_if_single_tensor(
1443 input_tensors
1444 )
1445 output_tensors = layer(input_tensors, **kwargs)
1447 # Update node index map.
1448 output_index = tf.nest.flatten(output_tensors)[
1449 0
1450 ]._keras_history.node_index
1451 node_index_map[
1452 (layer.name, node_count_by_layer[layer])
1453 ] = output_index
1454 node_count_by_layer[layer] += 1
1455 return True
1457 def process_layer(layer_data):
1458 """Deserializes a layer, then call it on appropriate inputs.
1460 Args:
1461 layer_data: layer config dict.
1463 Raises:
1464 ValueError: In case of improperly formatted `layer_data` dict.
1465 """
1466 layer_name = layer_data["name"]
1468 if layer_name in created_layers:
1469 layer = created_layers[layer_name]
1470 else:
1471 # Instantiate layer.
1472 from keras.src.layers import deserialize as deserialize_layer
1474 layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1475 created_layers[layer_name] = layer
1477 node_count_by_layer[layer] = int(_should_skip_first_node(layer))
1479 # Gather layer inputs and convert to `ListWrapper` objects.
1480 inbound_nodes_data = layer_data["inbound_nodes"]
1481 inbound_nodes_data = tf_utils.convert_inner_node_data(
1482 inbound_nodes_data, wrap=True
1483 )
1484 for node_data in inbound_nodes_data:
1485 # We don't process nodes (i.e. make layer calls)
1486 # on the fly because the inbound node may not yet exist,
1487 # in case of layer shared at different topological depths
1488 # (e.g. a model such as A(B(A(B(x)))))
1489 unprocessed_nodes[layer].append(node_data)
1491 # First, we create all layers and enqueue nodes to be processed
1492 for layer_data in config["layers"]:
1493 process_layer(layer_data)
1494 # Then we process nodes in order of layer depth.
1495 # Nodes that cannot yet be processed (if the inbound node
1496 # does not yet exist) are re-enqueued, and the process
1497 # is repeated until all nodes are processed.
1498 while unprocessed_nodes:
1499 for layer_data in config["layers"]:
1500 layer = created_layers[layer_data["name"]]
1501 if layer in unprocessed_nodes:
1502 layer_nodes = unprocessed_nodes.pop(layer)
1503 while layer_nodes:
1504 node_data = layer_nodes[0]
1505 if process_node(layer, node_data):
1506 layer_nodes.pop(0)
1507 else:
1508 # If a node can't be processed, stop processing the
1509 # nodes of the current layer to maintain node ordering.
1510 unprocessed_nodes[layer] = layer_nodes
1511 break
1513 input_tensors = []
1514 output_tensors = []
1516 input_layers = tf_utils.convert_inner_node_data(
1517 config["input_layers"], wrap=True
1518 )
1519 for layer_data in tf.nest.flatten(input_layers):
1520 layer_name, node_index, tensor_index = layer_data.as_list()
1521 assert layer_name in created_layers
1522 layer = created_layers[layer_name]
1523 node_index = get_node_index(layer, node_index)
1524 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1525 input_tensors.append(
1526 tf.nest.flatten(layer_output_tensors)[tensor_index]
1527 )
1529 output_layers = tf_utils.convert_inner_node_data(
1530 config["output_layers"], wrap=True
1531 )
1532 for layer_data in tf.nest.flatten(output_layers):
1533 layer_name, node_index, tensor_index = layer_data.as_list()
1534 assert layer_name in created_layers
1535 layer = created_layers[layer_name]
1536 node_index = get_node_index(layer, node_index)
1537 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1538 output_tensors.append(
1539 tf.nest.flatten(layer_output_tensors)[tensor_index]
1540 )
1542 input_tensors = tf.nest.pack_sequence_as(input_layers, input_tensors)
1543 output_tensors = tf.nest.pack_sequence_as(output_layers, output_tensors)
1544 return input_tensors, output_tensors, created_layers
1547def get_network_config(network, serialize_layer_fn=None, config=None):
1548 """Build the config, which consists of the node graph and serialized layers.
1550 Args:
1551 network: A Network object.
1552 serialize_layer_fn: Function used to serialize layers.
1553 config: A dict to append more config entries into. If None, start with a
1554 new dict for the config.
1556 Returns:
1557 Config dictionary.
1558 """
1559 config = config or {}
1560 serialize_obj_fn = serialization_lib.serialize_keras_object
1561 set_layers_legacy = False
1562 # To be removed after full affected g3 user migration to Keras V3 Saving.
1563 if getattr(network, "use_legacy_config", False):
1564 serialize_obj_fn = serialization.serialize_keras_object
1565 set_layers_legacy = True
1566 serialize_layer_fn = serialize_layer_fn or serialize_obj_fn
1567 config["name"] = network.name
1568 node_conversion_map = {}
1569 for layer in network.layers:
1570 kept_nodes = 1 if _should_skip_first_node(layer) else 0
1571 for original_node_index, node in enumerate(layer._inbound_nodes):
1572 node_key = _make_node_key(layer.name, original_node_index)
1573 if node_key in network._network_nodes:
1574 node_conversion_map[node_key] = kept_nodes
1575 kept_nodes += 1
1576 layer_configs = []
1578 with serialization.SharedObjectSavingScope():
1579 for layer in network.layers: # From the earliest layers on.
1580 filtered_inbound_nodes = []
1581 for original_node_index, node in enumerate(layer._inbound_nodes):
1582 node_key = _make_node_key(layer.name, original_node_index)
1583 if node_key in network._network_nodes and not node.is_input:
1584 # The node is relevant to the model:
1585 # add to filtered_inbound_nodes.
1586 node_data = node.serialize(
1587 _make_node_key, node_conversion_map
1588 )
1589 filtered_inbound_nodes.append(node_data)
1591 if isinstance(layer, Functional) and set_layers_legacy:
1592 layer.use_legacy_config = True
1593 layer_config = serialize_layer_fn(layer)
1594 layer_config["name"] = layer.name
1595 layer_config["inbound_nodes"] = filtered_inbound_nodes
1596 layer_configs.append(layer_config)
1597 config["layers"] = layer_configs
1599 # Gather info about inputs and outputs.
1600 model_inputs = []
1601 for i in range(len(network._input_layers)):
1602 layer, node_index, tensor_index = network._input_coordinates[i]
1603 node_key = _make_node_key(layer.name, node_index)
1604 if node_key not in network._network_nodes:
1605 continue
1606 new_node_index = node_conversion_map[node_key]
1607 model_inputs.append(
1608 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])
1609 )
1610 model_inputs = tf.nest.pack_sequence_as(
1611 network._nested_inputs, model_inputs
1612 )
1613 # Preserve external Keras compat for Models with single input.
1614 if not tf.nest.is_nested(model_inputs):
1615 model_inputs = [model_inputs]
1616 model_inputs = tf_utils.convert_inner_node_data(model_inputs)
1617 config["input_layers"] = model_inputs
1619 model_outputs = []
1620 for i in range(len(network._output_layers)):
1621 layer, node_index, tensor_index = network._output_coordinates[i]
1622 node_key = _make_node_key(layer.name, node_index)
1623 if node_key not in network._network_nodes:
1624 continue
1625 new_node_index = node_conversion_map[node_key]
1626 model_outputs.append(
1627 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])
1628 )
1629 model_outputs = tf.nest.pack_sequence_as(
1630 network._nested_outputs, model_outputs
1631 )
1632 # Preserve external Keras compat for Models with single output.
1633 if not tf.nest.is_nested(model_outputs):
1634 model_outputs = [model_outputs]
1635 model_outputs = tf_utils.convert_inner_node_data(model_outputs)
1636 config["output_layers"] = model_outputs
1637 return config
1640def shape_with_no_batch_size(x):
1641 if x.shape.rank is None:
1642 return None
1643 shape = x.shape.as_list()
1644 if shape:
1645 shape[0] = None
1646 return shape
1649class ModuleWrapper(base_layer.Layer):
1650 """Wrapper for `tf.Module`s to support the Functional and Sequential API."""
1652 def __init__(self, module, method_name=None, **kwargs):
1653 """Initializes the wrapper Layer for this module.
1655 Args:
1656 module: The `tf.Module` instance to be wrapped.
1657 method_name: (Optional) str. The name of the method to use as the
1658 forward pass of the module. If not set, becomes '__call__' if
1659 defined, or 'call'. Defaults to `None`.
1660 **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
1662 Raises:
1663 ValueError: If `method` is not defined on `module`.
1664 """
1665 super().__init__(**kwargs)
1666 if method_name is None:
1667 if hasattr(module, "__call__"):
1668 method_name = "__call__"
1669 elif hasattr(module, "call"):
1670 method_name = "call"
1671 if method_name is None or not hasattr(module, method_name):
1672 raise ValueError(f"{method_name} is not defined on object {module}")
1674 self._module = module
1675 self._method_name = method_name
1677 # Check if module.__call__ has a `training` arg or accepts `**kwargs`.
1678 method = getattr(module, method_name)
1679 method_arg_spec = tf_inspect.getfullargspec(method)
1680 self._call_spec.expects_training_arg = (
1681 "training" in method_arg_spec.args
1682 or method_arg_spec.varkw is not None
1683 )
1684 self._call_spec.expects_mask_arg = (
1685 "mask" in method_arg_spec.args or method_arg_spec.varkw is not None
1686 )
1688 def call(self, *args, **kwargs):
1689 if "training" in kwargs and not self._expects_training_arg:
1690 kwargs.pop("training")
1691 if "mask" in kwargs and not self._expects_mask_arg:
1692 kwargs.pop("mask")
1693 return getattr(self._module, self._method_name)(*args, **kwargs)
1696def has_functional_like_constructor(cls):
1697 init_args = tf_inspect.getfullargspec(cls.__init__).args[1:]
1698 functional_init_args = tf_inspect.getfullargspec(Functional.__init__).args[
1699 1:
1700 ]
1701 if init_args == functional_init_args:
1702 return True
1703 return False