Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/utils.py: 16%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Image util ops.""" 

16 

17import tensorflow as tf 

18 

19 

20def get_ndims(image): 

21 return image.get_shape().ndims or tf.rank(image) 

22 

23 

24def to_4D_image(image): 

25 """Convert 2/3/4D image to 4D image. 

26 

27 Args: 

28 image: 2/3/4D `Tensor`. 

29 

30 Returns: 

31 4D `Tensor` with the same type. 

32 """ 

33 with tf.control_dependencies( 

34 [ 

35 tf.debugging.assert_rank_in( 

36 image, [2, 3, 4], message="`image` must be 2/3/4D tensor" 

37 ) 

38 ] 

39 ): 

40 ndims = image.get_shape().ndims 

41 if ndims is None: 

42 return _dynamic_to_4D_image(image) 

43 elif ndims == 2: 

44 return image[None, :, :, None] 

45 elif ndims == 3: 

46 return image[None, :, :, :] 

47 else: 

48 return image 

49 

50 

51def _dynamic_to_4D_image(image): 

52 shape = tf.shape(image) 

53 original_rank = tf.rank(image) 

54 # 4D image => [N, H, W, C] or [N, C, H, W] 

55 # 3D image => [1, H, W, C] or [1, C, H, W] 

56 # 2D image => [1, H, W, 1] 

57 left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) 

58 right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) 

59 new_shape = tf.concat( 

60 [ 

61 tf.ones(shape=left_pad, dtype=tf.int32), 

62 shape, 

63 tf.ones(shape=right_pad, dtype=tf.int32), 

64 ], 

65 axis=0, 

66 ) 

67 return tf.reshape(image, new_shape) 

68 

69 

70def from_4D_image(image, ndims): 

71 """Convert back to an image with `ndims` rank. 

72 

73 Args: 

74 image: 4D `Tensor`. 

75 ndims: The original rank of the image. 

76 

77 Returns: 

78 `ndims`-D `Tensor` with the same type. 

79 """ 

80 with tf.control_dependencies( 

81 [tf.debugging.assert_rank(image, 4, message="`image` must be 4D tensor")] 

82 ): 

83 if isinstance(ndims, tf.Tensor): 

84 return _dynamic_from_4D_image(image, ndims) 

85 elif ndims == 2: 

86 return tf.squeeze(image, [0, 3]) 

87 elif ndims == 3: 

88 return tf.squeeze(image, [0]) 

89 else: 

90 return image 

91 

92 

93def _dynamic_from_4D_image(image, original_rank): 

94 shape = tf.shape(image) 

95 # 4D image <= [N, H, W, C] or [N, C, H, W] 

96 # 3D image <= [1, H, W, C] or [1, C, H, W] 

97 # 2D image <= [1, H, W, 1] 

98 begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) 

99 end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) 

100 new_shape = shape[begin:end] 

101 return tf.reshape(image, new_shape) 

102 

103 

104def wrap(image): 

105 """Returns `image` with an extra channel set to all 1s.""" 

106 shape = tf.shape(image) 

107 extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) 

108 extended = tf.concat([image, extended_channel], 2) 

109 return extended 

110 

111 

112def unwrap(image, replace): 

113 """Unwraps an image produced by wrap. 

114 

115 Where there is a 0 in the last channel for every spatial position, 

116 the rest of the three channels in that spatial dimension are grayed 

117 (set to 128). Operations like translate and shear on a wrapped 

118 Tensor will leave 0s in empty locations. Some transformations look 

119 at the intensity of values to do preprocessing, and we want these 

120 empty pixels to assume the 'average' value, rather than pure black. 

121 

122 

123 Args: 

124 image: A 3D image `Tensor` with 4 channels. 

125 replace: A one or three value 1D `Tensor` to fill empty pixels. 

126 

127 Returns: 

128 image: A 3D image `Tensor` with 3 channels. 

129 """ 

130 image_shape = tf.shape(image) 

131 # Flatten the spatial dimensions. 

132 flattened_image = tf.reshape(image, [-1, image_shape[2]]) 

133 

134 # Find all pixels where the last channel is zero. 

135 alpha_channel = flattened_image[:, 3] 

136 

137 replace = tf.cast(replace, image.dtype) 

138 if tf.rank(replace) == 0: 

139 replace = tf.expand_dims(replace, 0) 

140 replace = tf.concat([replace, replace, replace], 0) 

141 replace = tf.concat([replace, tf.ones([1], dtype=replace.dtype)], 0) 

142 

143 # Where they are zero, fill them in with 'replace'. 

144 cond = tf.equal(alpha_channel, 1) 

145 cond = tf.expand_dims(cond, 1) 

146 cond = tf.concat([cond, cond, cond, cond], 1) 

147 flattened_image = tf.where(cond, flattened_image, replace) 

148 

149 image = tf.reshape(flattened_image, image_shape) 

150 image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) 

151 return image