Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/cropping3d.py: 15%

82 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras cropping layer for 3D input.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.engine.base_layer import Layer 

21from keras.src.engine.input_spec import InputSpec 

22from keras.src.utils import conv_utils 

23 

24# isort: off 

25from tensorflow.python.util.tf_export import keras_export 

26 

27 

28@keras_export("keras.layers.Cropping3D") 

29class Cropping3D(Layer): 

30 """Cropping layer for 3D data (e.g. spatial or spatio-temporal). 

31 

32 Examples: 

33 

34 >>> input_shape = (2, 28, 28, 10, 3) 

35 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) 

36 >>> y = tf.keras.layers.Cropping3D(cropping=(2, 4, 2))(x) 

37 >>> print(y.shape) 

38 (2, 24, 20, 6, 3) 

39 

40 Args: 

41 cropping: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints. 

42 - If int: the same symmetric cropping 

43 is applied to depth, height, and width. 

44 - If tuple of 3 ints: interpreted as two different 

45 symmetric cropping values for depth, height, and width: 

46 `(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`. 

47 - If tuple of 3 tuples of 2 ints: interpreted as 

48 `((left_dim1_crop, right_dim1_crop), (left_dim2_crop, 

49 right_dim2_crop), (left_dim3_crop, right_dim3_crop))` 

50 data_format: A string, 

51 one of `channels_last` (default) or `channels_first`. 

52 The ordering of the dimensions in the inputs. 

53 `channels_last` corresponds to inputs with shape 

54 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 

55 while `channels_first` corresponds to inputs with shape 

56 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 

57 When unspecified, uses 

58 `image_data_format` value found in your Keras config file at 

59 `~/.keras/keras.json` (if exists) else 'channels_last'. 

60 Defaults to 'channels_last'. 

61 

62 Input shape: 

63 5D tensor with shape: 

64 - If `data_format` is `"channels_last"`: 

65 `(batch_size, first_axis_to_crop, second_axis_to_crop, 

66 third_axis_to_crop, depth)` 

67 - If `data_format` is `"channels_first"`: 

68 `(batch_size, depth, first_axis_to_crop, second_axis_to_crop, 

69 third_axis_to_crop)` 

70 

71 Output shape: 

72 5D tensor with shape: 

73 - If `data_format` is `"channels_last"`: 

74 `(batch_size, first_cropped_axis, second_cropped_axis, 

75 third_cropped_axis, depth)` 

76 - If `data_format` is `"channels_first"`: 

77 `(batch_size, depth, first_cropped_axis, second_cropped_axis, 

78 third_cropped_axis)` 

79 """ 

80 

81 def __init__( 

82 self, cropping=((1, 1), (1, 1), (1, 1)), data_format=None, **kwargs 

83 ): 

84 super().__init__(**kwargs) 

85 self.data_format = conv_utils.normalize_data_format(data_format) 

86 if isinstance(cropping, int): 

87 self.cropping = ( 

88 (cropping, cropping), 

89 (cropping, cropping), 

90 (cropping, cropping), 

91 ) 

92 elif hasattr(cropping, "__len__"): 

93 if len(cropping) != 3: 

94 raise ValueError( 

95 f"`cropping` should have 3 elements. Received: {cropping}." 

96 ) 

97 dim1_cropping = conv_utils.normalize_tuple( 

98 cropping[0], 2, "1st entry of cropping", allow_zero=True 

99 ) 

100 dim2_cropping = conv_utils.normalize_tuple( 

101 cropping[1], 2, "2nd entry of cropping", allow_zero=True 

102 ) 

103 dim3_cropping = conv_utils.normalize_tuple( 

104 cropping[2], 2, "3rd entry of cropping", allow_zero=True 

105 ) 

106 self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping) 

107 else: 

108 raise ValueError( 

109 "`cropping` should be either an int, " 

110 "a tuple of 3 ints " 

111 "(symmetric_dim1_crop, symmetric_dim2_crop, " 

112 "symmetric_dim3_crop), " 

113 "or a tuple of 3 tuples of 2 ints " 

114 "((left_dim1_crop, right_dim1_crop)," 

115 " (left_dim2_crop, right_dim2_crop)," 

116 " (left_dim3_crop, right_dim2_crop)). " 

117 f"Received: {cropping}." 

118 ) 

119 self.input_spec = InputSpec(ndim=5) 

120 

121 def compute_output_shape(self, input_shape): 

122 input_shape = tf.TensorShape(input_shape).as_list() 

123 

124 if self.data_format == "channels_first": 

125 if input_shape[2] is not None: 

126 dim1 = ( 

127 input_shape[2] - self.cropping[0][0] - self.cropping[0][1] 

128 ) 

129 else: 

130 dim1 = None 

131 if input_shape[3] is not None: 

132 dim2 = ( 

133 input_shape[3] - self.cropping[1][0] - self.cropping[1][1] 

134 ) 

135 else: 

136 dim2 = None 

137 if input_shape[4] is not None: 

138 dim3 = ( 

139 input_shape[4] - self.cropping[2][0] - self.cropping[2][1] 

140 ) 

141 else: 

142 dim3 = None 

143 return tf.TensorShape( 

144 [input_shape[0], input_shape[1], dim1, dim2, dim3] 

145 ) 

146 elif self.data_format == "channels_last": 

147 if input_shape[1] is not None: 

148 dim1 = ( 

149 input_shape[1] - self.cropping[0][0] - self.cropping[0][1] 

150 ) 

151 else: 

152 dim1 = None 

153 if input_shape[2] is not None: 

