Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/cropping2d.py: 24%
50 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras cropping layer for 2D input."""
18import tensorflow.compat.v2 as tf
20from keras.src.engine.base_layer import Layer
21from keras.src.engine.input_spec import InputSpec
22from keras.src.utils import conv_utils
24# isort: off
25from tensorflow.python.util.tf_export import keras_export
28@keras_export("keras.layers.Cropping2D")
29class Cropping2D(Layer):
30 """Cropping layer for 2D input (e.g. picture).
32 It crops along spatial dimensions, i.e. height and width.
34 Examples:
36 >>> input_shape = (2, 28, 28, 3)
37 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
38 >>> y = tf.keras.layers.Cropping2D(cropping=((2, 2), (4, 4)))(x)
39 >>> print(y.shape)
40 (2, 24, 20, 3)
42 Args:
43 cropping: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
44 - If int: the same symmetric cropping
45 is applied to height and width.
46 - If tuple of 2 ints:
47 interpreted as two different
48 symmetric cropping values for height and width:
49 `(symmetric_height_crop, symmetric_width_crop)`.
50 - If tuple of 2 tuples of 2 ints:
51 interpreted as
52 `((top_crop, bottom_crop), (left_crop, right_crop))`
53 data_format: A string,
54 one of `channels_last` (default) or `channels_first`.
55 The ordering of the dimensions in the inputs.
56 `channels_last` corresponds to inputs with shape
57 `(batch_size, height, width, channels)` while `channels_first`
58 corresponds to inputs with shape
59 `(batch_size, channels, height, width)`.
60 When unspecified, uses
61 `image_data_format` value found in your Keras config file at
62 `~/.keras/keras.json` (if exists) else 'channels_last'.
63 Defaults to 'channels_last'.
65 Input shape:
66 4D tensor with shape:
67 - If `data_format` is `"channels_last"`:
68 `(batch_size, rows, cols, channels)`
69 - If `data_format` is `"channels_first"`:
70 `(batch_size, channels, rows, cols)`
72 Output shape:
73 4D tensor with shape:
74 - If `data_format` is `"channels_last"`:
75 `(batch_size, cropped_rows, cropped_cols, channels)`
76 - If `data_format` is `"channels_first"`:
77 `(batch_size, channels, cropped_rows, cropped_cols)`
78 """
80 def __init__(self, cropping=((0, 0), (0, 0)), data_format=None, **kwargs):
81 super().__init__(**kwargs)
82 self.data_format = conv_utils.normalize_data_format(data_format)
83 if isinstance(cropping, int):
84 self.cropping = ((cropping, cropping), (cropping, cropping))
85 elif hasattr(cropping, "__len__"):
86 if len(cropping) != 2:
87 raise ValueError(
88 "`cropping` should have two elements. "
89 f"Received: {cropping}."
90 )
91 height_cropping = conv_utils.normalize_tuple(
92 cropping[0], 2, "1st entry of cropping", allow_zero=True
93 )
94 width_cropping = conv_utils.normalize_tuple(
95 cropping[1], 2, "2nd entry of cropping", allow_zero=True
96 )
97 self.cropping = (height_cropping, width_cropping)
98 else:
99 raise ValueError(
100 "`cropping` should be either an int, "
101 "a tuple of 2 ints "
102 "(symmetric_height_crop, symmetric_width_crop), "
103 "or a tuple of 2 tuples of 2 ints "
104 "((top_crop, bottom_crop), (left_crop, right_crop)). "
105 f"Received: {cropping}."
106 )
107 self.input_spec = InputSpec(ndim=4)
109 def compute_output_shape(self, input_shape):
110 input_shape = tf.TensorShape(input_shape).as_list()
112 if self.data_format == "channels_first":
113 return tf.TensorShape(
114 [
115 input_shape[0],
116 input_shape[1],
117 input_shape[2] - self.cropping[0][0] - self.cropping[0][1]
118 if input_shape[2]
119 else None,
120 input_shape[3] - self.cropping[1][0] - self.cropping[1][1]
121 if input_shape[3]
122 else None,
123 ]
124 )
125 else:
126 return tf.TensorShape(
127 [
128 input_shape[0],
129 input_shape[1] - self.cropping[0][0] - self.cropping[0][1]
130 if input_shape[1]
131 else None,
132 input_shape[2] - self.cropping[1][0] - self.cropping[1][1]
133 if input_shape[2]
134 else None,
135 input_shape[3],
136 ]
137 )
139 def call(self, inputs):
141 if self.data_format == "channels_first":
142 if (
143 inputs.shape[2] is not None
144 and sum(self.cropping[0]) >= inputs.shape[2]
145 ) or (
146 inputs.shape[3] is not None
147 and sum(self.cropping[1]) >= inputs.shape[3]
148 ):
149 raise ValueError(
150 "Argument `cropping` must be "
151 "greater than the input shape. Received: inputs.shape="
152 f"{inputs.shape}, and cropping={self.cropping}"
153 )
154 if self.cropping[0][1] == self.cropping[1][1] == 0:
155 return inputs[
156 :, :, self.cropping[0][0] :, self.cropping[1][0] :
157 ]
158 elif self.cropping[0][1] == 0:
159 return inputs[
160 :,
161 :,
162 self.cropping[0][0] :,
163 self.cropping[1][0] : -self.cropping[1][1],
164 ]
165 elif self.cropping[1][1] == 0:
166 return inputs[
167 :,
168 :,
169 self.cropping[0][0] : -self.cropping[0][1],
170 self.cropping[1][0] :,
171 ]
172 return inputs[
173 :,
174 :,
175 self.cropping[0][0] : -self.cropping[0][1],
176 self.cropping[1][0] : -self.cropping[1][1],
177 ]
178 else:
179 if (
180 inputs.shape[1] is not None
181 and sum(self.cropping[0]) >= inputs.shape[1]
182 ) or (
183 inputs.shape[2] is not None
184 and sum(self.cropping[1]) >= inputs.shape[2]
185 ):
186 raise ValueError(
187 "Argument `cropping` must be "
188 "greater than the input shape. Received: inputs.shape="
189 f"{inputs.shape}, and cropping={self.cropping}"
190 )
191 if self.cropping[0][1] == self.cropping[1][1] == 0:
192 return inputs[
193 :, self.cropping[0][0] :, self.cropping[1][0] :, :
194 ]
195 elif self.cropping[0][1] == 0:
196 return inputs[
197 :,
198 self.cropping[0][0] :,
199 self.cropping[1][0] : -self.cropping[1][1],
200 :,
201 ]
202 elif self.cropping[1][1] == 0:
203 return inputs[
204 :,
205 self.cropping[0][0] : -self.cropping[0][1],
206 self.cropping[1][0] :,
207 :,
208 ]
209 return inputs[
210 :,
211 self.cropping[0][0] : -self.cropping[0][1],
212 self.cropping[1][0] : -self.cropping[1][1],
213 :,
214 ]
216 def get_config(self):
217 config = {"cropping": self.cropping, "data_format": self.data_format}
218 base_config = super().get_config()
219 return dict(list(base_config.items()) + list(config.items()))