Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/transform_ops.py: 18%

85 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Image transform ops.""" 

16 

17import tensorflow as tf 

18from tensorflow_addons.image import utils as img_utils 

19from tensorflow_addons.utils.types import TensorLike 

20from tensorflow_addons.image.utils import wrap, unwrap 

21 

22from typing import Optional 

23 

24 

25_IMAGE_DTYPES = { 

26 tf.dtypes.uint8, 

27 tf.dtypes.int32, 

28 tf.dtypes.int64, 

29 tf.dtypes.float16, 

30 tf.dtypes.float32, 

31 tf.dtypes.float64, 

32} 

33 

34 

35def transform( 

36 images: TensorLike, 

37 transforms: TensorLike, 

38 interpolation: str = "nearest", 

39 fill_mode: str = "constant", 

40 output_shape: Optional[list] = None, 

41 name: Optional[str] = None, 

42 fill_value: TensorLike = 0.0, 

43) -> tf.Tensor: 

44 """Applies the given transform(s) to the image(s). 

45 

46 Args: 

47 images: A tensor of shape (num_images, num_rows, num_columns, 

48 num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or 

49 (num_rows, num_columns) (HW). 

50 transforms: Projective transform matrix/matrices. A vector of length 8 or 

51 tensor of size N x 8. If one row of transforms is 

52 [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point 

53 `(x, y)` to a transformed *input* point 

54 `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, 

55 where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to 

56 the transform mapping input points to output points. Note that 

57 gradients are not backpropagated into transformation parameters. 

58 interpolation: Interpolation mode. 

59 Supported values: "nearest", "bilinear". 

60 fill_mode: Points outside the boundaries of the input are filled according 

61 to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). 

62 - *reflect*: `(d c b a | a b c d | d c b a)` 

63 The input is extended by reflecting about the edge of the last pixel. 

64 - *constant*: `(k k k k | a b c d | k k k k)` 

65 The input is extended by filling all values beyond the edge with the 

66 same constant value k = 0. 

67 - *wrap*: `(a b c d | a b c d | a b c d)` 

68 The input is extended by wrapping around to the opposite edge. 

69 - *nearest*: `(a a a a | a b c d | d d d d)` 

70 The input is extended by the nearest pixel. 

71 fill_value: a float represents the value to be filled outside the 

72 boundaries when `fill_mode` is "constant". 

73 output_shape: Output dimesion after the transform, [height, width]. 

74 If None, output is the same size as input image. 

75 

76 name: The name of the op. 

77 

78 Returns: 

79 Image(s) with the same type and shape as `images`, with the given 

80 transform(s) applied. Transformed coordinates outside of the input image 

81 will be filled with zeros. 

82 

83 Raises: 

84 TypeError: If `image` is an invalid type. 

85 ValueError: If output shape is not 1-D int32 Tensor. 

86 """ 

87 with tf.name_scope(name or "transform"): 

88 image_or_images = tf.convert_to_tensor(images, name="images") 

89 transform_or_transforms = tf.convert_to_tensor( 

90 transforms, name="transforms", dtype=tf.dtypes.float32 

91 ) 

92 if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: 

93 raise TypeError("Invalid dtype %s." % image_or_images.dtype) 

94 images = img_utils.to_4D_image(image_or_images) 

95 original_ndims = img_utils.get_ndims(image_or_images) 

96 

97 if output_shape is None: 

98 output_shape = tf.shape(images)[1:3] 

99 

100 output_shape = tf.convert_to_tensor( 

101 output_shape, tf.dtypes.int32, name="output_shape" 

102 ) 

103 

104 if not output_shape.get_shape().is_compatible_with([2]): 

105 raise ValueError( 

106 "output_shape must be a 1-D Tensor of 2 elements: " 

107 "new_height, new_width" 

108 ) 

109 

110 if len(transform_or_transforms.get_shape()) == 1: 

111 transforms = transform_or_transforms[None] 

112 elif transform_or_transforms.get_shape().ndims is None: 

113 raise ValueError("transforms rank must be statically known") 

114 elif len(transform_or_transforms.get_shape()) == 2: 

115 transforms = transform_or_transforms 

116 else: 

117 transforms = transform_or_transforms 

118 raise ValueError( 

119 "transforms should have rank 1 or 2, but got rank %d" 

120 % len(transforms.get_shape()) 

121 ) 

122 

123 fill_value = tf.convert_to_tensor( 

124 fill_value, dtype=tf.float32, name="fill_value" 

125 ) 

126 output = tf.raw_ops.ImageProjectiveTransformV3( 

127 images=images, 

128 transforms=transforms, 

129 output_shape=output_shape, 

130 interpolation=interpolation.upper(), 

131 fill_mode=fill_mode.upper(), 

132 fill_value=fill_value, 

133 ) 

134 return img_utils.from_4D_image(output, original_ndims) 

135 

136 

137def compose_transforms(transforms: TensorLike, name: Optional[str] = None) -> tf.Tensor: 

138 """Composes the transforms tensors. 

139 

140 Args: 

141 transforms: List of image projective transforms to be composed. Each 

142 transform is length 8 (single transform) or shape (N, 8) (batched 

143 transforms). The shapes of all inputs must be equal, and at least one 

144 input must be given. 

145 name: The name for the op. 

146 

147 Returns: 

148 A composed transform tensor. When passed to `transform` op, 

149 equivalent to applying each of the given transforms to the image in 

150 order. 

151 """ 

152 assert transforms, "transforms cannot be empty" 

153 with tf.name_scope(name or "compose_transforms"): 

154 composed = flat_transforms_to_matrices(transforms[0]) 

155 for tr in transforms[1:]: 

156 # Multiply batches of matrices. 

157 composed = tf.matmul(composed, flat_transforms_to_matrices(tr)) 

158 return matrices_to_flat_transforms(composed) 

159 

160 

161def flat_transforms_to_matrices( 

162 transforms: TensorLike, name: Optional[str] = None 

163) -> tf.Tensor: 

164 """Converts projective transforms to affine matrices. 

165 

166 Note that the output matrices map output coordinates to input coordinates. 

167 For the forward transformation matrix, call `tf.linalg.inv` on the result. 

168 

169 Args: 

170 transforms: Vector of length 8, or batches of transforms with shape 

171 `(N, 8)`. 

172 name: The name for the op. 

173 

174 Returns: 

175 3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the 

176 *output coordinates* (in homogeneous coordinates) of each transform to 

177 the corresponding *input coordinates*. 

178 

179 Raises: 

180 ValueError: If `transforms` have an invalid shape. 

181 """ 

182 with tf.name_scope(name or "flat_transforms_to_matrices"): 

183 transforms = tf.convert_to_tensor(transforms, name="transforms") 

184 if transforms.shape.ndims not in (1, 2): 

185 raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms) 

186 # Make the transform(s) 2D in case the input is a single transform. 

187 transforms = tf.reshape(transforms, tf.constant([-1, 8])) 

188 num_transforms = tf.shape(transforms)[0] 

189 # Add a column of ones for the implicit last entry in the matrix. 

190 return tf.reshape( 

191 tf.concat([transforms, tf.ones([num_transforms, 1])], axis=1), 

192 tf.constant([-1, 3, 3]), 

193 ) 

194 

195 

196def matrices_to_flat_transforms( 

197 transform_matrices: TensorLike, name: Optional[str] = None 

198) -> tf.Tensor: 

199 """Converts affine matrices to projective transforms. 

200 

201 Note that we expect matrices that map output coordinates to input 

202 coordinates. To convert forward transformation matrices, 

203 call `tf.linalg.inv` on the matrices and use the result here. 

204 

205 Args: 

206 transform_matrices: One or more affine transformation matrices, for the 

207 reverse transformation in homogeneous coordinates. Shape `(3, 3)` or 

208 `(N, 3, 3)`. 

209 name: The name for the op. 

210 

211 Returns: 

212 2D tensor of flat transforms with shape `(N, 8)`, which may be passed 

213 into `transform` op. 

214 

215 Raises: 

216 ValueError: If `transform_matrices` have an invalid shape. 

217 """ 

218 with tf.name_scope(name or "matrices_to_flat_transforms"): 

219 transform_matrices = tf.convert_to_tensor( 

220 transform_matrices, name="transform_matrices" 

221 ) 

222 if transform_matrices.shape.ndims not in (2, 3): 

223 raise ValueError( 

224 "Matrices should be 2D or 3D, got: %s" % transform_matrices 

225 ) 

226 # Flatten each matrix. 

227 transforms = tf.reshape(transform_matrices, tf.constant([-1, 9])) 

228 # Divide each matrix by the last entry (normally 1). 

229 transforms /= transforms[:, 8:9] 

230 return transforms[:, :8] 

231 

232 

233def angles_to_projective_transforms( 

234 angles: TensorLike, 

235 image_height: TensorLike, 

236 image_width: TensorLike, 

237 name: Optional[str] = None, 

238) -> tf.Tensor: 

239 """Returns projective transform(s) for the given angle(s). 

240 

241 Args: 

242 angles: A scalar angle to rotate all images by, or (for batches of 

243 images) a vector with an angle to rotate each image in the batch. The 

244 rank must be statically known (the shape is not `TensorShape(None)`. 

245 image_height: Height of the image(s) to be transformed. 

246 image_width: Width of the image(s) to be transformed. 

247 

248 Returns: 

249 A tensor of shape (num_images, 8). Projective transforms which can be 

250 given to `transform` op. 

251 """ 

252 with tf.name_scope(name or "angles_to_projective_transforms"): 

253 angle_or_angles = tf.convert_to_tensor( 

254 angles, name="angles", dtype=tf.dtypes.float32 

255 ) 

256 if len(angle_or_angles.get_shape()) == 0: 

257 angles = angle_or_angles[None] 

258 elif len(angle_or_angles.get_shape()) == 1: 

259 angles = angle_or_angles 

260 else: 

261 raise ValueError("angles should have rank 0 or 1.") 

262 cos_angles = tf.math.cos(angles) 

263 sin_angles = tf.math.sin(angles) 

264 x_offset = ( 

265 (image_width - 1) 

266 - (cos_angles * (image_width - 1) - sin_angles * (image_height - 1)) 

267 ) / 2.0 

268 y_offset = ( 

269 (image_height - 1) 

270 - (sin_angles * (image_width - 1) + cos_angles * (image_height - 1)) 

271 ) / 2.0 

272 num_angles = tf.shape(angles)[0] 

273 return tf.concat( 

274 values=[ 

275 cos_angles[:, None], 

276 -sin_angles[:, None], 

277 x_offset[:, None], 

278 sin_angles[:, None], 

279 cos_angles[:, None], 

280 y_offset[:, None], 

281 tf.zeros((num_angles, 2), tf.dtypes.float32), 

282 ], 

283 axis=1, 

284 ) 

285 

286 

287def rotate( 

288 images: TensorLike, 

289 angles: TensorLike, 

290 interpolation: str = "nearest", 

291 fill_mode: str = "constant", 

292 name: Optional[str] = None, 

293 fill_value: TensorLike = 0.0, 

294) -> tf.Tensor: 

295 """Rotate image(s) counterclockwise by the passed angle(s) in radians. 

296 

297 Args: 

298 images: A tensor of shape 

299 `(num_images, num_rows, num_columns, num_channels)` 

300 (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or 

301 `(num_rows, num_columns)` (HW). 

302 angles: A scalar angle to rotate all images by, or (if `images` has rank 4) 

303 a vector of length num_images, with an angle for each image in the 

304 batch. 

305 interpolation: Interpolation mode. Supported values: "nearest", 

306 "bilinear". 

307 fill_mode: Points outside the boundaries of the input are filled according 

308 to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). 

309 - *reflect*: `(d c b a | a b c d | d c b a)` 

310 The input is extended by reflecting about the edge of the last pixel. 

311 - *constant*: `(k k k k | a b c d | k k k k)` 

312 The input is extended by filling all values beyond the edge with the 

313 same constant value k = 0. 

314 - *wrap*: `(a b c d | a b c d | a b c d)` 

315 The input is extended by wrapping around to the opposite edge. 

316 - *nearest*: `(a a a a | a b c d | d d d d)` 

317 The input is extended by the nearest pixel. 

318 fill_value: a float represents the value to be filled outside the 

319 boundaries when `fill_mode` is "constant". 

320 name: The name of the op. 

321 

322 Returns: 

323 Image(s) with the same type and shape as `images`, rotated by the given 

324 angle(s). Empty space due to the rotation will be filled with zeros. 

325 

326 Raises: 

327 TypeError: If `images` is an invalid type. 

328 """ 

329 with tf.name_scope(name or "rotate"): 

330 image_or_images = tf.convert_to_tensor(images) 

331 if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: 

332 raise TypeError("Invalid dtype %s." % image_or_images.dtype) 

333 images = img_utils.to_4D_image(image_or_images) 

334 original_ndims = img_utils.get_ndims(image_or_images) 

335 

336 image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] 

337 image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] 

338 output = transform( 

339 images, 

340 angles_to_projective_transforms(angles, image_height, image_width), 

341 interpolation=interpolation, 

342 fill_mode=fill_mode, 

343 fill_value=fill_value, 

344 ) 

345 return img_utils.from_4D_image(output, original_ndims) 

346 

347 

348def shear_x(image: TensorLike, level: float, replace: TensorLike) -> TensorLike: 

349 """Perform shear operation on an image (x-axis). 

350 

351 Args: 

352 image: A 3D image `Tensor`. 

353 level: A float denoting shear element along y-axis 

354 replace: A one or three value 1D tensor to fill empty pixels. 

355 Returns: 

356 Transformed image along X or Y axis, with space outside image 

357 filled with replace. 

358 """ 

359 # Shear parallel to x axis is a projective transform 

360 # with a matrix form of: 

361 # [1 level 

362 # 0 1]. 

363 image = transform(wrap(image), [1.0, level, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) 

364 return unwrap(image, replace) 

365 

366 

367def shear_y(image: TensorLike, level: float, replace: TensorLike) -> TensorLike: 

368 """Perform shear operation on an image (y-axis). 

369 

370 Args: 

371 image: A 3D image `Tensor`. 

372 level: A float denoting shear element along x-axis 

373 replace: A one or three value 1D tensor to fill empty pixels. 

374 Returns: 

375 Transformed image along X or Y axis, with space outside image 

376 filled with replace. 

377 """ 

378 # Shear parallel to y axis is a projective transform 

379 # with a matrix form of: 

380 # [1 0 

381 # level 1]. 

382 image = transform(wrap(image), [1.0, 0.0, 0.0, level, 1.0, 0.0, 0.0, 0.0]) 

383 return unwrap(image, replace)