Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/conv_utils.py: 9%
190 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities used by convolution layers."""
17import itertools
19import numpy as np
20import tensorflow.compat.v2 as tf
22from keras.src import backend
25def convert_data_format(data_format, ndim):
26 if data_format == "channels_last":
27 if ndim == 3:
28 return "NWC"
29 elif ndim == 4:
30 return "NHWC"
31 elif ndim == 5:
32 return "NDHWC"
33 else:
34 raise ValueError(
35 f"Input rank not supported: {ndim}. "
36 "Expected values are [3, 4, 5]"
37 )
38 elif data_format == "channels_first":
39 if ndim == 3:
40 return "NCW"
41 elif ndim == 4:
42 return "NCHW"
43 elif ndim == 5:
44 return "NCDHW"
45 else:
46 raise ValueError(
47 f"Input rank not supported: {ndim}. "
48 "Expected values are [3, 4, 5]"
49 )
50 else:
51 raise ValueError(
52 f"Invalid data_format: {data_format}. "
53 'Expected values are ["channels_first", "channels_last"]'
54 )
57def normalize_tuple(value, n, name, allow_zero=False):
58 """Transforms non-negative/positive integer/integers into an integer tuple.
60 Args:
61 value: The value to validate and convert. Could an int, or any iterable of
62 ints.
63 n: The size of the tuple to be returned.
64 name: The name of the argument being validated, e.g. "strides" or
65 "kernel_size". This is only used to format error messages.
66 allow_zero: Default to False. A ValueError will raised if zero is received
67 and this param is False.
69 Returns:
70 A tuple of n integers.
72 Raises:
73 ValueError: If something else than an int/long or iterable thereof or a
74 negative value is
75 passed.
76 """
77 error_msg = (
78 f"The `{name}` argument must be a tuple of {n} "
79 f"integers. Received: {value}"
80 )
82 if isinstance(value, int):
83 value_tuple = (value,) * n
84 else:
85 try:
86 value_tuple = tuple(value)
87 except TypeError:
88 raise ValueError(error_msg)
89 if len(value_tuple) != n:
90 raise ValueError(error_msg)
91 for single_value in value_tuple:
92 try:
93 int(single_value)
94 except (ValueError, TypeError):
95 error_msg += (
96 f"including element {single_value} of "
97 f"type {type(single_value)}"
98 )
99 raise ValueError(error_msg)
101 if allow_zero:
102 unqualified_values = {v for v in value_tuple if v < 0}
103 req_msg = ">= 0"
104 else:
105 unqualified_values = {v for v in value_tuple if v <= 0}
106 req_msg = "> 0"
108 if unqualified_values:
109 error_msg += (
110 f" including {unqualified_values}"
111 f" that does not satisfy the requirement `{req_msg}`."
112 )
113 raise ValueError(error_msg)
115 return value_tuple
118def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
119 """Determines output length of a convolution given input length.
121 Args:
122 input_length: integer.
123 filter_size: integer.
124 padding: one of "same", "valid", "full", "causal"
125 stride: integer.
126 dilation: dilation rate, integer.
128 Returns:
129 The output length (integer).
130 """
131 if input_length is None:
132 return None
133 assert padding in {"same", "valid", "full", "causal"}
134 dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
135 if padding in ["same", "causal"]:
136 output_length = input_length
137 elif padding == "valid":
138 output_length = input_length - dilated_filter_size + 1
139 elif padding == "full":
140 output_length = input_length + dilated_filter_size - 1
141 return (output_length + stride - 1) // stride
144def conv_input_length(output_length, filter_size, padding, stride):
145 """Determines input length of a convolution given output length.
147 Args:
148 output_length: integer.
149 filter_size: integer.
150 padding: one of "same", "valid", "full".
151 stride: integer.
153 Returns:
154 The input length (integer).
155 """
156 if output_length is None:
157 return None
158 assert padding in {"same", "valid", "full"}
159 if padding == "same":
160 pad = filter_size // 2
161 elif padding == "valid":
162 pad = 0
163 elif padding == "full":
164 pad = filter_size - 1
165 return (output_length - 1) * stride - 2 * pad + filter_size
168def deconv_output_length(
169 input_length,
170 filter_size,
171 padding,
172 output_padding=None,
173 stride=0,
174 dilation=1,
175):
176 """Determines output length of a transposed convolution given input length.
178 Args:
179 input_length: Integer.
180 filter_size: Integer.
181 padding: one of `"same"`, `"valid"`, `"full"`.
182 output_padding: Integer, amount of padding along the output dimension.
183 Can be set to `None` in which case the output length is inferred.
184 stride: Integer.
185 dilation: Integer.
187 Returns:
188 The output length (integer).
189 """
190 assert padding in {"same", "valid", "full"}
191 if input_length is None:
192 return None
194 # Get the dilated kernel size
195 filter_size = filter_size + (filter_size - 1) * (dilation - 1)
197 # Infer length if output padding is None, else compute the exact length
198 if output_padding is None:
199 if padding == "valid":
200 length = input_length * stride + max(filter_size - stride, 0)
201 elif padding == "full":
202 length = input_length * stride - (stride + filter_size - 2)
203 elif padding == "same":
204 length = input_length * stride
206 else:
207 if padding == "same":
208 pad = filter_size // 2
209 elif padding == "valid":
210 pad = 0
211 elif padding == "full":
212 pad = filter_size - 1
214 length = (
215 (input_length - 1) * stride + filter_size - 2 * pad + output_padding
216 )
217 return length
220def normalize_data_format(value):
221 if value is None:
222 value = backend.image_data_format()
223 data_format = value.lower()
224 if data_format not in {"channels_first", "channels_last"}:
225 raise ValueError(
226 "The `data_format` argument must be one of "
227 f'"channels_first", "channels_last". Received: {value}'
228 )
229 return data_format
232def normalize_padding(value):
233 if isinstance(value, (list, tuple)):
234 return value
235 padding = value.lower()
236 if padding not in {"valid", "same", "causal"}:
237 raise ValueError(
238 "The `padding` argument must be a list/tuple or one of "
239 '"valid", "same" (or "causal", only for `Conv1D). '
240 f"Received: {padding}"
241 )
242 return padding
245def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
246 """Compute a mask representing the connectivity of a convolution operation.
248 Assume a convolution with given parameters is applied to an input having N
249 spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
250 output with shape `(d_out1, ..., d_outN)`. This method returns a boolean
251 array of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True`
252 entries indicating pairs of input and output locations that are connected by
253 a weight.
255 Example:
257 >>> input_shape = (4,)
258 >>> kernel_shape = (2,)
259 >>> strides = (1,)
260 >>> padding = "valid"
261 >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
262 array([[ True, False, False],
263 [ True, True, False],
264 [False, True, True],
265 [False, False, True]])
267 where rows and columns correspond to inputs and outputs respectively.
270 Args:
271 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
272 input.
273 kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
274 receptive field.
275 strides: tuple of size N, strides along each spatial dimension.
276 padding: type of padding, string `"same"` or `"valid"`.
277 `"valid"` means no padding. `"same"` results in padding evenly to
278 the left/right or up/down of the input such that output has the same
279 height/width dimension as the input.
281 Returns:
282 A boolean 2N-D `np.ndarray` of shape
283 `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
284 is the spatial shape of the output. `True` entries in the mask represent
285 pairs of input-output locations that are connected by a weight.
287 Raises:
288 ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
289 same number of dimensions.
290 NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
291 """
292 if padding not in {"same", "valid"}:
293 raise NotImplementedError(
294 f"Padding type {padding} not supported. "
295 'Only "valid" and "same" are implemented.'
296 )
298 in_dims = len(input_shape)
299 if isinstance(kernel_shape, int):
300 kernel_shape = (kernel_shape,) * in_dims
301 if isinstance(strides, int):
302 strides = (strides,) * in_dims
304 kernel_dims = len(kernel_shape)
305 stride_dims = len(strides)
306 if kernel_dims != in_dims or stride_dims != in_dims:
307 raise ValueError(
308 "Number of strides, input and kernel dimensions must all "
309 f"match. Received: stride_dims={stride_dims}, "
310 f"in_dims={in_dims}, kernel_dims={kernel_dims}"
311 )
313 output_shape = conv_output_shape(
314 input_shape, kernel_shape, strides, padding
315 )
317 mask_shape = input_shape + output_shape
318 mask = np.zeros(mask_shape, bool)
320 output_axes_ticks = [range(dim) for dim in output_shape]
321 for output_position in itertools.product(*output_axes_ticks):
322 input_axes_ticks = conv_connected_inputs(
323 input_shape, kernel_shape, output_position, strides, padding
324 )
325 for input_position in itertools.product(*input_axes_ticks):
326 mask[input_position + output_position] = True
328 return mask
331def conv_kernel_idxs(
332 input_shape,
333 kernel_shape,
334 strides,
335 padding,
336 filters_in,
337 filters_out,
338 data_format,
339):
340 """Yields output-input tuples of indices in a CNN layer.
342 The generator iterates over all `(output_idx, input_idx)` tuples, where
343 `output_idx` is an integer index in a flattened tensor representing a single
344 output image of a convolutional layer that is connected (via the layer
345 weights) to the respective single input image at `input_idx`
347 Example:
349 >>> input_shape = (2, 2)
350 >>> kernel_shape = (2, 1)
351 >>> strides = (1, 1)
352 >>> padding = "valid"
353 >>> filters_in = 1
354 >>> filters_out = 1
355 >>> data_format = "channels_last"
356 >>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding,
357 ... filters_in, filters_out, data_format))
358 [(0, 0), (0, 2), (1, 1), (1, 3)]
360 Args:
361 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
362 input.
363 kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
364 receptive field.
365 strides: tuple of size N, strides along each spatial dimension.
366 padding: type of padding, string `"same"` or `"valid"`.
367 `"valid"` means no padding. `"same"` results in padding evenly to
368 the left/right or up/down of the input such that output has the same
369 height/width dimension as the input.
370 filters_in: `int`, number if filters in the input to the layer.
371 filters_out: `int', number if filters in the output of the layer.
372 data_format: string, "channels_first" or "channels_last".
374 Yields:
375 The next tuple `(output_idx, input_idx)`, where `output_idx` is an integer
376 index in a flattened tensor representing a single output image of a
377 convolutional layer that is connected (via the layer weights) to the
378 respective single input image at `input_idx`.
380 Raises:
381 ValueError: if `data_format` is neither `"channels_last"` nor
382 `"channels_first"`, or if number of strides, input, and kernel number
383 of dimensions do not match.
385 NotImplementedError: if `padding` is neither `"same"` nor `"valid"`.
386 """
387 if padding not in ("same", "valid"):
388 raise NotImplementedError(
389 f"Padding type {padding} not supported. "
390 'Only "valid" and "same" are implemented.'
391 )
393 in_dims = len(input_shape)
394 if isinstance(kernel_shape, int):
395 kernel_shape = (kernel_shape,) * in_dims
396 if isinstance(strides, int):
397 strides = (strides,) * in_dims
399 kernel_dims = len(kernel_shape)
400 stride_dims = len(strides)
401 if kernel_dims != in_dims or stride_dims != in_dims:
402 raise ValueError(
403 "Number of strides, input and kernel dimensions must all "
404 f"match. Received: stride_dims={stride_dims}, "
405 f"in_dims={in_dims}, kernel_dims={kernel_dims}"
406 )
408 output_shape = conv_output_shape(
409 input_shape, kernel_shape, strides, padding
410 )
411 output_axes_ticks = [range(dim) for dim in output_shape]
413 if data_format == "channels_first":
414 concat_idxs = (
415 lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx
416 )
417 elif data_format == "channels_last":
418 concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (
419 filter_idx,
420 )
421 else:
422 raise ValueError(
423 f"Data format `{data_format}` not recognized."
424 '`data_format` must be "channels_first" or "channels_last".'
425 )
427 for output_position in itertools.product(*output_axes_ticks):
428 input_axes_ticks = conv_connected_inputs(
429 input_shape, kernel_shape, output_position, strides, padding
430 )
431 for input_position in itertools.product(*input_axes_ticks):
432 for f_in in range(filters_in):
433 for f_out in range(filters_out):
434 out_idx = np.ravel_multi_index(
435 multi_index=concat_idxs(output_position, f_out),
436 dims=concat_idxs(output_shape, filters_out),
437 )
438 in_idx = np.ravel_multi_index(
439 multi_index=concat_idxs(input_position, f_in),
440 dims=concat_idxs(input_shape, filters_in),
441 )
442 yield (out_idx, in_idx)
445def conv_connected_inputs(
446 input_shape, kernel_shape, output_position, strides, padding
447):
448 """Return locations of the input connected to an output position.
450 Assume a convolution with given parameters is applied to an input having N
451 spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
452 returns N ranges specifying the input region that was convolved with the
453 kernel to produce the output at position
454 `output_position = (p_out1, ..., p_outN)`.
456 Example:
458 >>> input_shape = (4, 4)
459 >>> kernel_shape = (2, 1)
460 >>> output_position = (1, 1)
461 >>> strides = (1, 1)
462 >>> padding = "valid"
463 >>> conv_connected_inputs(input_shape, kernel_shape, output_position,
464 ... strides, padding)
465 [range(1, 3), range(1, 2)]
467 Args:
468 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
469 input.
470 kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
471 receptive field.
472 output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single
473 position in the output of the convolution.
474 strides: tuple of size N, strides along each spatial dimension.
475 padding: type of padding, string `"same"` or `"valid"`.
476 `"valid"` means no padding. `"same"` results in padding evenly to
477 the left/right or up/down of the input such that output has the same
478 height/width dimension as the input.
480 Returns:
481 N ranges `[[p_in_left1, ..., p_in_right1], ...,
482 [p_in_leftN, ..., p_in_rightN]]` specifying the region in the
483 input connected to output_position.
484 """
485 ranges = []
487 ndims = len(input_shape)
488 for d in range(ndims):
489 left_shift = int(kernel_shape[d] / 2)
490 right_shift = kernel_shape[d] - left_shift
492 center = output_position[d] * strides[d]
494 if padding == "valid":
495 center += left_shift
497 start = max(0, center - left_shift)
498 end = min(input_shape[d], center + right_shift)
500 ranges.append(range(start, end))
502 return ranges
505def conv_output_shape(input_shape, kernel_shape, strides, padding):
506 """Return the output shape of an N-D convolution.
508 Forces dimensions where input is empty (size 0) to remain empty.
510 Args:
511 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
512 input.
513 kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
514 receptive field.
515 strides: tuple of size N, strides along each spatial dimension.
516 padding: type of padding, string `"same"` or `"valid"`.
517 `"valid"` means no padding. `"same"` results in padding evenly to
518 the left/right or up/down of the input such that output has the same
519 height/width dimension as the input.
521 Returns:
522 tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
523 """
524 dims = range(len(kernel_shape))
525 output_shape = [
526 conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d])
527 for d in dims
528 ]
529 output_shape = tuple(
530 [0 if input_shape[d] == 0 else output_shape[d] for d in dims]
531 )
532 return output_shape
535def squeeze_batch_dims(inp, op, inner_rank):
536 """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
538 Where `squeeze_batch` reshapes `inp` to shape
539 `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]`
540 and `unsqueeze_batch` does the reverse reshape but on the output.
542 Args:
543 inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape`
544 is length `inner_rank`.
545 op: A callable that takes a single input tensor and returns a single.
546 output tensor.
547 inner_rank: A python integer.
549 Returns:
550 `unsqueeze_batch_op(squeeze_batch(inp))`.
551 """
552 with tf.name_scope("squeeze_batch_dims"):
553 shape = inp.shape
555 inner_shape = shape[-inner_rank:]
556 if not inner_shape.is_fully_defined():
557 inner_shape = tf.shape(inp)[-inner_rank:]
559 batch_shape = shape[:-inner_rank]
560 if not batch_shape.is_fully_defined():
561 batch_shape = tf.shape(inp)[:-inner_rank]
563 if isinstance(inner_shape, tf.TensorShape):
564 inp_reshaped = tf.reshape(inp, [-1] + inner_shape.as_list())
565 else:
566 inp_reshaped = tf.reshape(
567 inp, tf.concat(([-1], inner_shape), axis=-1)
568 )
570 out_reshaped = op(inp_reshaped)
572 out_inner_shape = out_reshaped.shape[-inner_rank:]
573 if not out_inner_shape.is_fully_defined():
574 out_inner_shape = tf.shape(out_reshaped)[-inner_rank:]
576 out = tf.reshape(
577 out_reshaped, tf.concat((batch_shape, out_inner_shape), axis=-1)
578 )
580 out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
581 return out