Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/functional_utils.py: 14%
85 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for keras functional model."""
17import tensorflow.compat.v2 as tf
19from keras.src import backend
20from keras.src.engine import input_layer as input_layer_module
21from keras.src.engine import keras_tensor
22from keras.src.engine import node as node_module
24_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG = (
25 "Found unexpected instance while processing input tensors for keras "
26 "functional model. Expecting KerasTensor which is from tf.keras.Input() "
27 "or output from keras layer call(). Got: {}"
28)
31def is_input_keras_tensor(tensor):
32 """Check if tensor is directly generated from `tf.keras.Input`.
34 This check is useful when constructing the functional model, since we will
35 need to clone Nodes and KerasTensors if the model is building from non input
36 tensor.
38 Args:
39 tensor: A `KerasTensor` as inputs to the functional model.
41 Returns:
42 bool. Whether the tensor is directly generated from `tf.keras.Input`.
44 Raises:
45 ValueError: if the tensor is not a KerasTensor instance.
46 """
47 if not node_module.is_keras_tensor(tensor):
48 raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(tensor))
49 return tensor.node.is_input
52def find_nodes_by_inputs_and_outputs(inputs, outputs):
53 """Fetch all Nodes in the graph defined by "inputs" and "outputs".
55 This method is used to find and then clone Nodes when creating a new
56 sub-model from an existing functional model.
58 Args:
59 inputs: A nested structure of KerasTensor to use as model inputs.
60 outputs: A nested structure of KerasTensor to use as model outputs.
62 Returns:
63 A list of Nodes that are connected to the inputs and outputs.
65 Raises:
66 ValueError: when inputs and outputs are disconnected or in case of
67 unexpected objects in the inputs/outputs.
68 """
69 # We walk the graph bottom up, starting from output nodes, and keep tracing
70 # the upstream node, until we find all the inputs nodes. We don't use top
71 # down search here since we don't know whether a certain node is in the
72 # graph between inputs and outputs, e.g. a functional graph could have
73 # multiple outputs, and the user could choose a subset of them to build the
74 # model. The bottom up approach will ensure all the nodes we visit are
75 # actually in use. If we reach the top and didn't find the nodes in the
76 # `inputs`, that's an error, since the user didn't specify the correct
77 # inputs.
78 start_keras_tensors = tf.nest.flatten(outputs)
79 end_keras_tensors = tf.nest.flatten(inputs)
81 for t in start_keras_tensors + end_keras_tensors:
82 if not node_module.is_keras_tensor(t):
83 raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(t))
84 end_ids = set([id(kt) for kt in end_keras_tensors])
85 # Track all the end tensors we found so far, if we didn't reach all the
86 # user-specified keras inputs after we finish the search, then that's an
87 # error since the inputs are disconnected from the outputs.
88 end_ids_found = set()
90 nodes_to_visit = []
91 nodes_in_graph = []
92 node_id_visited = set()
93 for t in start_keras_tensors:
94 nodes_to_visit.append(t.node)
96 while nodes_to_visit:
97 node = nodes_to_visit.pop(0)
98 if id(node) in node_id_visited:
99 continue
100 node_id_visited.add(id(node))
101 nodes_in_graph.append(node)
102 # Any input keras_tensor that produce the current node.
103 for kt in node.keras_inputs:
104 if id(kt) in end_ids:
105 # We found the inputs of the model, stop tracing upstream nodes
106 end_ids_found.add(id(kt))
107 continue
109 inbound_node = kt.node
110 # In case this is the tf.keras.Input node, we have reached the end
111 # of the tracing of upstream nodes. Any further tracing will just be
112 # an infinite loop. we should raise an error here since we didn't
113 # find the input in the user-specified inputs.
114 if inbound_node.is_input:
115 raise ValueError(
116 "Found input tensor cannot be reached given provided "
117 "output tensors. Please make sure the tensor {} is "
118 "included in the model inputs when building "
119 "functional model.".format(kt)
120 )
121 nodes_to_visit.append(inbound_node)
123 # Do a final check and make sure we have reached all the user-specified
124 # inputs
125 if end_ids != end_ids_found:
126 unvisited_inputs = [
127 kt for kt in end_keras_tensors if id(kt) not in end_ids_found
128 ]
129 raise ValueError(
130 "Found unvisited input tensors that are disconnected from "
131 "the outputs: {}".format(unvisited_inputs)
132 )
133 return nodes_in_graph
136def clone_graph_nodes(inputs, outputs):
137 """Clone the `Node` between the inputs and output tensors.
139 This function is used to create a new functional model from any intermediate
140 keras tensors. The clone of the nodes mimic the behavior of reconstructing
141 the functional graph network by re-executing all the __call__ methods. The
142 cloned nodes will be appended to the layers.
144 Note that a new tf.keras.Inputs will be created for any items in the
145 `inputs`
147 Args:
148 inputs: A nested structure of keras_tensors.
149 outputs: A nested structure of keras_tensors.
151 Returns:
152 A pair of inputs and outputs, with cloned keras_tensors. They can be used
153 to create a new functional model.
154 """
155 nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs)
156 cloned_inputs = []
157 cloned_outputs = []
158 # We not only need to create copies of Nodes (mimic the calls), also need to
159 # clone keras_tensors to avoid the override of _keras_history attached on
160 # the keras_tensor. The following dict is used to track any keras tensor we
161 # cloned The key is the string ID of the original keras tensor, and value is
162 # the cloned keras_tensor instance.
163 kt_id_mapping = {}
165 for kt_input in tf.nest.flatten(inputs):
166 if kt_input.node.is_input:
167 # For any existing keras_tensor from tf.keras.Input, we leave them
168 # as is.
169 cloned_inputs.append(kt_input)
170 kt_id_mapping[id(kt_input)] = kt_input
171 else:
172 # We need to create a new tf.keras.Input for any intermediate
173 # keras_tensor
174 cpy = _clone_keras_tensor(kt_input)
175 cloned_input = input_layer_module.Input(tensor=cpy)
176 cloned_inputs.append(cloned_input)
177 kt_id_mapping[id(kt_input)] = cloned_input
178 cloned_inputs = tf.nest.pack_sequence_as(inputs, cloned_inputs)
180 for kt_output in tf.nest.flatten(outputs):
181 cpy = _clone_keras_tensor(kt_output)
182 # We reuse the _keras_history here, which contains the old information.
183 # It is used in the Node constructor to check if the tensor
184 # "is_keras_tensor()" The history will be override by the Node
185 # constructor anyway for the corresponding layer output anyway.
186 cpy._keras_history = kt_output._keras_history
187 cloned_outputs.append(cpy)
188 kt_id_mapping[id(kt_output)] = cpy
189 cloned_outputs = tf.nest.pack_sequence_as(outputs, cloned_outputs)
191 for node in nodes_to_clone:
192 # Clone any keras_tensors to avoid override of _keras_history
193 # Or reuse an existing keras_tensor if it has already been cloned.
194 output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping)
195 call_args_copy = clone_keras_tensors(node.call_args, kt_id_mapping)
196 call_kwargs_copy = clone_keras_tensors(node.call_kwargs, kt_id_mapping)
197 # Creating new nodes based on the existing node information. Node wires
198 # itself to inbound and outbound layers. The Node constructor actually
199 # updates this layer's self._inbound_nodes, sets _keras_history on the
200 # outputs, and adds itself to the `_outbound_nodes` of the layers that
201 # produced the inputs to this layer call.
202 node_module.Node(
203 node.layer,
204 call_args=call_args_copy,
205 call_kwargs=call_kwargs_copy,
206 outputs=output_copy,
207 )
208 return cloned_inputs, cloned_outputs
211def clone_keras_tensors(args, keras_tensor_mapping):
212 """Clone the keras tensors from the inputs.
214 For any KerasTensor instance in the `args`, a new copy of KerasTensor will
215 be created if it has not been cloned yet (by checking the
216 `keras_tensor_mapping`). For any other types, the instance will be
217 unchanged. This function is useful for cloning the Nodes since KerasTensor
218 can't be reused across the models.
220 Args:
221 args: A nested structure of objects, which could contain KerasTensor.
222 keras_tensor_mapping: A dict contains the ID of original KerasTensor, and
223 the cloned KerasTensor instance. The dict will be updated with newly
224 copied KerasTensor instances within this method.
225 Returns:
226 Same structure as inputs, with KerasTensor cloned.
227 """
228 result = []
229 for obj in tf.nest.flatten(args):
230 if node_module.is_keras_tensor(obj):
231 if id(obj) in keras_tensor_mapping:
232 cpy = keras_tensor_mapping[id(obj)]
233 else:
234 # Create copy of keras_tensor if we haven't done it before
235 cpy = _clone_keras_tensor(obj)
236 cpy._keras_history = obj._keras_history
237 keras_tensor_mapping[id(obj)] = cpy
238 result.append(cpy)
239 else:
240 result.append(obj)
241 return tf.nest.pack_sequence_as(args, result)
244def _clone_keras_tensor(kt):
245 """Create an identical keras_tensor based on the input.
247 We use keras_tensor_to_placeholder and keras_tensor_from_tensor to make sure
248 inferred shape are not lost during the copy.
250 Args:
251 kt: the input KerasTensor.
253 Returns:
254 An identical copy of the input KerasTensor.
255 """
256 # Create a scratch graph since we don't intend to use the placeholders.
257 with backend._scratch_graph() as scratch_graph:
258 with scratch_graph.as_default():
259 placeholder = keras_tensor.keras_tensor_to_placeholder(kt)
260 return keras_tensor.keras_tensor_from_tensor(placeholder)