Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/translate_ops.py: 36%

28 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Image translate ops.""" 

16 

17import tensorflow as tf 

18from tensorflow_addons.image.transform_ops import transform 

19from tensorflow_addons.image.utils import wrap, unwrap 

20from tensorflow_addons.utils.types import TensorLike 

21 

22from typing import Optional 

23 

24 

25def translations_to_projective_transforms( 

26 translations: TensorLike, name: Optional[str] = None 

27) -> tf.Tensor: 

28 """Returns projective transform(s) for the given translation(s). 

29 

30 Args: 

31 translations: A 2-element list representing `[dx, dy]` or a matrix of 

32 2-element lists representing `[dx, dy]` to translate for each image 

33 (for a batch of images). The rank must be statically known 

34 (the shape is not `TensorShape(None)`). 

35 name: The name of the op. 

36 Returns: 

37 A tensor of shape `(num_images, 8)` projective transforms which can be 

38 given to `tfa.image.transform`. 

39 """ 

40 with tf.name_scope(name or "translations_to_projective_transforms"): 

41 translation_or_translations = tf.convert_to_tensor( 

42 translations, name="translations", dtype=tf.dtypes.float32 

43 ) 

44 if translation_or_translations.get_shape().ndims is None: 

45 raise TypeError("translation_or_translations rank must be statically known") 

46 elif len(translation_or_translations.get_shape()) == 1: 

47 translations = translation_or_translations[None] 

48 elif len(translation_or_translations.get_shape()) == 2: 

49 translations = translation_or_translations 

50 else: 

51 raise TypeError("Translations should have rank 1 or 2.") 

52 num_translations = tf.shape(translations)[0] 

53 # The translation matrix looks like: 

54 # [[1 0 -dx] 

55 # [0 1 -dy] 

56 # [0 0 1]] 

57 # where the last entry is implicit. 

58 # Translation matrices are always float32. 

59 return tf.concat( 

60 values=[ 

61 tf.ones((num_translations, 1), tf.dtypes.float32), 

62 tf.zeros((num_translations, 1), tf.dtypes.float32), 

63 -translations[:, 0, None], 

64 tf.zeros((num_translations, 1), tf.dtypes.float32), 

65 tf.ones((num_translations, 1), tf.dtypes.float32), 

66 -translations[:, 1, None], 

67 tf.zeros((num_translations, 2), tf.dtypes.float32), 

68 ], 

69 axis=1, 

70 ) 

71 

72 

73@tf.function 

74def translate( 

75 images: TensorLike, 

76 translations: TensorLike, 

77 interpolation: str = "nearest", 

78 fill_mode: str = "constant", 

79 name: Optional[str] = None, 

80 fill_value: TensorLike = 0.0, 

81) -> tf.Tensor: 

82 """Translate image(s) by the passed vectors(s). 

83 

84 Args: 

85 images: A tensor of shape 

86 `(num_images, num_rows, num_columns, num_channels)` (NHWC), 

87 `(num_rows, num_columns, num_channels)` (HWC), or 

88 `(num_rows, num_columns)` (HW). The rank must be statically known (the 

89 shape is not `TensorShape(None)`). 

90 translations: A vector representing `[dx, dy]` or (if `images` has rank 4) 

91 a matrix of length num_images, with a `[dx, dy]` vector for each image 

92 in the batch. 

93 interpolation: Interpolation mode. Supported values: "nearest", 

94 "bilinear". 

95 fill_mode: Points outside the boundaries of the input are filled according 

96 to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). 

97 - *reflect*: `(d c b a | a b c d | d c b a)` 

98 The input is extended by reflecting about the edge of the last pixel. 

99 - *constant*: `(k k k k | a b c d | k k k k)` 

100 The input is extended by filling all values beyond the edge with the 

101 same constant value k = 0. 

102 - *wrap*: `(a b c d | a b c d | a b c d)` 

103 The input is extended by wrapping around to the opposite edge. 

104 - *nearest*: `(a a a a | a b c d | d d d d)` 

105 The input is extended by the nearest pixel. 

106 fill_value: a float represents the value to be filled outside the 

107 boundaries when `fill_mode` is "constant". 

108 name: The name of the op. 

109 Returns: 

110 Image(s) with the same type and shape as `images`, translated by the 

111 given vector(s). Empty space due to the translation will be filled with 

112 zeros. 

113 Raises: 

114 TypeError: If `images` is an invalid type. 

115 """ 

116 with tf.name_scope(name or "translate"): 

117 return transform( 

118 images, 

119 translations_to_projective_transforms(translations), 

120 interpolation=interpolation, 

121 fill_mode=fill_mode, 

122 fill_value=fill_value, 

123 ) 

124 

125 

126def translate_xy( 

127 image: TensorLike, translate_to: TensorLike, replace: TensorLike 

128) -> TensorLike: 

129 """Translates image in X or Y dimension. 

130 

131 Args: 

132 image: A 3D image `Tensor`. 

133 translate_to: A 1D `Tensor` to translate `[x, y]`. 

134 replace: A one or three value 1D `Tensor` to fill empty pixels. 

135 Returns: 

136 Translated image along X or Y axis, with space outside image 

137 filled with replace. 

138 Raises: 

139 ValueError: if axis is neither 0 nor 1. 

140 """ 

141 image = tf.convert_to_tensor(image) 

142 image = wrap(image) 

143 trans = tf.convert_to_tensor(translate_to) 

144 image = translate(image, [trans[0], trans[1]]) 

145 return unwrap(image, replace)