Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/cropping2d.py: 24%

50 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras cropping layer for 2D input.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.engine.base_layer import Layer 

21from keras.src.engine.input_spec import InputSpec 

22from keras.src.utils import conv_utils 

23 

24# isort: off 

25from tensorflow.python.util.tf_export import keras_export 

26 

27 

28@keras_export("keras.layers.Cropping2D") 

29class Cropping2D(Layer): 

30 """Cropping layer for 2D input (e.g. picture). 

31 

32 It crops along spatial dimensions, i.e. height and width. 

33 

34 Examples: 

35 

36 >>> input_shape = (2, 28, 28, 3) 

37 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) 

38 >>> y = tf.keras.layers.Cropping2D(cropping=((2, 2), (4, 4)))(x) 

39 >>> print(y.shape) 

40 (2, 24, 20, 3) 

41 

42 Args: 

43 cropping: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. 

44 - If int: the same symmetric cropping 

45 is applied to height and width. 

46 - If tuple of 2 ints: 

47 interpreted as two different 

48 symmetric cropping values for height and width: 

49 `(symmetric_height_crop, symmetric_width_crop)`. 

50 - If tuple of 2 tuples of 2 ints: 

51 interpreted as 

52 `((top_crop, bottom_crop), (left_crop, right_crop))` 

53 data_format: A string, 

54 one of `channels_last` (default) or `channels_first`. 

55 The ordering of the dimensions in the inputs. 

56 `channels_last` corresponds to inputs with shape 

57 `(batch_size, height, width, channels)` while `channels_first` 

58 corresponds to inputs with shape 

59 `(batch_size, channels, height, width)`. 

60 When unspecified, uses 

61 `image_data_format` value found in your Keras config file at 

62 `~/.keras/keras.json` (if exists) else 'channels_last'. 

63 Defaults to 'channels_last'. 

64 

65 Input shape: 

66 4D tensor with shape: 

67 - If `data_format` is `"channels_last"`: 

68 `(batch_size, rows, cols, channels)` 

69 - If `data_format` is `"channels_first"`: 

70 `(batch_size, channels, rows, cols)` 

71 

72 Output shape: 

73 4D tensor with shape: 

74 - If `data_format` is `"channels_last"`: 

75 `(batch_size, cropped_rows, cropped_cols, channels)` 

76 - If `data_format` is `"channels_first"`: 

77 `(batch_size, channels, cropped_rows, cropped_cols)` 

78 """ 

79 

80 def __init__(self, cropping=((0, 0), (0, 0)), data_format=None, **kwargs): 

81 super().__init__(**kwargs) 

82 self.data_format = conv_utils.normalize_data_format(data_format) 

83 if isinstance(cropping, int): 

84 self.cropping = ((cropping, cropping), (cropping, cropping)) 

85 elif hasattr(cropping, "__len__"): 

86 if len(cropping) != 2: 

87 raise ValueError( 

88 "`cropping` should have two elements. " 

89 f"Received: {cropping}." 

90 ) 

91 height_cropping = conv_utils.normalize_tuple( 

92 cropping[0], 2, "1st entry of cropping", allow_zero=True 

93 ) 

94 width_cropping = conv_utils.normalize_tuple( 

95 cropping[1], 2, "2nd entry of cropping", allow_zero=True 

96 ) 

97 self.cropping = (height_cropping, width_cropping) 

98 else: 

99 raise ValueError( 

100 "`cropping` should be either an int, " 

101 "a tuple of 2 ints " 

102 "(symmetric_height_crop, symmetric_width_crop), " 

103 "or a tuple of 2 tuples of 2 ints " 

104 "((top_crop, bottom_crop), (left_crop, right_crop)). " 

105 f"Received: {cropping}." 

106 ) 

107 self.input_spec = InputSpec(ndim=4) 

108 

109 def compute_output_shape(self, input_shape): 

