Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/transform_ops.py: 18%
85 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Image transform ops."""
17import tensorflow as tf
18from tensorflow_addons.image import utils as img_utils
19from tensorflow_addons.utils.types import TensorLike
20from tensorflow_addons.image.utils import wrap, unwrap
22from typing import Optional
25_IMAGE_DTYPES = {
26 tf.dtypes.uint8,
27 tf.dtypes.int32,
28 tf.dtypes.int64,
29 tf.dtypes.float16,
30 tf.dtypes.float32,
31 tf.dtypes.float64,
32}
35def transform(
36 images: TensorLike,
37 transforms: TensorLike,
38 interpolation: str = "nearest",
39 fill_mode: str = "constant",
40 output_shape: Optional[list] = None,
41 name: Optional[str] = None,
42 fill_value: TensorLike = 0.0,
43) -> tf.Tensor:
44 """Applies the given transform(s) to the image(s).
46 Args:
47 images: A tensor of shape (num_images, num_rows, num_columns,
48 num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or
49 (num_rows, num_columns) (HW).
50 transforms: Projective transform matrix/matrices. A vector of length 8 or
51 tensor of size N x 8. If one row of transforms is
52 [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
53 `(x, y)` to a transformed *input* point
54 `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
55 where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
56 the transform mapping input points to output points. Note that
57 gradients are not backpropagated into transformation parameters.
58 interpolation: Interpolation mode.
59 Supported values: "nearest", "bilinear".
60 fill_mode: Points outside the boundaries of the input are filled according
61 to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
62 - *reflect*: `(d c b a | a b c d | d c b a)`
63 The input is extended by reflecting about the edge of the last pixel.
64 - *constant*: `(k k k k | a b c d | k k k k)`
65 The input is extended by filling all values beyond the edge with the
66 same constant value k = 0.
67 - *wrap*: `(a b c d | a b c d | a b c d)`
68 The input is extended by wrapping around to the opposite edge.
69 - *nearest*: `(a a a a | a b c d | d d d d)`
70 The input is extended by the nearest pixel.
71 fill_value: a float represents the value to be filled outside the
72 boundaries when `fill_mode` is "constant".
73 output_shape: Output dimesion after the transform, [height, width].
74 If None, output is the same size as input image.
76 name: The name of the op.
78 Returns:
79 Image(s) with the same type and shape as `images`, with the given
80 transform(s) applied. Transformed coordinates outside of the input image
81 will be filled with zeros.
83 Raises:
84 TypeError: If `image` is an invalid type.
85 ValueError: If output shape is not 1-D int32 Tensor.
86 """
87 with tf.name_scope(name or "transform"):
88 image_or_images = tf.convert_to_tensor(images, name="images")
89 transform_or_transforms = tf.convert_to_tensor(
90 transforms, name="transforms", dtype=tf.dtypes.float32
91 )
92 if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
93 raise TypeError("Invalid dtype %s." % image_or_images.dtype)
94 images = img_utils.to_4D_image(image_or_images)
95 original_ndims = img_utils.get_ndims(image_or_images)
97 if output_shape is None:
98 output_shape = tf.shape(images)[1:3]
100 output_shape = tf.convert_to_tensor(
101 output_shape, tf.dtypes.int32, name="output_shape"
102 )
104 if not output_shape.get_shape().is_compatible_with([2]):
105 raise ValueError(
106 "output_shape must be a 1-D Tensor of 2 elements: "
107 "new_height, new_width"
108 )
110 if len(transform_or_transforms.get_shape()) == 1:
111 transforms = transform_or_transforms[None]
112 elif transform_or_transforms.get_shape().ndims is None:
113 raise ValueError("transforms rank must be statically known")
114 elif len(transform_or_transforms.get_shape()) == 2:
115 transforms = transform_or_transforms
116 else:
117 transforms = transform_or_transforms
118 raise ValueError(
119 "transforms should have rank 1 or 2, but got rank %d"
120 % len(transforms.get_shape())
121 )
123 fill_value = tf.convert_to_tensor(
124 fill_value, dtype=tf.float32, name="fill_value"
125 )
126 output = tf.raw_ops.ImageProjectiveTransformV3(
127 images=images,
128 transforms=transforms,
129 output_shape=output_shape,
130 interpolation=interpolation.upper(),
131 fill_mode=fill_mode.upper(),
132 fill_value=fill_value,
133 )
134 return img_utils.from_4D_image(output, original_ndims)
137def compose_transforms(transforms: TensorLike, name: Optional[str] = None) -> tf.Tensor:
138 """Composes the transforms tensors.
140 Args:
141 transforms: List of image projective transforms to be composed. Each
142 transform is length 8 (single transform) or shape (N, 8) (batched
143 transforms). The shapes of all inputs must be equal, and at least one
144 input must be given.
145 name: The name for the op.
147 Returns:
148 A composed transform tensor. When passed to `transform` op,
149 equivalent to applying each of the given transforms to the image in
150 order.
151 """
152 assert transforms, "transforms cannot be empty"
153 with tf.name_scope(name or "compose_transforms"):
154 composed = flat_transforms_to_matrices(transforms[0])
155 for tr in transforms[1:]:
156 # Multiply batches of matrices.
157 composed = tf.matmul(composed, flat_transforms_to_matrices(tr))
158 return matrices_to_flat_transforms(composed)
161def flat_transforms_to_matrices(
162 transforms: TensorLike, name: Optional[str] = None
163) -> tf.Tensor:
164 """Converts projective transforms to affine matrices.
166 Note that the output matrices map output coordinates to input coordinates.
167 For the forward transformation matrix, call `tf.linalg.inv` on the result.
169 Args:
170 transforms: Vector of length 8, or batches of transforms with shape
171 `(N, 8)`.
172 name: The name for the op.
174 Returns:
175 3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the
176 *output coordinates* (in homogeneous coordinates) of each transform to
177 the corresponding *input coordinates*.
179 Raises:
180 ValueError: If `transforms` have an invalid shape.
181 """
182 with tf.name_scope(name or "flat_transforms_to_matrices"):
183 transforms = tf.convert_to_tensor(transforms, name="transforms")
184 if transforms.shape.ndims not in (1, 2):
185 raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms)
186 # Make the transform(s) 2D in case the input is a single transform.
187 transforms = tf.reshape(transforms, tf.constant([-1, 8]))
188 num_transforms = tf.shape(transforms)[0]
189 # Add a column of ones for the implicit last entry in the matrix.
190 return tf.reshape(
191 tf.concat([transforms, tf.ones([num_transforms, 1])], axis=1),
192 tf.constant([-1, 3, 3]),
193 )
196def matrices_to_flat_transforms(
197 transform_matrices: TensorLike, name: Optional[str] = None
198) -> tf.Tensor:
199 """Converts affine matrices to projective transforms.
201 Note that we expect matrices that map output coordinates to input
202 coordinates. To convert forward transformation matrices,
203 call `tf.linalg.inv` on the matrices and use the result here.
205 Args:
206 transform_matrices: One or more affine transformation matrices, for the
207 reverse transformation in homogeneous coordinates. Shape `(3, 3)` or
208 `(N, 3, 3)`.
209 name: The name for the op.
211 Returns:
212 2D tensor of flat transforms with shape `(N, 8)`, which may be passed
213 into `transform` op.
215 Raises:
216 ValueError: If `transform_matrices` have an invalid shape.
217 """
218 with tf.name_scope(name or "matrices_to_flat_transforms"):
219 transform_matrices = tf.convert_to_tensor(
220 transform_matrices, name="transform_matrices"
221 )
222 if transform_matrices.shape.ndims not in (2, 3):
223 raise ValueError(
224 "Matrices should be 2D or 3D, got: %s" % transform_matrices
225 )
226 # Flatten each matrix.
227 transforms = tf.reshape(transform_matrices, tf.constant([-1, 9]))
228 # Divide each matrix by the last entry (normally 1).
229 transforms /= transforms[:, 8:9]
230 return transforms[:, :8]
233def angles_to_projective_transforms(
234 angles: TensorLike,
235 image_height: TensorLike,
236 image_width: TensorLike,
237 name: Optional[str] = None,
238) -> tf.Tensor:
239 """Returns projective transform(s) for the given angle(s).
241 Args:
242 angles: A scalar angle to rotate all images by, or (for batches of
243 images) a vector with an angle to rotate each image in the batch. The
244 rank must be statically known (the shape is not `TensorShape(None)`.
245 image_height: Height of the image(s) to be transformed.
246 image_width: Width of the image(s) to be transformed.
248 Returns:
249 A tensor of shape (num_images, 8). Projective transforms which can be
250 given to `transform` op.
251 """
252 with tf.name_scope(name or "angles_to_projective_transforms"):
253 angle_or_angles = tf.convert_to_tensor(
254 angles, name="angles", dtype=tf.dtypes.float32
255 )
256 if len(angle_or_angles.get_shape()) == 0:
257 angles = angle_or_angles[None]
258 elif len(angle_or_angles.get_shape()) == 1:
259 angles = angle_or_angles
260 else:
261 raise ValueError("angles should have rank 0 or 1.")
262 cos_angles = tf.math.cos(angles)
263 sin_angles = tf.math.sin(angles)
264 x_offset = (
265 (image_width - 1)
266 - (cos_angles * (image_width - 1) - sin_angles * (image_height - 1))
267 ) / 2.0
268 y_offset = (
269 (image_height - 1)
270 - (sin_angles * (image_width - 1) + cos_angles * (image_height - 1))
271 ) / 2.0
272 num_angles = tf.shape(angles)[0]
273 return tf.concat(
274 values=[
275 cos_angles[:, None],
276 -sin_angles[:, None],
277 x_offset[:, None],
278 sin_angles[:, None],
279 cos_angles[:, None],
280 y_offset[:, None],
281 tf.zeros((num_angles, 2), tf.dtypes.float32),
282 ],
283 axis=1,
284 )
287def rotate(
288 images: TensorLike,
289 angles: TensorLike,
290 interpolation: str = "nearest",
291 fill_mode: str = "constant",
292 name: Optional[str] = None,
293 fill_value: TensorLike = 0.0,
294) -> tf.Tensor:
295 """Rotate image(s) counterclockwise by the passed angle(s) in radians.
297 Args:
298 images: A tensor of shape
299 `(num_images, num_rows, num_columns, num_channels)`
300 (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or
301 `(num_rows, num_columns)` (HW).
302 angles: A scalar angle to rotate all images by, or (if `images` has rank 4)
303 a vector of length num_images, with an angle for each image in the
304 batch.
305 interpolation: Interpolation mode. Supported values: "nearest",
306 "bilinear".
307 fill_mode: Points outside the boundaries of the input are filled according
308 to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
309 - *reflect*: `(d c b a | a b c d | d c b a)`
310 The input is extended by reflecting about the edge of the last pixel.
311 - *constant*: `(k k k k | a b c d | k k k k)`
312 The input is extended by filling all values beyond the edge with the
313 same constant value k = 0.
314 - *wrap*: `(a b c d | a b c d | a b c d)`
315 The input is extended by wrapping around to the opposite edge.
316 - *nearest*: `(a a a a | a b c d | d d d d)`
317 The input is extended by the nearest pixel.
318 fill_value: a float represents the value to be filled outside the
319 boundaries when `fill_mode` is "constant".
320 name: The name of the op.
322 Returns:
323 Image(s) with the same type and shape as `images`, rotated by the given
324 angle(s). Empty space due to the rotation will be filled with zeros.
326 Raises:
327 TypeError: If `images` is an invalid type.
328 """
329 with tf.name_scope(name or "rotate"):
330 image_or_images = tf.convert_to_tensor(images)
331 if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
332 raise TypeError("Invalid dtype %s." % image_or_images.dtype)
333 images = img_utils.to_4D_image(image_or_images)
334 original_ndims = img_utils.get_ndims(image_or_images)
336 image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None]
337 image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None]
338 output = transform(
339 images,
340 angles_to_projective_transforms(angles, image_height, image_width),
341 interpolation=interpolation,
342 fill_mode=fill_mode,
343 fill_value=fill_value,
344 )
345 return img_utils.from_4D_image(output, original_ndims)
348def shear_x(image: TensorLike, level: float, replace: TensorLike) -> TensorLike:
349 """Perform shear operation on an image (x-axis).
351 Args:
352 image: A 3D image `Tensor`.
353 level: A float denoting shear element along y-axis
354 replace: A one or three value 1D tensor to fill empty pixels.
355 Returns:
356 Transformed image along X or Y axis, with space outside image
357 filled with replace.
358 """
359 # Shear parallel to x axis is a projective transform
360 # with a matrix form of:
361 # [1 level
362 # 0 1].
363 image = transform(wrap(image), [1.0, level, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
364 return unwrap(image, replace)
367def shear_y(image: TensorLike, level: float, replace: TensorLike) -> TensorLike:
368 """Perform shear operation on an image (y-axis).
370 Args:
371 image: A 3D image `Tensor`.
372 level: A float denoting shear element along x-axis
373 replace: A one or three value 1D tensor to fill empty pixels.
374 Returns:
375 Transformed image along X or Y axis, with space outside image
376 filled with replace.
377 """
378 # Shear parallel to y axis is a projective transform
379 # with a matrix form of:
380 # [1 0
381 # level 1].
382 image = transform(wrap(image), [1.0, 0.0, 0.0, level, 1.0, 0.0, 0.0, 0.0])
383 return unwrap(image, replace)