Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/input_spec.py: 15%
116 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
17"""Contains the InputSpec class."""
19import tensorflow.compat.v2 as tf
21from keras.src import backend
23# isort: off
24from tensorflow.python.util.tf_export import keras_export
25from tensorflow.python.util.tf_export import tf_export
28@keras_export(
29 "keras.layers.InputSpec",
30 v1=["keras.layers.InputSpec", "keras.__internal__.legacy.layers.InputSpec"],
31)
32@tf_export(v1=["layers.InputSpec"])
33class InputSpec:
34 """Specifies the rank, dtype and shape of every input to a layer.
36 Layers can expose (if appropriate) an `input_spec` attribute:
37 an instance of `InputSpec`, or a nested structure of `InputSpec` instances
38 (one per input tensor). These objects enable the layer to run input
39 compatibility checks for input structure, input rank, input shape, and
40 input dtype.
42 A None entry in a shape is compatible with any dimension,
43 a None shape is compatible with any shape.
45 Args:
46 dtype: Expected DataType of the input.
47 shape: Shape tuple, expected shape of the input
48 (may include None for unchecked axes). Includes the batch size.
49 ndim: Integer, expected rank of the input.
50 max_ndim: Integer, maximum rank of the input.
51 min_ndim: Integer, minimum rank of the input.
52 axes: Dictionary mapping integer axes to
53 a specific dimension value.
54 allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
55 as the last axis of the input is 1, as well as inputs of rank N-1
56 as long as the last axis of the spec is 1.
57 name: Expected key corresponding to this input when passing data as
58 a dictionary.
60 Example:
62 ```python
63 class MyLayer(Layer):
64 def __init__(self):
65 super(MyLayer, self).__init__()
66 # The layer will accept inputs with
67 # shape (?, 28, 28) & (?, 28, 28, 1)
68 # and raise an appropriate error message otherwise.
69 self.input_spec = InputSpec(
70 shape=(None, 28, 28, 1),
71 allow_last_axis_squeeze=True)
72 ```
73 """
75 def __init__(
76 self,
77 dtype=None,
78 shape=None,
79 ndim=None,
80 max_ndim=None,
81 min_ndim=None,
82 axes=None,
83 allow_last_axis_squeeze=False,
84 name=None,
85 ):
86 self.dtype = tf.as_dtype(dtype).name if dtype is not None else None
87 shape = tf.TensorShape(shape)
88 if shape.rank is None:
89 shape = None
90 else:
91 shape = tuple(shape.as_list())
92 if shape is not None:
93 self.ndim = len(shape)
94 self.shape = shape
95 else:
96 self.ndim = ndim
97 self.shape = None
98 self.max_ndim = max_ndim
99 self.min_ndim = min_ndim
100 self.name = name
101 self.allow_last_axis_squeeze = allow_last_axis_squeeze
102 try:
103 axes = axes or {}
104 self.axes = {int(k): axes[k] for k in axes}
105 except (ValueError, TypeError):
106 raise TypeError(
107 "Argument `axes` must be a dict with integer keys. "
108 f"Received: axes={axes}"
109 )
111 if self.axes and (self.ndim is not None or self.max_ndim is not None):
112 max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
113 max_axis = max(self.axes)
114 if max_axis > max_dim:
115 raise ValueError(
116 "Axis {} is greater than the maximum "
117 "allowed value: {}".format(max_axis, max_dim)
118 )
120 def __repr__(self):
121 spec = [
122 ("dtype=" + str(self.dtype)) if self.dtype else "",
123 ("shape=" + str(self.shape)) if self.shape else "",
124 ("ndim=" + str(self.ndim)) if self.ndim else "",
125 ("max_ndim=" + str(self.max_ndim)) if self.max_ndim else "",
126 ("min_ndim=" + str(self.min_ndim)) if self.min_ndim else "",
127 ("axes=" + str(self.axes)) if self.axes else "",
128 ]
129 return f"InputSpec({', '.join(x for x in spec if x)})"
131 def get_config(self):
132 return {
133 "dtype": self.dtype,
134 "shape": self.shape,
135 "ndim": self.ndim,
136 "max_ndim": self.max_ndim,
137 "min_ndim": self.min_ndim,
138 "axes": self.axes,
139 }
141 @classmethod
142 def from_config(cls, config):
143 return cls(**config)
146def to_tensor_shape(spec):
147 """Returns a tf.TensorShape object that matches the shape specifications.
149 If the InputSpec's shape or ndim is defined, this method will return a fully
150 or partially-known shape. Otherwise, the returned TensorShape is None.
152 Args:
153 spec: an InputSpec object.
155 Returns:
156 a tf.TensorShape object
157 """
158 if spec.ndim is None and spec.shape is None:
159 return tf.TensorShape(None)
160 elif spec.shape is not None:
161 return tf.TensorShape(spec.shape)
162 else:
163 shape = [None] * spec.ndim
164 for a in spec.axes:
165 shape[a] = spec.axes[a] # Assume that axes is defined
166 return tf.TensorShape(shape)
169def assert_input_compatibility(input_spec, inputs, layer_name):
170 """Checks compatibility between the layer and provided inputs.
172 This checks that the tensor(s) `inputs` verify the input assumptions
173 of a layer (if any). If not, a clear and actional exception gets raised.
175 Args:
176 input_spec: An InputSpec instance, list of InputSpec instances, a nested
177 structure of InputSpec instances, or None.
178 inputs: Input tensor, list of input tensors, or a nested structure of
179 input tensors.
180 layer_name: String, name of the layer (for error message formatting).
182 Raises:
183 ValueError: in case of mismatch between
184 the provided inputs and the expectations of the layer.
185 """
186 if not input_spec:
187 return
189 input_spec = tf.nest.flatten(input_spec)
190 if isinstance(inputs, dict):
191 # Flatten `inputs` by reference order if input spec names are provided
192 names = [spec.name for spec in input_spec]
193 if all(names):
194 list_inputs = []
195 for name in names:
196 if name not in inputs:
197 raise ValueError(
198 f'Missing data for input "{name}". '
199 "You passed a data dictionary with keys "
200 f"{list(inputs.keys())}. "
201 f"Expected the following keys: {names}"
202 )
203 list_inputs.append(inputs[name])
204 inputs = list_inputs
206 inputs = tf.nest.flatten(inputs)
207 for x in inputs:
208 # Having a shape/dtype is the only commonality of the various
209 # tensor-like objects that may be passed. The most common kind of
210 # invalid type we are guarding for is a Layer instance (Functional API),
211 # which does not have a `shape` attribute.
212 if not hasattr(x, "shape"):
213 raise TypeError(
214 f"Inputs to a layer should be tensors. Got '{x}' "
215 f"(of type {type(x)}) as input for layer '{layer_name}'."
216 )
218 if len(inputs) != len(input_spec):
219 raise ValueError(
220 f'Layer "{layer_name}" expects {len(input_spec)} input(s),'
221 f" but it received {len(inputs)} input tensors. "
222 f"Inputs received: {inputs}"
223 )
224 for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
225 if spec is None:
226 continue
228 shape = tf.TensorShape(x.shape)
229 if shape.rank is None:
230 return
231 # Check ndim.
232 if spec.ndim is not None and not spec.allow_last_axis_squeeze:
233 ndim = shape.rank
234 if ndim != spec.ndim:
235 raise ValueError(
236 f'Input {input_index} of layer "{layer_name}" '
237 "is incompatible with the layer: "
238 f"expected ndim={spec.ndim}, found ndim={ndim}. "
239 f"Full shape received: {tuple(shape)}"
240 )
241 if spec.max_ndim is not None:
242 ndim = x.shape.rank
243 if ndim is not None and ndim > spec.max_ndim:
244 raise ValueError(
245 f'Input {input_index} of layer "{layer_name}" '
246 "is incompatible with the layer: "
247 f"expected max_ndim={spec.max_ndim}, "
248 f"found ndim={ndim}"
249 )
250 if spec.min_ndim is not None:
251 ndim = x.shape.rank
252 if ndim is not None and ndim < spec.min_ndim:
253 raise ValueError(
254 f'Input {input_index} of layer "{layer_name}" '
255 "is incompatible with the layer: "
256 f"expected min_ndim={spec.min_ndim}, "
257 f"found ndim={ndim}. "
258 f"Full shape received: {tuple(shape)}"
259 )
260 # Check dtype.
261 if spec.dtype is not None:
262 if x.dtype.name != spec.dtype:
263 raise ValueError(
264 f'Input {input_index} of layer "{layer_name}" '
265 "is incompatible with the layer: "
266 f"expected dtype={spec.dtype}, "
267 f"found dtype={x.dtype}"
268 )
270 # Check specific shape axes.
271 shape_as_list = shape.as_list()
272 if spec.axes:
273 for axis, value in spec.axes.items():
274 if hasattr(value, "value"):
275 value = value.value
276 if value is not None and shape_as_list[int(axis)] not in {
277 value,
278 None,
279 }:
280 raise ValueError(
281 f'Input {input_index} of layer "{layer_name}" is '
282 f"incompatible with the layer: expected axis {axis} "
283 f"of input shape to have value {value}, "
284 "but received input with "
285 f"shape {display_shape(x.shape)}"
286 )
287 # Check shape.
288 if spec.shape is not None and shape.rank is not None:
289 spec_shape = spec.shape
290 if spec.allow_last_axis_squeeze:
291 if shape_as_list and shape_as_list[-1] == 1:
292 shape_as_list = shape_as_list[:-1]
293 if spec_shape and spec_shape[-1] == 1:
294 spec_shape = spec_shape[:-1]
295 for spec_dim, dim in zip(spec_shape, shape_as_list):
296 if spec_dim is not None and dim is not None:
297 if spec_dim != dim:
298 raise ValueError(
299 f'Input {input_index} of layer "{layer_name}" is '
300 "incompatible with the layer: "
301 f"expected shape={spec.shape}, "
302 f"found shape={display_shape(x.shape)}"
303 )
306def display_shape(shape):
307 return str(tuple(shape.as_list()))
310def to_tensor_spec(input_spec, default_dtype=None):
311 """Converts a Keras InputSpec object to a TensorSpec."""
312 default_dtype = default_dtype or backend.floatx()
313 if isinstance(input_spec, InputSpec):
314 dtype = input_spec.dtype or default_dtype
315 return tf.TensorSpec(to_tensor_shape(input_spec), dtype)
316 return tf.TensorSpec(None, default_dtype)