110 input_shape = tf.TensorShape(input_shape).as_list() 

111 

112 if self.data_format == "channels_first": 

113 return tf.TensorShape( 

114 [ 

115 input_shape[0], 

116 input_shape[1], 

117 input_shape[2] - self.cropping[0][0] - self.cropping[0][1] 

118 if input_shape[2] 

119 else None, 

120 input_shape[3] - self.cropping[1][0] - self.cropping[1][1] 

121 if input_shape[3] 

122 else None, 

123 ] 

124 ) 

125 else: 

126 return tf.TensorShape( 

127 [ 

128 input_shape[0], 

129 input_shape[1] - self.cropping[0][0] - self.cropping[0][1] 

130 if input_shape[1] 

131 else None, 

132 input_shape[2] - self.cropping[1][0] - self.cropping[1][1] 

133 if input_shape[2] 

134 else None, 

135 input_shape[3], 

136 ] 

137 ) 

138 

139 def call(self, inputs): 

140 

141 if self.data_format == "channels_first": 

142 if ( 

143 inputs.shape[2] is not None 

144 and sum(self.cropping[0]) >= inputs.shape[2] 

145 ) or ( 

146 inputs.shape[3] is not None 

147 and sum(self.cropping[1]) >= inputs.shape[3] 

148 ): 

149 raise ValueError( 

150 "Argument `cropping` must be " 

151 "greater than the input shape. Received: inputs.shape=" 

152 f"{inputs.shape}, and cropping={self.cropping}" 

153 ) 

154 if self.cropping[0][1] == self.cropping[1][1] == 0: 

155 return inputs[ 

156 :, :, self.cropping[0][0] :, self.cropping[1][0] : 

157 ] 

158 elif self.cropping[0][1] == 0: 

159 return inputs[ 

160 :, 

161 :, 

162 self.cropping[0][0] :, 

163 self.cropping[1][0] : -self.cropping[1][1], 

164 ] 

165 elif self.cropping[1][1] == 0: 

166 return inputs[ 

167 :, 

168 :, 

169 self.cropping[0][0] : -self.cropping[0][1], 

170 self.cropping[1][0] :, 

171 ] 

172 return inputs[ 

173 :, 

174 :, 

175 self.cropping[0][0] : -self.cropping[0][1], 

176 self.cropping[1][0] : -self.cropping[1][1], 

177 ] 

178 else: 

179 if ( 

180 inputs.shape[1] is not None 

181 and sum(self.cropping[0]) >= inputs.shape[1] 

182 ) or ( 

183 inputs.shape[2] is not None 

184 and sum(self.cropping[1]) >= inputs.shape[2] 

185 ): 

186 raise ValueError( 

187 "Argument `cropping` must be " 

188 "greater than the input shape. Received: inputs.shape=" 

189 f"{inputs.shape}, and cropping={self.cropping}" 

190 ) 

191 if self.cropping[0][1] == self.cropping[1][1] == 0: 

192 return inputs[ 

193 :, self.cropping[0][0] :, self.cropping[1][0] :, : 

194 ] 

195 elif self.cropping[0][1] == 0: 

196 return inputs[ 

197 :, 

198 self.cropping[0][0] :, 

199 self.cropping[1][0] : -self.cropping[1][1], 

200 :, 

201 ] 

202 elif self.cropping[1][1] == 0: 

203 return inputs[ 

204 :, 

205 self.cropping[0][0] : -self.cropping[0][1], 

206 self.cropping[1][0] :, 

207 :, 

208 ] 

209 return inputs[ 

210 :, 

211 self.cropping[0][0] : -self.cropping[0][1], 

212 self.cropping[1][0] : -self.cropping[1][1], 

213 :, 

214 ] 

215 

216 def get_config(self): 

217 config = {"cropping": self.cropping, "data_format": self.data_format} 

218 base_config = super().get_config() 

219 return dict(list(base_config.items()) + list(config.items())) 

220