Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/node.py: 24%
151 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
17"""Contains the `Node` class."""
19import collections
20import copy
21import json
23import numpy as np
24import tensorflow.compat.v2 as tf
26from keras.src import backend
27from keras.src.engine import base_layer_utils
28from keras.src.saving.legacy.saved_model import json_utils
29from keras.src.utils import tf_utils
31_CONSTANT_VALUE = "_CONSTANT_VALUE"
32# Using dict to avoid conflict with constant string tensor.
33_COMPOSITE_TYPE = {"_TYPE": "COMPOSITE"}
36class Node:
37 """A `Node` describes a layer `__call__()` event.
39 A Functional model is a DAG with `Node` instances as nodes, and
40 `KerasTensor` instances as edges. Nodes aren't `Layer` instances, because a
41 single layer could be called multiple times, which would result in graph
42 cycles.
44 A `__call__()` event involves input tensors (and other input arguments),
45 the layer that was called, and the resulting output tensors.
46 A `Node` will include all this information.
48 Since a single `Layer` could be called multiple times, the `Node` instances
49 are stored on layers as a list. Each time a layer is called a node is added
50 to `layer._inbound_nodes`. Each time the output of a layer is used by
51 another layer, a node is added to `layer._outbound_nodes`.
53 Every `KerasTensor` instance has a `KerasHistory` object attached,
54 which tracks the `Node` that records the `__call__()` event that created
55 the tensor. By recursively walking through `Node` instances
56 via the `KerasHistory` metadata of `KerasTensor` instances, once can
57 retrieve the entire DAG of a Functional model.
59 Args:
60 layer: The layer that was called in the `Layer.__call__()`
61 event that this node represents.
62 call_args: The positional arguments the layer was called with.
63 call_kwargs: The keyword arguments the layer was called with.
64 outputs: The output tensors of the `Layer.__call__()`
65 """
67 def __init__(self, layer, call_args=None, call_kwargs=None, outputs=None):
68 call_args = [] if call_args is None else call_args
69 call_kwargs = {} if call_kwargs is None else call_kwargs
70 outputs = [] if outputs is None else outputs
72 self.layer = layer
73 self.is_input = not call_args and not call_kwargs
75 # These arguments are user-provided. Copy the structures here so that
76 # future user modifications do not affect the node's metadata.
77 # We copy using map_structure rather than python's shallow or deep copy,
78 # because the args can be data structures (so shallow copy is
79 # insufficient), but individual values might not support copy.copy
80 # or be too expensive to deep copy.
81 call_args = tf.nest.map_structure(lambda t: t, call_args)
82 call_kwargs = tf.nest.map_structure(lambda t: t, call_kwargs)
83 self.outputs = tf.nest.map_structure(lambda t: t, outputs)
84 self.call_args = call_args
85 self.call_kwargs = call_kwargs
87 # Cached for performance.
88 self._flat_arguments = tf.nest.flatten(
89 (self.call_args, self.call_kwargs)
90 )
91 # Used to avoid expensive `nest` operations in the most common case.
92 self._single_positional_tensor_passed = (
93 not self.call_kwargs
94 and len(self.call_args) == 1
95 and tf.is_tensor(self.call_args[0])
96 )
98 if not tf.compat.v1.executing_eagerly_outside_functions():
99 # Create TensorFlowOpLayers if needed (in TF1)
100 for obj in self._flat_arguments:
101 if isinstance(
102 obj, tf.Tensor
103 ) and base_layer_utils.needs_keras_history(
104 obj, ignore_call_context=True
105 ):
106 base_layer_utils.create_keras_history(obj)
108 self._keras_inputs = []
109 self._keras_inputs_ids_and_indices = []
110 for i, ele in enumerate(self._flat_arguments):
111 if is_keras_tensor(ele):
112 self._keras_inputs.append(ele)
113 kt_id = str(id(ele))
114 kt_index = i
115 self._keras_inputs_ids_and_indices.append((kt_id, kt_index))
117 # Wire up Node to Layers.
118 self.layer._inbound_nodes.append(self)
119 for kt in self.keras_inputs:
120 inbound_layer = kt._keras_history.layer
121 if inbound_layer is not None: # `None` for `Input` tensors.
122 inbound_layer._outbound_nodes.append(self)
124 # Set metadata on outputs.
125 node_index = len(self.layer._inbound_nodes) - 1
126 for i, tensor in enumerate(tf.nest.flatten(outputs)):
127 tensor._keras_history = KerasHistory(
128 layer=layer, node_index=node_index, tensor_index=i
129 )
131 # Cached for performance.
132 self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
133 self.flat_output_ids = [
134 str(id(t)) for t in tf.nest.flatten(self.outputs)
135 ]
137 @property
138 def keras_inputs(self):
139 """Tensors input to this node that can be traced back to a
140 `keras.Input`."""
141 return self._keras_inputs
143 @property
144 def parent_nodes(self):
145 """Returns all the `Node`s whose output this node immediately depends
146 on."""
147 node_deps = []
148 for kt in self.keras_inputs:
149 layer = kt._keras_history.layer
150 node_index = kt._keras_history.node_index
151 if layer is not None: # `None` for `Input` tensors.
152 node_deps.append(layer._inbound_nodes[node_index])
153 return node_deps
155 def iterate_inbound(self):
156 """Yields tuples representing the data inbound from other nodes.
158 Yields:
159 tuples like: (inbound_layer, node_index, tensor_index, tensor).
160 """
161 for kt in self.keras_inputs:
162 keras_history = kt._keras_history
163 layer = keras_history.layer
164 node_index = keras_history.node_index
165 tensor_index = keras_history.tensor_index
166 yield layer, node_index, tensor_index, kt
168 def map_arguments(self, tensor_dict):
169 """Maps Keras Tensors to computed Tensors using `tensor_dict`."""
170 if self._single_positional_tensor_passed:
171 # Performance optimization for most common case.
172 kt_id, _ = self._keras_inputs_ids_and_indices[0]
173 return (tensor_dict[kt_id].pop(),), {}
174 else:
175 flat_arguments = copy.copy(self._flat_arguments)
176 for kt_id, kt_index in self._keras_inputs_ids_and_indices:
177 flat_arguments[kt_index] = tensor_dict[kt_id].pop()
179 args, kwargs = tf.nest.pack_sequence_as(
180 (self.call_args, self.call_kwargs), flat_arguments
181 )
182 return args, kwargs
184 def serialize(self, make_node_key, node_conversion_map):
185 """Serializes `Node` for Functional API's `get_config`."""
186 # Serialization still special-cases first argument.
187 args, kwargs = self.call_args, self.call_kwargs
188 inputs, args, kwargs = self.layer._call_spec.split_out_first_arg(
189 args, kwargs
190 )
192 # Treat everything other than first argument as a kwarg.
193 arguments = dict(zip(self.layer._call_spec.arg_names[1:], args))
194 arguments.update(kwargs)
195 kwargs = arguments
197 def _serialize_keras_tensor(t):
198 """Serializes a single Tensor passed to `call`."""
199 if hasattr(t, "_keras_history"):
200 kh = t._keras_history
201 node_index = kh.node_index
202 node_key = make_node_key(kh.layer.name, node_index)
203 new_node_index = node_conversion_map.get(node_key, 0)
204 return [kh.layer.name, new_node_index, kh.tensor_index]
206 if isinstance(t, np.ndarray):
207 return t.tolist()
209 if isinstance(t, tf.Tensor):
210 return backend.get_value(t).tolist()
212 # Not using json_utils to serialize both constant Tensor and
213 # constant CompositeTensor for saving format backward compatibility.
214 if isinstance(t, tf.__internal__.CompositeTensor):
215 return (_COMPOSITE_TYPE, json_utils.Encoder().encode(t))
217 return t
219 kwargs = tf.nest.map_structure(_serialize_keras_tensor, kwargs)
220 try:
221 json.dumps(kwargs, default=json_utils.get_json_type)
222 except TypeError:
223 kwarg_types = tf.nest.map_structure(type, kwargs)
224 raise TypeError(
225 "Layer "
226 + self.layer.name
227 + " was passed non-JSON-serializable arguments. "
228 + "Arguments had types: "
229 + str(kwarg_types)
230 + ". They cannot be serialized out when saving the model."
231 )
233 # `kwargs` is added to each Tensor in the first arg. This should be
234 # changed in a future version of the serialization format.
235 def serialize_first_arg_tensor(t):
236 if is_keras_tensor(t):
237 kh = t._keras_history
238 node_index = kh.node_index
239 node_key = make_node_key(kh.layer.name, node_index)
240 new_node_index = node_conversion_map.get(node_key, 0)
241 data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
242 else:
243 # If an element in the first call argument did not originate as
244 # a keras tensor and is a constant value, we save it using the
245 # format ['_CONSTANT_VALUE', -1,
246 # serialized_tensor_or_python_constant] (potentially including
247 # serialized kwargs in an optional 4th argument).
248 data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
249 return tf_utils.ListWrapper(data)
251 data = tf.nest.map_structure(serialize_first_arg_tensor, inputs)
252 if (
253 not tf.nest.is_nested(data)
254 and not self.layer._preserve_input_structure_in_config
255 ):
256 data = [data]
257 data = tf_utils.convert_inner_node_data(data)
258 return data
260 #############################################################
261 # Properties for Backwards compatibility.
262 # These only check the first input argument
263 # As nodes are internal, they may be removed in the future.
264 #############################################################
266 @property
267 def input_tensors(self):
268 if self.is_input:
269 return [self.outputs] # Used in `Layer.input`.
270 return self.call_args[0]
272 @property
273 def output_tensors(self):
274 if self.is_input:
275 return [self.outputs] # Used in `Layer.input`.
276 return self.outputs
278 @property
279 def input_shapes(self):
280 input_shapes = tf.nest.map_structure(
281 backend.int_shape, self.input_tensors
282 )
283 if len(input_shapes) == 1 and not self.is_input:
284 return input_shapes[0]
285 return input_shapes
287 @property
288 def output_shapes(self):
289 return tf.nest.map_structure(backend.int_shape, self.output_tensors)
291 @property
292 def outbound_layer(self):
293 return self.layer
295 @property
296 def inbound_layers(self):
297 """Return all layers that feed into the current node."""
298 if self.is_input:
299 return []
300 tensor_call_args = [
301 x
302 for x in self._flat_arguments
303 if tf.is_tensor(x) and hasattr(x, "_keras_history")
304 ]
305 inbound_layers = tf.nest.map_structure(
306 lambda t: t._keras_history.layer, tensor_call_args
307 )
308 if len(inbound_layers) == 1:
309 return inbound_layers[0]
310 return inbound_layers
313class KerasHistory(
314 collections.namedtuple(
315 "KerasHistory", ["layer", "node_index", "tensor_index"]
316 )
317):
318 """Tracks the Layer call that created a Tensor, for Keras Graph Networks.
320 During construction of Keras Graph Networks, this metadata is added to
321 each Tensor produced as the output of a Layer, starting with an
322 `InputLayer`. This allows Keras to track how each Tensor was produced, and
323 this information is later retraced by the `keras.engine.Network` class to
324 reconstruct the Keras Graph Network.
326 Attributes:
327 layer: The Layer that produced the Tensor.
328 node_index: The specific call to the Layer that produced this Tensor.
329 Layers can be called multiple times in order to share weights. A new
330 node is created every time a Layer is called. The corresponding node
331 that represents the call event that produced the Tensor can be found at
332 `layer._inbound_nodes[node_index]`.
333 tensor_index: The output index for this Tensor. Always zero if the Layer
334 that produced this Tensor only has one output. Nested structures of
335 Tensors are deterministically assigned an index via `nest.flatten`.
336 """
338 # Added to maintain memory and performance characteristics of `namedtuple`
339 # while subclassing.
340 __slots__ = ()
343def is_keras_tensor(obj):
344 return hasattr(obj, "_keras_history")