154 dim2 = ( 

155 input_shape[2] - self.cropping[1][0] - self.cropping[1][1] 

156 ) 

157 else: 

158 dim2 = None 

159 if input_shape[3] is not None: 

160 dim3 = ( 

161 input_shape[3] - self.cropping[2][0] - self.cropping[2][1] 

162 ) 

163 else: 

164 dim3 = None 

165 return tf.TensorShape( 

166 [input_shape[0], dim1, dim2, dim3, input_shape[4]] 

167 ) 

168 

169 def call(self, inputs): 

170 

171 if self.data_format == "channels_first": 

172 if ( 

173 self.cropping[0][1] 

174 == self.cropping[1][1] 

175 == self.cropping[2][1] 

176 == 0 

177 ): 

178 return inputs[ 

179 :, 

180 :, 

181 self.cropping[0][0] :, 

182 self.cropping[1][0] :, 

183 self.cropping[2][0] :, 

184 ] 

185 elif self.cropping[0][1] == self.cropping[1][1] == 0: 

186 return inputs[ 

187 :, 

188 :, 

189 self.cropping[0][0] :, 

190 self.cropping[1][0] :, 

191 self.cropping[2][0] : -self.cropping[2][1], 

192 ] 

193 elif self.cropping[1][1] == self.cropping[2][1] == 0: 

194 return inputs[ 

195 :, 

196 :, 

197 self.cropping[0][0] : -self.cropping[0][1], 

198 self.cropping[1][0] :, 

199 self.cropping[2][0] :, 

200 ] 

201 elif self.cropping[0][1] == self.cropping[2][1] == 0: 

202 return inputs[ 

203 :, 

204 :, 

205 self.cropping[0][0] :, 

206 self.cropping[1][0] : -self.cropping[1][1], 

207 self.cropping[2][0] :, 

208 ] 

209 elif self.cropping[0][1] == 0: 

210 return inputs[ 

211 :, 

212 :, 

213 self.cropping[0][0] :, 

214 self.cropping[1][0] : -self.cropping[1][1], 

215 self.cropping[2][0] : -self.cropping[2][1], 

216 ] 

217 elif self.cropping[1][1] == 0: 

218 return inputs[ 

219 :, 

220 :, 

221 self.cropping[0][0] : -self.cropping[0][1], 

222 self.cropping[1][0] :, 

223 self.cropping[2][0] : -self.cropping[2][1], 

224 ] 

225 elif self.cropping[2][1] == 0: 

226 return inputs[ 

227 :, 

228 :, 

229 self.cropping[0][0] : -self.cropping[0][1], 

230 self.cropping[1][0] : -self.cropping[1][1], 

231 self.cropping[2][0] :, 

232 ] 

233 return inputs[ 

234 :, 

235 :, 

236 self.cropping[0][0] : -self.cropping[0][1], 

237 self.cropping[1][0] : -self.cropping[1][1], 

238 self.cropping[2][0] : -self.cropping[2][1], 

239 ] 

240 else: 

241 if ( 

242 self.cropping[0][1] 

243 == self.cropping[1][1] 

244 == self.cropping[2][1] 

245 == 0 

246 ): 

247 return inputs[ 

248 :, 

249 self.cropping[0][0] :, 

250 self.cropping[1][0] :, 

251 self.cropping[2][0] :, 

252 :, 

253 ] 

254 elif self.cropping[0][1] == self.cropping[1][1] == 0: 

255 return inputs[ 

256 :, 

257 self.cropping[0][0] :, 

258 self.cropping[1][0] :, 

259 self.cropping[2][0] : -self.cropping[2][1], 

260 :, 

261 ] 

262 elif self.cropping[1][1] == self.cropping[2][1] == 0: 

263 return inputs[ 

264 :, 

265 self.cropping[0][0] : -self.cropping[0][1], 

266 self.cropping[1][0] :, 

267 self.cropping[2][0] :, 

268 :, 

269 ] 

270 elif self.cropping[0][1] == self.cropping[2][1] == 0: 

271 return inputs[ 

272 :, 

273 self.cropping[0][0] :, 

274 self.cropping[1][0] : -self.cropping[1][1], 

275 self.cropping[2][0] :, 

276 :, 

277 ] 

278 elif self.cropping[0][1] == 0: 

279 return inputs[ 

280 :, 

281 self.cropping[0][0] :, 

282 self.cropping[1][0] : -self.cropping[1][1], 

283 self.cropping[2][0] : -self.cropping[2][1], 

284 :, 

285 ] 

286 elif self.cropping[1][1] == 0: 

287 return inputs[ 

288 :, 

289 self.cropping[0][0] : -self.cropping[0][1], 

290 self.cropping[1][0] :, 

291 self.cropping[2][0] : -self.cropping[2][1], 

292 :, 

293 ] 

294 elif self.cropping[2][1] == 0: 

295 return inputs[ 

296 :, 

297 self.cropping[0][0] : -self.cropping[0][1], 

298 self.cropping[1][0] : -self.cropping[1][1], 

299 self.cropping[2][0] :, 

300 :, 

301 ] 

302 return inputs[ 

303 :, 

304 self.cropping[0][0] : -self.cropping[0][1], 

305 self.cropping[1][0] : -self.cropping[1][1], 

306 self.cropping[2][0] : -self.cropping[2][1], 

307 :, 

308 ] 

309 

310 def get_config(self): 

311 config = {"cropping": self.cropping, "data_format": self.data_format} 

312 base_config = super().get_config() 

313 return dict(list(base_config.items()) + list(config.items())) 

314