Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/cropping1d.py: 43%
28 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras cropping layer for 1D input."""
18import tensorflow.compat.v2 as tf
20from keras.src.engine.base_layer import Layer
21from keras.src.engine.input_spec import InputSpec
22from keras.src.utils import conv_utils
24# isort: off
25from tensorflow.python.util.tf_export import keras_export
28@keras_export("keras.layers.Cropping1D")
29class Cropping1D(Layer):
30 """Cropping layer for 1D input (e.g. temporal sequence).
32 It crops along the time dimension (axis 1).
34 Examples:
36 >>> input_shape = (2, 3, 2)
37 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
38 >>> print(x)
39 [[[ 0 1]
40 [ 2 3]
41 [ 4 5]]
42 [[ 6 7]
43 [ 8 9]
44 [10 11]]]
45 >>> y = tf.keras.layers.Cropping1D(cropping=1)(x)
46 >>> print(y)
47 tf.Tensor(
48 [[[2 3]]
49 [[8 9]]], shape=(2, 1, 2), dtype=int64)
51 Args:
52 cropping: Int or tuple of int (length 2)
53 How many units should be trimmed off at the beginning and end of
54 the cropping dimension (axis 1).
55 If a single int is provided, the same value will be used for both.
57 Input shape:
58 3D tensor with shape `(batch_size, axis_to_crop, features)`
60 Output shape:
61 3D tensor with shape `(batch_size, cropped_axis, features)`
62 """
64 def __init__(self, cropping=(1, 1), **kwargs):
65 super().__init__(**kwargs)
66 self.cropping = conv_utils.normalize_tuple(
67 cropping, 2, "cropping", allow_zero=True
68 )
69 self.input_spec = InputSpec(ndim=3)
71 def compute_output_shape(self, input_shape):
72 input_shape = tf.TensorShape(input_shape).as_list()
73 if input_shape[1] is not None:
74 length = input_shape[1] - self.cropping[0] - self.cropping[1]
75 else:
76 length = None
77 return tf.TensorShape([input_shape[0], length, input_shape[2]])
79 def call(self, inputs):
80 if (
81 inputs.shape[1] is not None
82 and sum(self.cropping) >= inputs.shape[1]
83 ):
84 raise ValueError(
85 "cropping parameter of Cropping layer must be "
86 "greater than the input shape. Received: inputs.shape="
87 f"{inputs.shape}, and cropping={self.cropping}"
88 )
89 if self.cropping[1] == 0:
90 return inputs[:, self.cropping[0] :, :]
91 else:
92 return inputs[:, self.cropping[0] : -self.cropping[1], :]
94 def get_config(self):
95 config = {"cropping": self.cropping}
96 base_config = super().get_config()
97 return dict(list(base_config.items()) + list(config.items()))