Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py: 12%
699 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""A `Network` is way to compose layers: the topological form of a `Model`."""
18import collections
19import copy
20import itertools
21import warnings
23from tensorflow.python.eager import context
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.keras import backend
27from tensorflow.python.keras.engine import base_layer
28from tensorflow.python.keras.engine import base_layer_utils
29from tensorflow.python.keras.engine import input_layer as input_layer_module
30from tensorflow.python.keras.engine import input_spec
31from tensorflow.python.keras.engine import node as node_module
32from tensorflow.python.keras.engine import training as training_lib
33from tensorflow.python.keras.engine import training_utils
34from tensorflow.python.keras.saving.saved_model import network_serialization
35from tensorflow.python.keras.utils import generic_utils
36from tensorflow.python.keras.utils import tf_inspect
37from tensorflow.python.keras.utils import tf_utils
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.trackable import base as trackable
42from tensorflow.python.util import nest
43from tensorflow.tools.docs import doc_controls
46# pylint: disable=g-classes-have-attributes
47class Functional(training_lib.Model):
48 """A `Functional` model is a `Model` defined as a directed graph of layers.
50 Three types of `Model` exist: subclassed `Model`, `Functional` model,
51 and `Sequential` (a special case of `Functional`).
52 In general, more Keras features are supported with `Functional`
53 than with subclassed `Model`s, specifically:
55 - Model cloning (`keras.models.clone`)
56 - Serialization (`model.get_config()/from_config`, `model.to_json()`
57 - Whole-model saving (`model.save()`)
59 A `Functional` model can be instantiated by passing two arguments to
60 `__init__`. The first argument is the `keras.Input` Tensors that represent
61 the inputs to the model. The second argument specifies the output
62 tensors that represent the outputs of this model. Both arguments can be a
63 nested structure of tensors.
65 Example:
67 ```
68 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
69 t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
70 outputs = keras.layers.Add()([t, inputs['x2'])
71 model = keras.Model(inputs, outputs)
72 ```
74 A `Functional` model constructed using the Functional API can also include raw
75 TensorFlow functions, with the exception of functions that create Variables
76 or assign ops.
78 Example:
80 ```
81 inputs = keras.Input(shape=(10,))
82 x = keras.layers.Dense(1)(inputs)
83 outputs = tf.nn.relu(x)
84 model = keras.Model(inputs, outputs)
85 ```
87 Args:
88 inputs: List of input tensors (must be created via `tf.keras.Input()`).
89 outputs: List of output tensors.
90 name: String, optional. Name of the model.
91 trainable: Boolean, optional. If the model's variables should be trainable.
92 """
94 # See tf.Module for the usage of this property.
95 # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to
96 # flatten the key since it is trying to convert Trackable/Layer to a string.
97 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
98 ('_layer_call_argspecs', '_compiled_trainable_state',
99 '_output_mask_cache', '_output_tensor_cache', '_output_shape_cache'),
100 training_lib.Model._TF_MODULE_IGNORED_PROPERTIES
101 ))
103 @trackable.no_automatic_dependency_tracking
104 def __init__(self, inputs, outputs, name=None, trainable=True,
105 **kwargs):
106 # This is used by the Model class, since we have some logic to swap the
107 # class in the __new__ method, which will lead to __init__ get invoked
108 # twice. Using the skip_init to skip one of the invocation of __init__ to
109 # avoid any side effects
110 skip_init = kwargs.pop('skip_init', False)
111 if skip_init:
112 return
113 generic_utils.validate_kwargs(kwargs, {})
114 super(Functional, self).__init__(name=name, trainable=trainable)
115 self._init_graph_network(inputs, outputs)
117 @trackable.no_automatic_dependency_tracking
118 def _init_graph_network(self, inputs, outputs):
119 # This method is needed for Sequential to reinitialize graph network when
120 # layer is added or removed.
121 self._is_graph_network = True
123 # Normalize and set self.inputs, self.outputs.
124 if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1:
125 inputs = inputs[0]
126 if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1:
127 outputs = outputs[0]
128 self._nested_inputs = inputs
129 self._nested_outputs = outputs
130 self.inputs = nest.flatten(inputs)
131 self.outputs = nest.flatten(outputs)
133 # Models constructed with a single Tensor or list of Tensors can
134 # be called with a dict, where the keys of the dict are the names
135 # of the `Input` objects. Extra keys are ignored with warning.
136 if not nest.is_nested(self._nested_inputs):
137 self._enable_dict_to_input_mapping = True
138 elif (isinstance(self._nested_inputs, (list, tuple)) and
139 not any(nest.is_nested(t) for t in self._nested_inputs)):
140 self._enable_dict_to_input_mapping = True
141 elif (isinstance(self._nested_inputs, dict) and
142 not any(nest.is_nested(t) for t in self._nested_inputs.values())):
143 self._enable_dict_to_input_mapping = True
144 else:
145 self._enable_dict_to_input_mapping = False
147 if not ops.executing_eagerly_outside_functions():
148 if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
149 base_layer_utils.create_keras_history(self._nested_outputs)
151 self._validate_graph_inputs_and_outputs()
153 # A Network does not create weights of its own, thus it is already
154 # built.
155 self.built = True
156 self._build_input_shape = nest.map_structure(lambda x: x.shape, inputs)
157 self._compute_output_and_mask_jointly = True
158 # `_expects_training_arg` is True since the `training` argument is always
159 # present in the signature of the `call` method of a graph network.
160 self._expects_training_arg = True
161 self._expects_mask_arg = True
162 # A graph network does not autocast inputs, as its layers will cast them
163 # instead.
164 self._autocast = False
166 self._input_layers = []
167 self._output_layers = []
168 self._input_coordinates = []
169 self._output_coordinates = []
171 # This is for performance optimization when calling the Network on new
172 # inputs. Every time the Network is called on a set on input tensors,
173 # we compute the output tensors, output masks and output shapes in one pass,
174 # then cache them here. When any of these outputs is queried later, we
175 # retrieve it from there instead of recomputing it.
176 self._output_mask_cache = {}
177 self._output_tensor_cache = {}
178 self._output_shape_cache = {}
180 # Build self._output_layers:
181 for x in self.outputs:
182 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
183 self._output_layers.append(layer)
184 self._output_coordinates.append((layer, node_index, tensor_index))
186 # Build self._input_layers:
187 for x in self.inputs:
188 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
189 # It's supposed to be an input layer, so only one node
190 # and one tensor output.
191 assert node_index == 0
192 assert tensor_index == 0
193 self._input_layers.append(layer)
194 self._input_coordinates.append((layer, node_index, tensor_index))
196 # Keep track of the network's nodes and layers.
197 nodes, nodes_by_depth, layers, _ = _map_graph_network(
198 self.inputs, self.outputs)
199 self._network_nodes = nodes
200 self._nodes_by_depth = nodes_by_depth
201 self._self_tracked_trackables = layers
202 self._layer_call_argspecs = {}
203 for layer in self._self_tracked_trackables:
204 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
206 # Build self.input_names and self.output_names.
207 self._set_output_names()
208 self.input_names = []
209 self._feed_input_names = []
210 self._feed_inputs = []
211 self._feed_input_shapes = []
212 for layer in self._input_layers:
213 self.input_names.append(layer.name)
214 if layer.is_placeholder:
215 self._feed_input_names.append(layer.name)
216 # Use batch_input_shape here because non-eager composite tensors may not
217 # have a shape attribute that's meaningful (sparse, for instance, has
218 # a tensor that's non-constant and needs to be fed). This means that
219 # input layers that create placeholders will need to have the
220 # batch_input_shape attr to allow for input shape validation.
221 self._feed_input_shapes.append(layer._batch_input_shape)
222 self._feed_inputs.append(layer.input)
224 self._compute_tensor_usage_count()
225 self._set_save_spec(self._nested_inputs)
226 tf_utils.assert_no_legacy_layers(self.layers)
228 @property
229 def input(self):
230 """Retrieves the input tensor(s) of a layer.
232 Only applicable if the layer has exactly one input,
233 i.e. if it is connected to one incoming layer.
235 Returns:
236 Input tensor or list of input tensors.
238 Raises:
239 RuntimeError: If called in Eager mode.
240 AttributeError: If no inbound nodes are found.
241 """
242 return self._nested_inputs
244 @property
245 def input_shape(self):
246 """Retrieves the input shape(s) of a layer.
248 Only applicable if the layer has exactly one input,
249 i.e. if it is connected to one incoming layer, or if all inputs
250 have the same shape.
252 Returns:
253 Input shape, as an integer shape tuple
254 (or list of shape tuples, one tuple per input tensor).
256 Raises:
257 AttributeError: if the layer has no defined input_shape.
258 RuntimeError: if called in Eager mode.
259 """
260 return nest.map_structure(backend.int_shape, self.input)
262 @property
263 def input_spec(self):
264 if hasattr(self, '_manual_input_spec'):
265 return self._manual_input_spec
266 if (isinstance(self._nested_inputs, (dict, list, tuple)) and
267 len(self._nested_inputs) != len(self.inputs)):
268 # Case where we have a nested structure.
269 # In such a case we can't safely run any checks.
270 return None
271 if isinstance(self._nested_inputs, dict):
272 # Case where `_nested_inputs` is a plain dict of Inputs.
273 names = sorted(self._nested_inputs.keys())
274 return [input_spec.InputSpec(
275 shape=shape_with_no_batch_size(self._nested_inputs[name]),
276 allow_last_axis_squeeze=True, name=name) for name in names]
277 else:
278 # Single input, or list / tuple of inputs.
279 # The data may be passed as a dict keyed by input name.
280 return [input_spec.InputSpec(
281 shape=shape_with_no_batch_size(x), allow_last_axis_squeeze=True,
282 name=x._keras_history.layer.name) for x in self.inputs]
284 @input_spec.setter
285 def input_spec(self, value):
286 self._manual_input_spec = value
288 @property
289 def output(self):
290 """Retrieves the output tensor(s) of a layer.
292 Only applicable if the layer has exactly one output,
293 i.e. if it is connected to one incoming layer.
295 Returns:
296 Output tensor or list of output tensors.
298 Raises:
299 AttributeError: if the layer is connected to more than one incoming
300 layers.
301 RuntimeError: if called in Eager mode.
302 """
303 return self._nested_outputs
305 @property
306 def output_shape(self):
307 """Retrieves the output shape(s) of a layer.
309 Only applicable if the layer has one output,
310 or if all outputs have the same shape.
312 Returns:
313 Output shape, as an integer shape tuple
314 (or list of shape tuples, one tuple per output tensor).
316 Raises:
317 AttributeError: if the layer has no defined output shape.
318 RuntimeError: if called in Eager mode.
319 """
320 return nest.map_structure(backend.int_shape, self.output)
322 def _set_output_names(self):
323 """Assigns unique names to the Network's outputs.
325 Output layers with multiple output tensors would otherwise lead to duplicate
326 names in self.output_names.
327 """
328 uniquified = []
329 output_names = set()
330 prefix_count = {}
331 for layer in self._output_layers:
332 proposal = layer.name
333 while proposal in output_names:
334 existing_count = prefix_count.get(layer.name, 1)
335 proposal = '{}_{}'.format(layer.name, existing_count)
336 prefix_count[layer.name] = existing_count + 1
337 output_names.add(proposal)
338 uniquified.append(proposal)
339 self.output_names = uniquified
341 @property
342 def _layer_checkpoint_dependencies(self):
343 """Dictionary of layer dependencies to be included in the checkpoint."""
344 weight_layer_index = 0
346 dependencies = collections.OrderedDict()
347 for layer_index, layer in enumerate(self.layers):
348 try:
349 if layer.weights:
350 # Keep a separate index for layers which have weights. This allows
351 # users to insert Layers without weights anywhere in the network
352 # without breaking checkpoints.
353 dependencies['layer_with_weights-%d' % weight_layer_index] = layer
354 weight_layer_index += 1
355 except ValueError:
356 # The layer might have weights, but may not be built yet. We just treat
357 # it as layer without weight.
358 pass
360 # Even if it doesn't have weights, we should still track everything in
361 # case it has/will have Trackable dependencies.
362 dependencies['layer-%d' % layer_index] = layer
363 return dependencies
365 def _trackable_children(self,
366 save_type=trackable.SaveType.CHECKPOINT,
367 **kwargs):
368 dependencies = self._layer_checkpoint_dependencies
369 dependencies.update(
370 super(Functional, self)._trackable_children(save_type, **kwargs))
371 return dependencies
373 def _lookup_dependency(self, name):
374 layer_dependencies = self._layer_checkpoint_dependencies
375 if name in layer_dependencies:
376 return layer_dependencies[name]
377 return super(Functional, self)._lookup_dependency(name)
379 def _handle_deferred_layer_dependencies(self, layers):
380 """Handles layer checkpoint dependencies that are added after init."""
381 layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
382 layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
383 for layer in layers:
384 if layer in layer_to_name:
385 self._handle_deferred_dependencies(name=layer_to_name[layer],
386 trackable=layer)
388 @property
389 def _should_compute_mask(self):
390 return True
392 def compute_mask(self, inputs, mask):
393 # TODO(omalleyt): b/123540974 This function is not really safe to call
394 # by itself because it will duplicate any updates and losses in graph
395 # mode by `call`ing the Layers again.
396 output_tensors = self._run_internal_graph(inputs, mask=mask)
397 return nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
398 output_tensors)
400 @doc_controls.do_not_doc_inheritable
401 def call(self, inputs, training=None, mask=None):
402 """Calls the model on new inputs.
404 In this case `call` just reapplies
405 all ops in the graph to the new inputs
406 (e.g. build a new computational graph from the provided inputs).
408 Args:
409 inputs: A tensor or list of tensors.
410 training: Boolean or boolean scalar tensor, indicating whether to run
411 the `Network` in training mode or inference mode.
412 mask: A mask or list of masks. A mask can be
413 either a tensor or None (no mask).
415 Returns:
416 A tensor if there is a single output, or
417 a list of tensors if there are more than one outputs.
418 """
419 return self._run_internal_graph(
420 inputs, training=training, mask=mask)
422 def compute_output_shape(self, input_shape):
423 # Convert any shapes in tuple format to TensorShapes.
424 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
426 if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)):
427 raise ValueError('Invalid input_shape argument ' + str(input_shape) +
428 ': model has ' + str(len(self._input_layers)) +
429 ' tensor inputs.')
431 # Use the tuple of TensorShape as the cache key, since tuple is hashable
432 # and can be used as hash key.
433 try:
434 cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True))
435 if cache_key in self._output_shape_cache:
436 # Cache hit. Return shapes as TensorShapes.
437 return self._output_shape_cache[cache_key]
438 except ValueError:
439 # In case there are unknown TensorShape, eg for sparse tensor input,
440 # We skip the caching since the shape is unknown.
441 pass
443 layers_to_output_shapes = {}
444 for layer, shape in zip(self._input_layers, nest.flatten(input_shape)):
445 # It's an input layer: then `compute_output_shape` is identity,
446 # and there is only one node and one tensor..
447 shape_key = layer.name + '_0_0'
448 layers_to_output_shapes[shape_key] = shape
450 depth_keys = list(self._nodes_by_depth.keys())
451 depth_keys.sort(reverse=True)
452 # Iterate over nodes, by depth level.
453 if len(depth_keys) > 1:
454 for depth in depth_keys:
455 nodes = self._nodes_by_depth[depth]
456 for node in nodes:
457 layer = node.layer
458 if layer in self._input_layers:
459 # We've already covered the input layers
460 # a few lines above.
461 continue
462 # Get the input shapes for the first argument of the node
463 layer_input_shapes = []
464 layer_inputs = node.call_args[0]
465 for layer_input in nest.flatten(layer_inputs):
466 kh = layer_input._keras_history
467 input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index,
468 kh.tensor_index)
469 layer_input_shapes.append(layers_to_output_shapes[input_layer_key])
470 layer_input_shapes = nest.pack_sequence_as(layer_inputs,
471 layer_input_shapes)
472 # Layers expect shapes to be tuples for `compute_output_shape`.
473 layer_input_shapes = tf_utils.convert_shapes(
474 layer_input_shapes, to_tuples=True)
475 layer_output_shapes = layer.compute_output_shape(layer_input_shapes)
476 # Convert back to TensorShapes.
477 layer_output_shapes = tf_utils.convert_shapes(
478 layer_output_shapes, to_tuples=False)
480 node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access
481 for j, shape in enumerate(nest.flatten(layer_output_shapes)):
482 shape_key = layer.name + '_%s_%s' % (node_index, j)
483 layers_to_output_shapes[shape_key] = shape
485 # Read final output shapes from layers_to_output_shapes.
486 output_shapes = []
487 for i in range(len(self._output_layers)):
488 layer, node_index, tensor_index = self._output_coordinates[i]
489 shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
490 output_shapes.append(layers_to_output_shapes[shape_key])
491 output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes)
492 # Store in cache.
493 self._output_shape_cache[cache_key] = output_shapes
495 # Return shapes as TensorShapes.
496 return output_shapes
498 def _init_set_name(self, name, zero_based=True):
499 if not name:
500 cls_name = self.__class__.__name__
501 if self.__class__ == Functional:
502 # Hide the functional class name from user, since its not a public
503 # visible class. Use "Model" instead,
504 cls_name = 'Model'
505 self._name = backend.unique_object_name(
506 generic_utils.to_snake_case(cls_name),
507 zero_based=zero_based)
508 else:
509 self._name = name
511 def _run_internal_graph(self, inputs, training=None, mask=None):
512 """Computes output tensors for new inputs.
514 # Note:
515 - Can be run on non-Keras tensors.
517 Args:
518 inputs: Tensor or nested structure of Tensors.
519 training: Boolean learning phase.
520 mask: (Optional) Tensor or nested structure of Tensors.
522 Returns:
523 output_tensors
524 """
525 inputs = self._flatten_to_reference_inputs(inputs)
526 if mask is None:
527 masks = [None] * len(inputs)
528 else:
529 masks = self._flatten_to_reference_inputs(mask)
530 for input_t, mask in zip(inputs, masks):
531 input_t._keras_mask = mask
533 # Dictionary mapping reference tensors to computed tensors.
534 tensor_dict = {}
535 tensor_usage_count = self._tensor_usage_count
536 for x, y in zip(self.inputs, inputs):
537 y = self._conform_to_reference_input(y, ref_input=x)
538 x_id = str(id(x))
539 tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
541 nodes_by_depth = self._nodes_by_depth
542 depth_keys = list(nodes_by_depth.keys())
543 depth_keys.sort(reverse=True)
545 for depth in depth_keys:
546 nodes = nodes_by_depth[depth]
547 for node in nodes:
548 if node.is_input:
549 continue # Input tensors already exist.
551 if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
552 continue # Node is not computable, try skipping.
554 args, kwargs = node.map_arguments(tensor_dict)
555 outputs = node.layer(*args, **kwargs)
557 # Update tensor_dict.
558 for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)):
559 tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
561 output_tensors = []
562 for x in self.outputs:
563 x_id = str(id(x))
564 assert x_id in tensor_dict, 'Could not compute output ' + str(x)
565 output_tensors.append(tensor_dict[x_id].pop())
567 return nest.pack_sequence_as(self._nested_outputs, output_tensors)
569 def _flatten_to_reference_inputs(self, tensors):
570 """Maps `tensors` to their respective `keras.Input`."""
571 if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
572 ref_inputs = self._nested_inputs
573 if not nest.is_nested(ref_inputs):
574 ref_inputs = [self._nested_inputs]
575 if isinstance(ref_inputs, dict):
576 # In the case that the graph is constructed with dict input tensors,
577 # We will use the original dict key to map with the keys in the input
578 # data. Note that the model.inputs is using nest.flatten to process the
579 # input tensors, which means the dict input tensors are ordered by their
580 # keys.
581 ref_input_names = sorted(ref_inputs.keys())
582 else:
583 ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs]
585 # Raise an warning if there are more input data comparing to input tensor
586 if len(tensors) > len(ref_input_names):
587 warnings.warn(
588 'Input dict contained keys {} which did not match any model input. '
589 'They will be ignored by the model.'.format(
590 [n for n in tensors.keys() if n not in ref_input_names])
591 )
593 try:
594 # Flatten in the order `Input`s were passed during Model construction.
595 return [tensors[n] for n in ref_input_names]
596 except KeyError:
597 # TODO(b/151582614)
598 return nest.flatten(tensors)
600 # Otherwise both self.inputs and tensors will already be in same order.
601 return nest.flatten(tensors)
603 def _conform_to_reference_input(self, tensor, ref_input):
604 """Set shape and dtype based on `keras.Input`s."""
605 if isinstance(tensor, ops.Tensor):
606 # Allow (None,) and (None, 1) Tensors to be passed interchangeably. Use
607 # the shape specified by the `keras.Input`.
608 t_shape = tensor.shape
609 t_rank = t_shape.rank
610 ref_shape = ref_input.shape
611 ref_rank = ref_shape.rank
612 keras_history = getattr(tensor, '_keras_history', None)
613 if t_rank is not None and ref_rank is not None:
614 # Should squeeze last dimension.
615 # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
616 if (t_rank == ref_rank + 1 and t_shape[-1] == 1):
617 tensor = array_ops.squeeze_v2(tensor, axis=-1)
618 # Should expand last_dimension.
619 # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
620 elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
621 tensor = array_ops.expand_dims_v2(tensor, axis=-1)
622 if keras_history is not None: # Restore keras history.
623 tensor._keras_history = keras_history
625 # Add shape hints to Tensors that may have None shape dims but have shapes
626 # defined by the `keras.Input` (not applicable in eager mode).
627 if not context.executing_eagerly():
628 try:
629 tensor.set_shape(tensor.shape.merge_with(ref_input.shape))
630 except ValueError:
631 logging.warning(
632 'Model was constructed with shape {} for input {}, but it was '
633 'called on an input with incompatible shape {}.'.format(
634 ref_input.shape, ref_input, tensor.shape))
636 # Dtype casting.
637 tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
638 elif tf_utils.is_extension_type(tensor):
639 # Dtype casting (If the extension type has a non-variant dtype and
640 # supports being cast)
641 ref_input_dtype = getattr(ref_input, 'dtype', None)
642 if ref_input_dtype is not None and ref_input_dtype != dtypes.variant:
643 tensor = math_ops.cast(tensor, dtype=ref_input_dtype)
645 return tensor
647 def get_config(self):
648 return copy.deepcopy(get_network_config(self))
650 @classmethod
651 def from_config(cls, config, custom_objects=None):
652 """Instantiates a Model from its config (output of `get_config()`).
654 Args:
655 config: Model config dictionary.
656 custom_objects: Optional dictionary mapping names
657 (strings) to custom classes or functions to be
658 considered during deserialization.
660 Returns:
661 A model instance.
663 Raises:
664 ValueError: In case of improperly formatted config dict.
665 """
666 with generic_utils.SharedObjectLoadingScope():
667 input_tensors, output_tensors, created_layers = reconstruct_from_config(
668 config, custom_objects)
669 model = cls(inputs=input_tensors, outputs=output_tensors,
670 name=config.get('name'))
671 connect_ancillary_layers(model, created_layers)
672 return model
674 def _validate_graph_inputs_and_outputs(self):
675 """Validates the inputs and outputs of a Graph Network."""
676 # Check for redundancy in inputs.
677 if len({id(i) for i in self.inputs}) != len(self.inputs):
678 raise ValueError('The list of inputs passed to the model '
679 'is redundant. '
680 'All inputs should only appear once.'
681 ' Found: ' + str(self.inputs))
683 for x in self.inputs:
684 # Check that x has appropriate `_keras_history` metadata.
685 if not hasattr(x, '_keras_history'):
686 cls_name = self.__class__.__name__
687 raise ValueError('Input tensors to a ' + cls_name + ' ' +
688 'must come from `tf.keras.Input`. '
689 'Received: ' + str(x) +
690 ' (missing previous layer metadata).')
691 # Check that x is an input tensor.
692 # pylint: disable=protected-access
693 layer = x._keras_history.layer
694 if len(layer._inbound_nodes) > 1 or (
695 layer._inbound_nodes and not layer._inbound_nodes[0].is_input):
696 cls_name = self.__class__.__name__
697 logging.warning(cls_name + ' model inputs must come from '
698 '`tf.keras.Input` (thus holding past layer metadata), '
699 'they cannot be the output of '
700 'a previous non-Input layer. '
701 'Here, a tensor specified as '
702 'input to "' + self.name + '" was not an Input tensor, '
703 'it was generated by layer ' + layer.name + '.\n'
704 'Note that input tensors are '
705 'instantiated via `tensor = tf.keras.Input(shape)`.\n'
706 'The tensor that caused the issue was: ' + str(x.name))
708 # Check compatibility of batch sizes of Input Layers.
709 input_batch_sizes = [
710 training_utils.get_static_batch_size(x._keras_history.layer)
711 for x in self.inputs
712 ]
713 consistent_batch_size = None
714 for batch_size in input_batch_sizes:
715 if batch_size is not None:
716 if (consistent_batch_size is not None and
717 batch_size != consistent_batch_size):
718 raise ValueError('The specified batch sizes of the Input Layers'
719 ' are incompatible. Found batch sizes: {}'.format(
720 input_batch_sizes))
721 consistent_batch_size = batch_size
723 for x in self.outputs:
724 if not hasattr(x, '_keras_history'):
725 cls_name = self.__class__.__name__
726 raise ValueError('Output tensors of a ' + cls_name + ' model must be '
727 'the output of a TensorFlow `Layer` '
728 '(thus holding past layer metadata). Found: ' + str(x))
730 def _insert_layers(self, layers, relevant_nodes=None):
731 """Inserts Layers into the Network after Network creation.
733 This is only valid for Keras Graph Networks. Layers added via this function
734 will be included in the `call` computation and `get_config` of this Network.
735 They will not be added to the Network's outputs.
738 Args:
739 layers: Arbitrary nested structure of Layers. Layers must be reachable
740 from one or more of the `keras.Input` Tensors that correspond to this
741 Network's inputs.
742 relevant_nodes: Nodes from the Layers that should be considered part of
743 this Network. If `None`, all Nodes will be considered part of this
744 Network.
746 Raises:
747 ValueError: If the layers depend on `Input`s not found in this Model.
748 """
749 layers = nest.flatten(layers)
750 tf_utils.assert_no_legacy_layers(layers)
751 node_to_depth = {}
752 for depth, nodes in self._nodes_by_depth.items():
753 node_to_depth.update({node: depth for node in nodes})
754 # The nodes of these Layers that are relevant to this Network. If not
755 # provided, assume all Nodes are relevant
756 if not relevant_nodes:
757 relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers])
758 network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
760 def _get_min_depth(node):
761 """Gets the minimum depth at which node can be computed."""
762 min_depth = 0
763 for layer, node_id, _, _ in node.iterate_inbound():
764 inbound_node = layer._inbound_nodes[node_id]
765 if inbound_node in node_to_depth:
766 min_depth = min(min_depth, node_to_depth[inbound_node])
767 elif inbound_node not in network_nodes:
768 continue
769 else:
770 # Previous relevant nodes haven't been processed yet.
771 return None
772 # New node is one shallower than its shallowest input.
773 return min_depth - 1
775 # Insert nodes into `_nodes_by_depth` and other node attrs.
776 unprocessed_nodes = copy.copy(relevant_nodes)
777 i = 0
778 while unprocessed_nodes:
779 i += 1
780 # Do a sanity check. This can occur if `Input`s from outside this Model
781 # are being relied on.
782 if i > 10000:
783 raise ValueError('Layers could not be added due to missing '
784 'dependencies.')
786 node = unprocessed_nodes.pop(0)
787 depth = _get_min_depth(node)
788 if depth is None: # Defer until inbound nodes are processed.
789 unprocessed_nodes.append(node)
790 continue
791 node_key = _make_node_key(node.layer.name,
792 node.layer._inbound_nodes.index(node))
793 if node_key not in self._network_nodes:
794 node_to_depth[node] = depth
795 self._network_nodes.add(node_key)
796 self._nodes_by_depth[depth].append(node)
798 # Insert layers and update other layer attrs.
799 layer_set = set(self._self_tracked_trackables)
800 deferred_layers = []
801 for layer in layers:
802 if layer not in layer_set:
803 self._self_tracked_trackables.append(layer)
804 deferred_layers.append(layer)
805 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
806 layer_set.add(layer)
807 self._handle_deferred_layer_dependencies(deferred_layers)
809 self._compute_tensor_usage_count()
811 def _compute_tensor_usage_count(self):
812 """Compute the #. of tensor usages for all the output tensors of layers.
814 The computed tensor usage count is saved as `self._tensor_usage_count`. This
815 is later used for saving memory in eager computation by releasing
816 no-longer-needed tensors as early as possible.
817 """
818 tensor_usage_count = collections.Counter()
819 available_tensors = set(str(id(tensor)) for tensor in self.inputs)
821 depth_keys = list(self._nodes_by_depth.keys())
822 depth_keys.sort(reverse=True)
823 depth_keys = depth_keys[1:]
825 for depth in depth_keys:
826 for node in self._nodes_by_depth[depth]:
827 input_tensors = {
828 str(id(tensor)) for tensor in nest.flatten(node.keras_inputs)
829 }
830 if input_tensors.issubset(available_tensors):
831 for tensor in nest.flatten(node.keras_inputs):
832 tensor_usage_count[str(id(tensor))] += 1
834 for output_tensor in nest.flatten(node.outputs):
835 available_tensors.add(str(id(output_tensor)))
837 for tensor in self.outputs:
838 tensor_usage_count[str(id(tensor))] += 1
840 self._tensor_usage_count = tensor_usage_count
842 def _assert_weights_created(self):
843 # Override the implementation in Model.
844 # The Functional model should always have weight created already.
845 return
847 def _graph_network_add_loss(self, symbolic_loss):
848 new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
849 # Losses must be keyed on inputs no matter what in order to be supported in
850 # DistributionStrategy.
851 add_loss_layer = base_layer.AddLoss(
852 unconditional=False, dtype=symbolic_loss.dtype)
853 add_loss_layer(symbolic_loss)
854 new_nodes.extend(add_loss_layer.inbound_nodes)
855 new_layers.append(add_loss_layer)
856 self._insert_layers(new_layers, new_nodes)
858 def _graph_network_add_metric(self, value, aggregation, name):
859 new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
860 add_metric_layer = base_layer.AddMetric(
861 aggregation, name, dtype=value.dtype)
862 add_metric_layer(value)
863 new_nodes.extend(add_metric_layer.inbound_nodes)
864 new_layers.append(add_metric_layer)
865 self._insert_layers(new_layers, new_nodes)
867 @property
868 def _trackable_saved_model_saver(self):
869 return network_serialization.NetworkSavedModelSaver(self)
871 def _get_save_spec(self, dynamic_batch=True):
872 if getattr(self, '_has_explicit_input_shape', True):
873 # Functional models and Sequential models that have an explicit input
874 # shape should use the batch size set by the input layer.
875 dynamic_batch = False
876 return super(Functional, self)._get_save_spec(dynamic_batch)
879def _make_node_key(layer_name, node_index):
880 return layer_name + '_ib-' + str(node_index)
883def _map_graph_network(inputs, outputs):
884 """Validates a network's topology and gather its layers and nodes.
886 Args:
887 inputs: List of input tensors.
888 outputs: List of outputs tensors.
890 Returns:
891 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
892 - nodes: list of Node instances.
893 - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
894 - layers: list of Layer instances.
895 - layers_by_depth: dict mapping ints (depth) to lists of layer instances.
897 Raises:
898 ValueError: In case the network is not valid (e.g. disconnected graph).
899 """
900 # "depth" is number of layers between output Node and the Node.
901 # Nodes are ordered from inputs -> outputs.
902 nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
903 network_nodes = {
904 _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
905 for node in nodes_in_decreasing_depth
906 }
908 nodes_depths = {} # dict {node: depth value}
909 layers_depths = {} # dict {layer: depth value}
911 for node in reversed(nodes_in_decreasing_depth):
912 # If the depth is not set, the node has no outbound nodes (depth 0).
913 depth = nodes_depths.setdefault(node, 0)
915 # Update the depth of the corresponding layer
916 previous_depth = layers_depths.get(node.layer, 0)
917 # If we've seen this layer before at a higher depth,
918 # we should use that depth instead of the node depth.
919 # This is necessary for shared layers that have inputs at different
920 # depth levels in the graph.
921 depth = max(depth, previous_depth)
922 layers_depths[node.layer] = depth
923 nodes_depths[node] = depth
925 # Update the depth of inbound nodes.
926 # The "depth" of a node is the max of the depths
927 # of all nodes it is connected to + 1.
928 for node_dep in node.parent_nodes:
929 previous_depth = nodes_depths.get(node_dep, 0)
930 nodes_depths[node_dep] = max(depth + 1, previous_depth)
932 # Handle inputs that are not connected to outputs.
933 # We do not error out here because the inputs may be used to compute losses
934 # and metrics.
935 for input_t in inputs:
936 input_layer = input_t._keras_history[0]
937 if input_layer not in layers_depths:
938 layers_depths[input_layer] = 0
939 layer_indices[input_layer] = -1
940 nodes_depths[input_layer._inbound_nodes[0]] = 0
941 network_nodes.add(_make_node_key(input_layer.name, 0))
943 # Build a dict {depth: list of nodes with this depth}
944 nodes_by_depth = collections.defaultdict(list)
945 for node, depth in nodes_depths.items():
946 nodes_by_depth[depth].append(node)
948 # Build a dict {depth: list of layers with this depth}
949 layers_by_depth = collections.defaultdict(list)
950 for layer, depth in layers_depths.items():
951 layers_by_depth[depth].append(layer)
953 # Get sorted list of layer depths.
954 depth_keys = list(layers_by_depth.keys())
955 depth_keys.sort(reverse=True)
957 # Set self.layers ordered by depth.
958 layers = []
959 for depth in depth_keys:
960 layers_for_depth = layers_by_depth[depth]
961 # Network.layers needs to have a deterministic order:
962 # here we order them by traversal order.
963 layers_for_depth.sort(key=lambda x: layer_indices[x])
964 layers.extend(layers_for_depth)
966 # Get sorted list of node depths.
967 depth_keys = list(nodes_by_depth.keys())
968 depth_keys.sort(reverse=True)
970 # Check that all tensors required are computable.
971 # computable_tensors: all tensors in the graph
972 # that can be computed from the inputs provided.
973 computable_tensors = set()
974 for x in inputs:
975 computable_tensors.add(id(x))
977 layers_with_complete_input = [] # To provide a better error msg.
978 for depth in depth_keys:
979 for node in nodes_by_depth[depth]:
980 layer = node.layer
981 if layer and not node.is_input:
982 for x in nest.flatten(node.keras_inputs):
983 if id(x) not in computable_tensors:
984 raise ValueError('Graph disconnected: '
985 'cannot obtain value for tensor ' + str(x) +
986 ' at layer "' + layer.name + '". '
987 'The following previous layers '
988 'were accessed without issue: ' +
989 str(layers_with_complete_input))
990 for x in nest.flatten(node.outputs):
991 computable_tensors.add(id(x))
992 layers_with_complete_input.append(layer.name)
994 # Ensure name unicity, which will be crucial for serialization
995 # (since serialized nodes refer to layers by their name).
996 all_names = [layer.name for layer in layers]
997 for name in all_names:
998 if all_names.count(name) != 1:
999 raise ValueError('The name "' + name + '" is used ' +
1000 str(all_names.count(name)) + ' times in the model. '
1001 'All layer names should be unique.')
1002 return network_nodes, nodes_by_depth, layers, layers_by_depth
1005def _build_map(outputs):
1006 """This method topologically sorts nodes in order from inputs to outputs.
1008 It uses a depth-first search to topologically sort nodes that appear in the
1009 _keras_history connectivity metadata of `outputs`.
1011 Args:
1012 outputs: the output tensors whose _keras_history metadata should be walked.
1013 This may be an arbitrary nested structure.
1015 Returns:
1016 A tuple like (ordered_nodes, layer_to_first_traversal_index)
1017 ordered_nodes: list of nodes appearing in the keras history, topologically
1018 sorted from original inputs to the `outputs`.
1019 (If outputs have different sets of ancestors, the inputs to one output
1020 may appear after a different output).
1021 layer_to_first_traversal_index:
1022 A dict mapping layer to the traversal index in the DFS where it is
1023 seen. Note: if a layer is shared by several nodes, the dict will only
1024 store the index corresponding to the *first* time the layer seen.
1025 """
1026 finished_nodes = set()
1027 nodes_in_progress = set()
1028 nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
1029 layer_indices = {} # layer -> in traversal order.
1030 for output in nest.flatten(outputs):
1031 _build_map_helper(output, finished_nodes, nodes_in_progress,
1032 nodes_in_decreasing_depth, layer_indices)
1033 return nodes_in_decreasing_depth, layer_indices
1036def _build_map_helper(tensor, finished_nodes, nodes_in_progress,
1037 nodes_in_decreasing_depth, layer_indices):
1038 """Recursive helper for `_build_map`."""
1039 layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access
1040 node = layer._inbound_nodes[node_index] # pylint: disable=protected-access
1042 # Don't repeat work for shared subgraphs
1043 if node in finished_nodes:
1044 return
1046 # Prevent cycles.
1047 if node in nodes_in_progress:
1048 raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name +
1049 '" is part of a cycle.')
1051 # Store the traversal order for layer sorting.
1052 if layer not in layer_indices:
1053 layer_indices[layer] = len(layer_indices)
1055 # Propagate to all previous tensors connected to this node.
1056 nodes_in_progress.add(node)
1057 if not node.is_input:
1058 for tensor in node.keras_inputs:
1059 _build_map_helper(tensor, finished_nodes, nodes_in_progress,
1060 nodes_in_decreasing_depth, layer_indices)
1062 finished_nodes.add(node)
1063 nodes_in_progress.remove(node)
1064 nodes_in_decreasing_depth.append(node)
1067def _map_subgraph_network(inputs, outputs):
1068 """Returns the nodes and layers in the topology from `inputs` to `outputs`.
1070 Args:
1071 inputs: List of input tensors.
1072 outputs: List of output tensors.
1074 Returns:
1075 A tuple of List{Node] and List[Layer].
1076 """
1077 if not ops.executing_eagerly_outside_functions():
1078 base_layer_utils.create_keras_history(outputs)
1079 # Keep only nodes and layers in the topology between inputs and outputs.
1080 _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
1081 return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
1084def _should_skip_first_node(layer):
1085 """Returns True if the first layer node should not be saved or loaded."""
1086 # Networks that are constructed with an Input layer/shape start with a
1087 # pre-existing node linking their input to output. This node is excluded from
1088 # the network config.
1089 if layer._self_tracked_trackables:
1090 return (isinstance(layer, Functional) and
1091 # Filter out Sequential models without an input shape.
1092 isinstance(layer._self_tracked_trackables[0],
1093 input_layer_module.InputLayer))
1094 else:
1095 return isinstance(layer, Functional)
1098def connect_ancillary_layers(model, created_layers):
1099 """Adds layers that are not connected to the outputs to the model."""
1100 # Layers not connected to outputs, such as those added in `add_loss`.
1101 ancillary_layers = [
1102 layer for layer in created_layers.values() if layer not in model.layers
1103 ]
1104 if ancillary_layers:
1105 relevant_nodes = nest.flatten([
1106 layer.inbound_nodes[1:]
1107 if _should_skip_first_node(layer) else layer.inbound_nodes
1108 for layer in created_layers.values()
1109 ])
1110 model._insert_layers(ancillary_layers, relevant_nodes)
1111 return model
1114def reconstruct_from_config(config, custom_objects=None, created_layers=None):
1115 """Reconstructs graph from config object.
1117 Args:
1118 config: Dictionary returned from Network.get_config()
1119 custom_objects: Optional dictionary mapping names (strings) to custom
1120 classes or functions to be considered during deserialization.
1121 created_layers: Optional dictionary mapping names to Layer objects. Any
1122 layer not in this dictionary will be created and added to the dict.
1123 This function will add new nodes to all layers (excluding InputLayers),
1124 instead of re-using pre-existing nodes in the layers.
1126 Returns:
1127 Tuple of (input tensors, output tensors, dictionary of created layers)
1128 """
1129 # Layer instances created during the graph reconstruction process.
1130 created_layers = created_layers or collections.OrderedDict()
1132 # Maps input data (tuple of inbound layer name, node index) from the config
1133 # to node indices in the newly generated model. The node indices may be
1134 # different if the layers have already been called previously.
1135 node_index_map = {}
1136 node_count_by_layer = {}
1138 # Dictionary mapping layer instances to
1139 # node data that specifies a layer call.
1140 # It acts as a queue that maintains any unprocessed
1141 # layer call until it becomes possible to process it
1142 # (i.e. until the input tensors to the call all exist).
1143 unprocessed_nodes = {}
1145 def add_unprocessed_node(layer, node_data):
1146 if layer not in unprocessed_nodes:
1147 unprocessed_nodes[layer] = [node_data]
1148 else:
1149 unprocessed_nodes[layer].append(node_data)
1151 def get_node_index(layer, config_node_index):
1152 """Returns node index in layer (might differ from config_node_index)."""
1153 if isinstance(layer, input_layer_module.InputLayer):
1154 return 0
1155 return node_index_map.get((layer.name, config_node_index), None)
1157 def _deserialize_keras_tensors(kwargs, layer_map):
1158 """Deserializes Keras Tensors passed to `call`.."""
1160 def _deserialize_keras_tensor(t):
1161 """Deserializes a single Keras Tensor passed to `call`."""
1162 if isinstance(t, tf_utils.ListWrapper):
1163 t = t.as_list()
1164 layer_name = t[0]
1165 node_index = t[1]
1166 tensor_index = t[2]
1168 layer = layer_map[layer_name]
1169 new_node_index = get_node_index(layer, node_index)
1170 if new_node_index is None:
1171 # The inbound node may not have been processed yet,
1172 # (This can happen e.g. if it depends on a different set
1173 # of inputs than those that have been processed already).
1174 # raise an IndexError so that the current node puts itself
1175 # back on the unprocessed queue.
1176 # Caution: This may lead to infinite loops for malformed
1177 # network configurations! (or when there is a bug in
1178 # the network config loading code).
1179 raise IndexError
1180 node = layer._inbound_nodes[new_node_index]
1181 return nest.flatten(node.outputs)[tensor_index]
1182 return t
1184 kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
1185 return nest.map_structure(_deserialize_keras_tensor, kwargs)
1187 def process_node(layer, node_data):
1188 """Deserialize a node.
1190 Args:
1191 layer: layer instance.
1192 node_data: Nested structure of `ListWrapper`.
1194 Raises:
1195 ValueError: In case of improperly formatted `node_data`.
1196 """
1197 input_tensors = []
1198 for input_data in nest.flatten(node_data):
1199 input_data = input_data.as_list()
1200 inbound_layer_name = input_data[0]
1201 inbound_node_index = input_data[1]
1202 inbound_tensor_index = input_data[2]
1203 if len(input_data) == 3:
1204 kwargs = {}
1205 elif len(input_data) == 4:
1206 kwargs = input_data[3]
1207 try:
1208 kwargs = _deserialize_keras_tensors(kwargs, created_layers)
1209 except IndexError:
1210 # Happens if keras tensors in kwargs are still unprocessed
1211 add_unprocessed_node(layer, node_data)
1212 return
1213 else:
1214 raise ValueError('Improperly formatted model config.')
1216 if inbound_layer_name != node_module._CONSTANT_VALUE:
1217 inbound_layer = created_layers[inbound_layer_name]
1218 inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
1220 if inbound_node_index is None:
1221 add_unprocessed_node(layer, node_data)
1222 return
1223 inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
1224 input_tensors.append(
1225 nest.flatten(inbound_node.outputs)[inbound_tensor_index])
1226 else:
1227 # We received a constant w/ no Keras history attached
1228 input_tensors.append(inbound_tensor_index)
1229 input_tensors = nest.pack_sequence_as(node_data, input_tensors)
1230 # Call layer on its inputs, thus creating the node
1231 # and building the layer if needed.
1232 if input_tensors is not None:
1233 if not layer._preserve_input_structure_in_config:
1234 input_tensors = (
1235 base_layer_utils.unnest_if_single_tensor(input_tensors))
1236 output_tensors = layer(input_tensors, **kwargs)
1238 # Update node index map.
1239 output_index = nest.flatten(output_tensors)[0]._keras_history.node_index
1240 node_index_map[(layer.name, node_count_by_layer[layer])] = output_index
1241 node_count_by_layer[layer] += 1
1243 def process_layer(layer_data):
1244 """Deserializes a layer, then call it on appropriate inputs.
1246 Args:
1247 layer_data: layer config dict.
1249 Raises:
1250 ValueError: In case of improperly formatted `layer_data` dict.
1251 """
1252 layer_name = layer_data['name']
1254 if layer_name in created_layers:
1255 layer = created_layers[layer_name]
1256 else:
1257 # Instantiate layer.
1258 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
1260 layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1261 created_layers[layer_name] = layer
1263 node_count_by_layer[layer] = int(_should_skip_first_node(layer))
1265 # Gather layer inputs and convert to `ListWrapper` objects.
1266 inbound_nodes_data = layer_data['inbound_nodes']
1267 inbound_nodes_data = tf_utils.convert_inner_node_data(
1268 inbound_nodes_data, wrap=True)
1269 for node_data in inbound_nodes_data:
1270 # We don't process nodes (i.e. make layer calls)
1271 # on the fly because the inbound node may not yet exist,
1272 # in case of layer shared at different topological depths
1273 # (e.g. a model such as A(B(A(B(x)))))
1274 add_unprocessed_node(layer, node_data)
1276 # First, we create all layers and enqueue nodes to be processed
1277 for layer_data in config['layers']:
1278 process_layer(layer_data)
1279 # Then we process nodes in order of layer depth.
1280 # Nodes that cannot yet be processed (if the inbound node
1281 # does not yet exist) are re-enqueued, and the process
1282 # is repeated until all nodes are processed.
1283 while unprocessed_nodes:
1284 for layer_data in config['layers']:
1285 layer = created_layers[layer_data['name']]
1286 if layer in unprocessed_nodes:
1287 for node_data in unprocessed_nodes.pop(layer):
1288 process_node(layer, node_data)
1290 input_tensors = []
1291 output_tensors = []
1293 input_layers = tf_utils.convert_inner_node_data(
1294 config['input_layers'], wrap=True)
1295 for layer_data in nest.flatten(input_layers):
1296 layer_name, node_index, tensor_index = layer_data.as_list()
1297 assert layer_name in created_layers
1298 layer = created_layers[layer_name]
1299 node_index = get_node_index(layer, node_index)
1300 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1301 input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1303 output_layers = tf_utils.convert_inner_node_data(
1304 config['output_layers'], wrap=True)
1305 for layer_data in nest.flatten(output_layers):
1306 layer_name, node_index, tensor_index = layer_data.as_list()
1307 assert layer_name in created_layers
1308 layer = created_layers[layer_name]
1309 node_index = get_node_index(layer, node_index)
1310 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1311 output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1313 input_tensors = nest.pack_sequence_as(input_layers, input_tensors)
1314 output_tensors = nest.pack_sequence_as(output_layers, output_tensors)
1315 return input_tensors, output_tensors, created_layers
1318def get_network_config(network, serialize_layer_fn=None):
1319 """Builds the config, which consists of the node graph and serialized layers.
1321 Args:
1322 network: A Network object.
1323 serialize_layer_fn: Function used to serialize layers.
1325 Returns:
1326 Config dictionary.
1327 """
1328 serialize_layer_fn = (
1329 serialize_layer_fn or generic_utils.serialize_keras_object)
1330 config = {
1331 'name': network.name,
1332 }
1333 node_conversion_map = {}
1334 for layer in network.layers:
1335 kept_nodes = 1 if _should_skip_first_node(layer) else 0
1336 for original_node_index, node in enumerate(layer._inbound_nodes):
1337 node_key = _make_node_key(layer.name, original_node_index)
1338 if node_key in network._network_nodes:
1339 node_conversion_map[node_key] = kept_nodes
1340 kept_nodes += 1
1341 layer_configs = []
1343 with generic_utils.SharedObjectSavingScope():
1344 for layer in network.layers: # From the earliest layers on.
1345 filtered_inbound_nodes = []
1346 for original_node_index, node in enumerate(layer._inbound_nodes):
1347 node_key = _make_node_key(layer.name, original_node_index)
1348 if node_key in network._network_nodes and not node.is_input:
1349 # The node is relevant to the model:
1350 # add to filtered_inbound_nodes.
1351 node_data = node.serialize(_make_node_key, node_conversion_map)
1352 filtered_inbound_nodes.append(node_data)
1354 layer_config = serialize_layer_fn(layer)
1355 layer_config['name'] = layer.name
1356 layer_config['inbound_nodes'] = filtered_inbound_nodes
1357 layer_configs.append(layer_config)
1358 config['layers'] = layer_configs
1360 # Gather info about inputs and outputs.
1361 model_inputs = []
1362 for i in range(len(network._input_layers)):
1363 layer, node_index, tensor_index = network._input_coordinates[i]
1364 node_key = _make_node_key(layer.name, node_index)
1365 if node_key not in network._network_nodes:
1366 continue
1367 new_node_index = node_conversion_map[node_key]
1368 model_inputs.append(
1369 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
1370 model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs)
1371 # Preserve external Keras compat for Models with single input.
1372 if not nest.is_nested(model_inputs):
1373 model_inputs = [model_inputs]
1374 model_inputs = tf_utils.convert_inner_node_data(model_inputs)
1375 config['input_layers'] = model_inputs
1377 model_outputs = []
1378 for i in range(len(network._output_layers)):
1379 layer, node_index, tensor_index = network._output_coordinates[i]
1380 node_key = _make_node_key(layer.name, node_index)
1381 if node_key not in network._network_nodes:
1382 continue
1383 new_node_index = node_conversion_map[node_key]
1384 model_outputs.append(
1385 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
1386 model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs)
1387 # Preserve external Keras compat for Models with single output.
1388 if not nest.is_nested(model_outputs):
1389 model_outputs = [model_outputs]
1390 model_outputs = tf_utils.convert_inner_node_data(model_outputs)
1391 config['output_layers'] = model_outputs
1392 return config
1395def shape_with_no_batch_size(x):
1396 if x.shape.rank is None:
1397 return None
1398 shape = x.shape.as_list()
1399 if shape:
1400 shape[0] = None
1401 return shape
1404class ModuleWrapper(base_layer.Layer):
1405 """Wrapper for `tf.Module`s to support the Functional and Sequential API."""
1407 def __init__(self, module, method_name=None, **kwargs):
1408 """Initializes the wrapper Layer for this module.
1410 Args:
1411 module: The `tf.Module` instance to be wrapped.
1412 method_name: (Optional) str. The name of the method to use as the forward
1413 pass of the module. If not set, defaults to '__call__' if defined, or
1414 'call'.
1415 **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
1417 Raises:
1418 ValueError: If `method` is not defined on `module`.
1419 """
1420 super(ModuleWrapper, self).__init__(**kwargs)
1421 if method_name is None:
1422 if hasattr(module, '__call__'):
1423 method_name = '__call__'
1424 elif hasattr(module, 'call'):
1425 method_name = 'call'
1426 if method_name is None or not hasattr(module, method_name):
1427 raise ValueError('{} is not defined on object {}'.format(
1428 method_name, module))
1430 self._module = module
1431 self._method_name = method_name
1433 # Check if module.__call__ has a `training` arg or accepts `**kwargs`.
1434 method = getattr(module, method_name)
1435 method_arg_spec = tf_inspect.getfullargspec(method)
1436 self._expects_training_arg = ('training' in method_arg_spec.args or
1437 method_arg_spec.varkw is not None)
1438 self._expects_mask_arg = ('mask' in method_arg_spec.args or
1439 method_arg_spec.varkw is not None)
1441 def call(self, *args, **kwargs):
1442 if 'training' in kwargs and not self._expects_training_arg:
1443 kwargs.pop('training')
1444 if 'mask' in kwargs and not self._expects_mask_arg:
1445 kwargs.pop('mask')
1446 return getattr(self._module, self._method_name)(*args, **kwargs)