Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/input_spec.py: 17%
119 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-classes-have-attributes
17"""Contains the InputSpec class."""
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import tensor_spec
22from tensorflow.python.keras import backend
23from tensorflow.python.util import nest
24from tensorflow.python.util.tf_export import keras_export
25from tensorflow.python.util.tf_export import tf_export
28@keras_export('keras.layers.InputSpec')
29@tf_export(v1=['layers.InputSpec'])
30class InputSpec(object):
31 """Specifies the rank, dtype and shape of every input to a layer.
33 Layers can expose (if appropriate) an `input_spec` attribute:
34 an instance of `InputSpec`, or a nested structure of `InputSpec` instances
35 (one per input tensor). These objects enable the layer to run input
36 compatibility checks for input structure, input rank, input shape, and
37 input dtype.
39 A None entry in a shape is compatible with any dimension,
40 a None shape is compatible with any shape.
42 Args:
43 dtype: Expected DataType of the input.
44 shape: Shape tuple, expected shape of the input
45 (may include None for unchecked axes). Includes the batch size.
46 ndim: Integer, expected rank of the input.
47 max_ndim: Integer, maximum rank of the input.
48 min_ndim: Integer, minimum rank of the input.
49 axes: Dictionary mapping integer axes to
50 a specific dimension value.
51 allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
52 as the last axis of the input is 1, as well as inputs of rank N-1
53 as long as the last axis of the spec is 1.
54 name: Expected key corresponding to this input when passing data as
55 a dictionary.
57 Example:
59 ```python
60 class MyLayer(Layer):
61 def __init__(self):
62 super(MyLayer, self).__init__()
63 # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
64 # and raise an appropriate error message otherwise.
65 self.input_spec = InputSpec(
66 shape=(None, 28, 28, 1),
67 allow_last_axis_squeeze=True)
68 ```
69 """
71 def __init__(self,
72 dtype=None,
73 shape=None,
74 ndim=None,
75 max_ndim=None,
76 min_ndim=None,
77 axes=None,
78 allow_last_axis_squeeze=False,
79 name=None):
80 self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
81 shape = tensor_shape.TensorShape(shape)
82 if shape.rank is None:
83 shape = None
84 else:
85 shape = tuple(shape.as_list())
86 if shape is not None:
87 self.ndim = len(shape)
88 self.shape = shape
89 else:
90 self.ndim = ndim
91 self.shape = None
92 self.max_ndim = max_ndim
93 self.min_ndim = min_ndim
94 self.name = name
95 self.allow_last_axis_squeeze = allow_last_axis_squeeze
96 try:
97 axes = axes or {}
98 self.axes = {int(k): axes[k] for k in axes}
99 except (ValueError, TypeError):
100 raise TypeError('The keys in axes must be integers.')
102 if self.axes and (self.ndim is not None or self.max_ndim is not None):
103 max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
104 max_axis = max(self.axes)
105 if max_axis > max_dim:
106 raise ValueError('Axis {} is greater than the maximum allowed value: {}'
107 .format(max_axis, max_dim))
109 def __repr__(self):
110 spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
111 ('shape=' + str(self.shape)) if self.shape else '',
112 ('ndim=' + str(self.ndim)) if self.ndim else '',
113 ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
114 ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
115 ('axes=' + str(self.axes)) if self.axes else '']
116 return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
118 def get_config(self):
119 return {
120 'dtype': self.dtype,
121 'shape': self.shape,
122 'ndim': self.ndim,
123 'max_ndim': self.max_ndim,
124 'min_ndim': self.min_ndim,
125 'axes': self.axes}
127 @classmethod
128 def from_config(cls, config):
129 return cls(**config)
132def to_tensor_shape(spec):
133 """Returns a tf.TensorShape object that matches the shape specifications.
135 If the InputSpec's shape or ndim is defined, this method will return a fully
136 or partially-known shape. Otherwise, the returned TensorShape is None.
138 Args:
139 spec: an InputSpec object.
141 Returns:
142 a tf.TensorShape object
143 """
144 if spec.ndim is None and spec.shape is None:
145 return tensor_shape.TensorShape(None)
146 elif spec.shape is not None:
147 return tensor_shape.TensorShape(spec.shape)
148 else:
149 shape = [None] * spec.ndim
150 for a in spec.axes:
151 shape[a] = spec.axes[a] # Assume that axes is defined
152 return tensor_shape.TensorShape(shape)
155def assert_input_compatibility(input_spec, inputs, layer_name):
156 """Checks compatibility between the layer and provided inputs.
158 This checks that the tensor(s) `inputs` verify the input assumptions
159 of a layer (if any). If not, a clear and actional exception gets raised.
161 Args:
162 input_spec: An InputSpec instance, list of InputSpec instances, a nested
163 structure of InputSpec instances, or None.
164 inputs: Input tensor, list of input tensors, or a nested structure of
165 input tensors.
166 layer_name: String, name of the layer (for error message formatting).
168 Raises:
169 ValueError: in case of mismatch between
170 the provided inputs and the expectations of the layer.
171 """
172 if not input_spec:
173 return
175 input_spec = nest.flatten(input_spec)
176 if isinstance(inputs, dict):
177 # Flatten `inputs` by reference order if input spec names are provided
178 names = [spec.name for spec in input_spec]
179 if all(names):
180 list_inputs = []
181 for name in names:
182 if name not in inputs:
183 raise ValueError('Missing data for input "%s". '
184 'You passed a data dictionary with keys %s. '
185 'Expected the following keys: %s' %
186 (name, list(inputs.keys()), names))
187 list_inputs.append(inputs[name])
188 inputs = list_inputs
190 inputs = nest.flatten(inputs)
191 for x in inputs:
192 # Having a shape/dtype is the only commonality of the various tensor-like
193 # objects that may be passed. The most common kind of invalid type we are
194 # guarding for is a Layer instance (Functional API), which does not
195 # have a `shape` attribute.
196 if not hasattr(x, 'shape'):
197 raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))
199 if len(inputs) != len(input_spec):
200 raise ValueError('Layer ' + layer_name + ' expects ' +
201 str(len(input_spec)) + ' input(s), '
202 'but it received ' + str(len(inputs)) +
203 ' input tensors. Inputs received: ' + str(inputs))
204 for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
205 if spec is None:
206 continue
208 shape = tensor_shape.TensorShape(x.shape)
209 if shape.rank is None:
210 return
211 # Check ndim.
212 if spec.ndim is not None and not spec.allow_last_axis_squeeze:
213 ndim = shape.rank
214 if ndim != spec.ndim:
215 raise ValueError('Input ' + str(input_index) + ' of layer ' +
216 layer_name + ' is incompatible with the layer: '
217 'expected ndim=' + str(spec.ndim) + ', found ndim=' +
218 str(ndim) + '. Full shape received: ' +
219 str(tuple(shape)))
220 if spec.max_ndim is not None:
221 ndim = x.shape.rank
222 if ndim is not None and ndim > spec.max_ndim:
223 raise ValueError('Input ' + str(input_index) + ' of layer ' +
224 layer_name + ' is incompatible with the layer: '
225 'expected max_ndim=' + str(spec.max_ndim) +
226 ', found ndim=' + str(ndim))
227 if spec.min_ndim is not None:
228 ndim = x.shape.rank
229 if ndim is not None and ndim < spec.min_ndim:
230 raise ValueError('Input ' + str(input_index) + ' of layer ' +
231 layer_name + ' is incompatible with the layer: '
232 ': expected min_ndim=' + str(spec.min_ndim) +
233 ', found ndim=' + str(ndim) +
234 '. Full shape received: ' +
235 str(tuple(shape)))
236 # Check dtype.
237 if spec.dtype is not None:
238 if x.dtype.name != spec.dtype:
239 raise ValueError('Input ' + str(input_index) + ' of layer ' +
240 layer_name + ' is incompatible with the layer: '
241 'expected dtype=' + str(spec.dtype) +
242 ', found dtype=' + str(x.dtype))
244 # Check specific shape axes.
245 shape_as_list = shape.as_list()
246 if spec.axes:
247 for axis, value in spec.axes.items():
248 if hasattr(value, 'value'):
249 value = value.value
250 if value is not None and shape_as_list[int(axis)] not in {value, None}:
251 raise ValueError(
252 'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
253 ' incompatible with the layer: expected axis ' + str(axis) +
254 ' of input shape to have value ' + str(value) +
255 ' but received input with shape ' + display_shape(x.shape))
256 # Check shape.
257 if spec.shape is not None and shape.rank is not None:
258 spec_shape = spec.shape
259 if spec.allow_last_axis_squeeze:
260 if shape_as_list and shape_as_list[-1] == 1:
261 shape_as_list = shape_as_list[:-1]
262 if spec_shape and spec_shape[-1] == 1:
263 spec_shape = spec_shape[:-1]
264 for spec_dim, dim in zip(spec_shape, shape_as_list):
265 if spec_dim is not None and dim is not None:
266 if spec_dim != dim:
267 raise ValueError('Input ' + str(input_index) +
268 ' is incompatible with layer ' + layer_name +
269 ': expected shape=' + str(spec.shape) +
270 ', found shape=' + display_shape(x.shape))
273def display_shape(shape):
274 return str(tuple(shape.as_list()))
277def to_tensor_spec(input_spec, default_dtype=None):
278 """Converts a Keras InputSpec object to a TensorSpec."""
279 default_dtype = default_dtype or backend.floatx()
280 if isinstance(input_spec, InputSpec):
281 dtype = input_spec.dtype or default_dtype
282 return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
283 return tensor_spec.TensorSpec(None, default_dtype)