Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/cropping3d.py: 15%
82 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras cropping layer for 3D input."""
18import tensorflow.compat.v2 as tf
20from keras.src.engine.base_layer import Layer
21from keras.src.engine.input_spec import InputSpec
22from keras.src.utils import conv_utils
24# isort: off
25from tensorflow.python.util.tf_export import keras_export
28@keras_export("keras.layers.Cropping3D")
29class Cropping3D(Layer):
30 """Cropping layer for 3D data (e.g. spatial or spatio-temporal).
32 Examples:
34 >>> input_shape = (2, 28, 28, 10, 3)
35 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
36 >>> y = tf.keras.layers.Cropping3D(cropping=(2, 4, 2))(x)
37 >>> print(y.shape)
38 (2, 24, 20, 6, 3)
40 Args:
41 cropping: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
42 - If int: the same symmetric cropping
43 is applied to depth, height, and width.
44 - If tuple of 3 ints: interpreted as two different
45 symmetric cropping values for depth, height, and width:
46 `(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`.
47 - If tuple of 3 tuples of 2 ints: interpreted as
48 `((left_dim1_crop, right_dim1_crop), (left_dim2_crop,
49 right_dim2_crop), (left_dim3_crop, right_dim3_crop))`
50 data_format: A string,
51 one of `channels_last` (default) or `channels_first`.
52 The ordering of the dimensions in the inputs.
53 `channels_last` corresponds to inputs with shape
54 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
55 while `channels_first` corresponds to inputs with shape
56 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
57 When unspecified, uses
58 `image_data_format` value found in your Keras config file at
59 `~/.keras/keras.json` (if exists) else 'channels_last'.
60 Defaults to 'channels_last'.
62 Input shape:
63 5D tensor with shape:
64 - If `data_format` is `"channels_last"`:
65 `(batch_size, first_axis_to_crop, second_axis_to_crop,
66 third_axis_to_crop, depth)`
67 - If `data_format` is `"channels_first"`:
68 `(batch_size, depth, first_axis_to_crop, second_axis_to_crop,
69 third_axis_to_crop)`
71 Output shape:
72 5D tensor with shape:
73 - If `data_format` is `"channels_last"`:
74 `(batch_size, first_cropped_axis, second_cropped_axis,
75 third_cropped_axis, depth)`
76 - If `data_format` is `"channels_first"`:
77 `(batch_size, depth, first_cropped_axis, second_cropped_axis,
78 third_cropped_axis)`
79 """
81 def __init__(
82 self, cropping=((1, 1), (1, 1), (1, 1)), data_format=None, **kwargs
83 ):
84 super().__init__(**kwargs)
85 self.data_format = conv_utils.normalize_data_format(data_format)
86 if isinstance(cropping, int):
87 self.cropping = (
88 (cropping, cropping),
89 (cropping, cropping),
90 (cropping, cropping),
91 )
92 elif hasattr(cropping, "__len__"):
93 if len(cropping) != 3:
94 raise ValueError(
95 f"`cropping` should have 3 elements. Received: {cropping}."
96 )
97 dim1_cropping = conv_utils.normalize_tuple(
98 cropping[0], 2, "1st entry of cropping", allow_zero=True
99 )
100 dim2_cropping = conv_utils.normalize_tuple(
101 cropping[1], 2, "2nd entry of cropping", allow_zero=True
102 )
103 dim3_cropping = conv_utils.normalize_tuple(
104 cropping[2], 2, "3rd entry of cropping", allow_zero=True
105 )
106 self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping)
107 else:
108 raise ValueError(
109 "`cropping` should be either an int, "
110 "a tuple of 3 ints "
111 "(symmetric_dim1_crop, symmetric_dim2_crop, "
112 "symmetric_dim3_crop), "
113 "or a tuple of 3 tuples of 2 ints "
114 "((left_dim1_crop, right_dim1_crop),"
115 " (left_dim2_crop, right_dim2_crop),"
116 " (left_dim3_crop, right_dim2_crop)). "
117 f"Received: {cropping}."
118 )
119 self.input_spec = InputSpec(ndim=5)
121 def compute_output_shape(self, input_shape):
122 input_shape = tf.TensorShape(input_shape).as_list()
124 if self.data_format == "channels_first":
125 if input_shape[2] is not None:
126 dim1 = (
127 input_shape[2] - self.cropping[0][0] - self.cropping[0][1]
128 )
129 else:
130 dim1 = None
131 if input_shape[3] is not None:
132 dim2 = (
133 input_shape[3] - self.cropping[1][0] - self.cropping[1][1]
134 )
135 else:
136 dim2 = None
137 if input_shape[4] is not None:
138 dim3 = (
139 input_shape[4] - self.cropping[2][0] - self.cropping[2][1]
140 )
141 else:
142 dim3 = None
143 return tf.TensorShape(
144 [input_shape[0], input_shape[1], dim1, dim2, dim3]
145 )
146 elif self.data_format == "channels_last":
147 if input_shape[1] is not None:
148 dim1 = (
149 input_shape[1] - self.cropping[0][0] - self.cropping[0][1]
150 )
151 else:
152 dim1 = None
153 if input_shape[2] is not None:
154 dim2 = (
155 input_shape[2] - self.cropping[1][0] - self.cropping[1][1]
156 )
157 else:
158 dim2 = None
159 if input_shape[3] is not None:
160 dim3 = (
161 input_shape[3] - self.cropping[2][0] - self.cropping[2][1]
162 )
163 else:
164 dim3 = None
165 return tf.TensorShape(
166 [input_shape[0], dim1, dim2, dim3, input_shape[4]]
167 )
169 def call(self, inputs):
171 if self.data_format == "channels_first":
172 if (
173 self.cropping[0][1]
174 == self.cropping[1][1]
175 == self.cropping[2][1]
176 == 0
177 ):
178 return inputs[
179 :,
180 :,
181 self.cropping[0][0] :,
182 self.cropping[1][0] :,
183 self.cropping[2][0] :,
184 ]
185 elif self.cropping[0][1] == self.cropping[1][1] == 0:
186 return inputs[
187 :,
188 :,
189 self.cropping[0][0] :,
190 self.cropping[1][0] :,
191 self.cropping[2][0] : -self.cropping[2][1],
192 ]
193 elif self.cropping[1][1] == self.cropping[2][1] == 0:
194 return inputs[
195 :,
196 :,
197 self.cropping[0][0] : -self.cropping[0][1],
198 self.cropping[1][0] :,
199 self.cropping[2][0] :,
200 ]
201 elif self.cropping[0][1] == self.cropping[2][1] == 0:
202 return inputs[
203 :,
204 :,
205 self.cropping[0][0] :,
206 self.cropping[1][0] : -self.cropping[1][1],
207 self.cropping[2][0] :,
208 ]
209 elif self.cropping[0][1] == 0:
210 return inputs[
211 :,
212 :,
213 self.cropping[0][0] :,
214 self.cropping[1][0] : -self.cropping[1][1],
215 self.cropping[2][0] : -self.cropping[2][1],
216 ]
217 elif self.cropping[1][1] == 0:
218 return inputs[
219 :,
220 :,
221 self.cropping[0][0] : -self.cropping[0][1],
222 self.cropping[1][0] :,
223 self.cropping[2][0] : -self.cropping[2][1],
224 ]
225 elif self.cropping[2][1] == 0:
226 return inputs[
227 :,
228 :,
229 self.cropping[0][0] : -self.cropping[0][1],
230 self.cropping[1][0] : -self.cropping[1][1],
231 self.cropping[2][0] :,
232 ]
233 return inputs[
234 :,
235 :,
236 self.cropping[0][0] : -self.cropping[0][1],
237 self.cropping[1][0] : -self.cropping[1][1],
238 self.cropping[2][0] : -self.cropping[2][1],
239 ]
240 else:
241 if (
242 self.cropping[0][1]
243 == self.cropping[1][1]
244 == self.cropping[2][1]
245 == 0
246 ):
247 return inputs[
248 :,
249 self.cropping[0][0] :,
250 self.cropping[1][0] :,
251 self.cropping[2][0] :,
252 :,
253 ]
254 elif self.cropping[0][1] == self.cropping[1][1] == 0:
255 return inputs[
256 :,
257 self.cropping[0][0] :,
258 self.cropping[1][0] :,
259 self.cropping[2][0] : -self.cropping[2][1],
260 :,
261 ]
262 elif self.cropping[1][1] == self.cropping[2][1] == 0:
263 return inputs[
264 :,
265 self.cropping[0][0] : -self.cropping[0][1],
266 self.cropping[1][0] :,
267 self.cropping[2][0] :,
268 :,
269 ]
270 elif self.cropping[0][1] == self.cropping[2][1] == 0:
271 return inputs[
272 :,
273 self.cropping[0][0] :,
274 self.cropping[1][0] : -self.cropping[1][1],
275 self.cropping[2][0] :,
276 :,
277 ]
278 elif self.cropping[0][1] == 0:
279 return inputs[
280 :,
281 self.cropping[0][0] :,
282 self.cropping[1][0] : -self.cropping[1][1],
283 self.cropping[2][0] : -self.cropping[2][1],
284 :,
285 ]
286 elif self.cropping[1][1] == 0:
287 return inputs[
288 :,
289 self.cropping[0][0] : -self.cropping[0][1],
290 self.cropping[1][0] :,
291 self.cropping[2][0] : -self.cropping[2][1],
292 :,
293 ]
294 elif self.cropping[2][1] == 0:
295 return inputs[
296 :,
297 self.cropping[0][0] : -self.cropping[0][1],
298 self.cropping[1][0] : -self.cropping[1][1],
299 self.cropping[2][0] :,
300 :,
301 ]
302 return inputs[
303 :,
304 self.cropping[0][0] : -self.cropping[0][1],
305 self.cropping[1][0] : -self.cropping[1][1],
306 self.cropping[2][0] : -self.cropping[2][1],
307 :,
308 ]
310 def get_config(self):
311 config = {"cropping": self.cropping, "data_format": self.data_format}
312 base_config = super().get_config()
313 return dict(list(base_config.items()) + list(config.items()))