Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/dense_image_warp.py: 16%

83 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Image warping using per-pixel flow vectors.""" 

16 

17import tensorflow as tf 

18 

19from tensorflow_addons.utils import types 

20from typing import Optional 

21 

22 

23@tf.function 

24def interpolate_bilinear( 

25 grid: types.TensorLike, 

26 query_points: types.TensorLike, 

27 indexing: str = "ij", 

28 name: Optional[str] = None, 

29) -> tf.Tensor: 

30 """Similar to Matlab's interp2 function. 

31 

32 Finds values for query points on a grid using bilinear interpolation. 

33 

34 Args: 

35 grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. 

36 query_points: a 3-D float `Tensor` of N points with shape 

37 `[batch, N, 2]`. 

38 indexing: whether the query points are specified as row and column (ij), 

39 or Cartesian coordinates (xy). 

40 name: a name for the operation (optional). 

41 

42 Returns: 

43 values: a 3-D `Tensor` with shape `[batch, N, channels]` 

44 

45 Raises: 

46 ValueError: if the indexing mode is invalid, or if the shape of the 

47 inputs invalid. 

48 """ 

49 return _interpolate_bilinear_with_checks(grid, query_points, indexing, name) 

50 

51 

52def _interpolate_bilinear_with_checks( 

53 grid: types.TensorLike, 

54 query_points: types.TensorLike, 

55 indexing: str, 

56 name: Optional[str], 

57) -> tf.Tensor: 

58 """Perform checks on inputs without tf.function decorator to avoid flakiness.""" 

59 if indexing != "ij" and indexing != "xy": 

60 raise ValueError("Indexing mode must be 'ij' or 'xy'") 

61 

62 grid = tf.convert_to_tensor(grid) 

63 query_points = tf.convert_to_tensor(query_points) 

64 grid_shape = tf.shape(grid) 

65 query_shape = tf.shape(query_points) 

66 

67 with tf.control_dependencies( 

68 [ 

69 tf.debugging.assert_equal(tf.rank(grid), 4, "Grid must be 4D Tensor"), 

70 tf.debugging.assert_greater_equal( 

71 grid_shape[1], 2, "Grid height must be at least 2." 

72 ), 

73 tf.debugging.assert_greater_equal( 

74 grid_shape[2], 2, "Grid width must be at least 2." 

75 ), 

76 tf.debugging.assert_equal( 

77 tf.rank(query_points), 3, "Query points must be 3 dimensional." 

78 ), 

79 tf.debugging.assert_equal( 

80 query_shape[2], 2, "Query points last dimension must be 2." 

81 ), 

82 ] 

83 ): 

84 return _interpolate_bilinear_impl(grid, query_points, indexing, name) 

85 

86 

87def _interpolate_bilinear_impl( 

88 grid: types.TensorLike, 

89 query_points: types.TensorLike, 

90 indexing: str, 

91 name: Optional[str], 

92) -> tf.Tensor: 

93 """tf.function implementation of interpolate_bilinear.""" 

94 with tf.name_scope(name or "interpolate_bilinear"): 

95 grid_shape = tf.shape(grid) 

96 query_shape = tf.shape(query_points) 

97 

98 batch_size, height, width, channels = ( 

99 grid_shape[0], 

100 grid_shape[1], 

101 grid_shape[2], 

102 grid_shape[3], 

103 ) 

104 

105 num_queries = query_shape[1] 

106 

107 query_type = query_points.dtype 

108 grid_type = grid.dtype 

109 

110 alphas = [] 

111 floors = [] 

112 ceils = [] 

113 index_order = [0, 1] if indexing == "ij" else [1, 0] 

114 unstacked_query_points = tf.unstack(query_points, axis=2, num=2) 

115 

116 for i, dim in enumerate(index_order): 

117 with tf.name_scope("dim-" + str(dim)): 

118 queries = unstacked_query_points[dim] 

119 

120 size_in_indexing_dimension = grid_shape[i + 1] 

121 

122 # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 

123 # is still a valid index into the grid. 

124 max_floor = tf.cast(size_in_indexing_dimension - 2, query_type) 

125 min_floor = tf.constant(0.0, dtype=query_type) 

126 floor = tf.math.minimum( 

127 tf.math.maximum(min_floor, tf.math.floor(queries)), max_floor 

128 ) 

129 int_floor = tf.cast(floor, tf.dtypes.int32) 

130 floors.append(int_floor) 

131 ceil = int_floor + 1 

132 ceils.append(ceil) 

133 

134 # alpha has the same type as the grid, as we will directly use alpha 

135 # when taking linear combinations of pixel values from the image. 

136 alpha = tf.cast(queries - floor, grid_type) 

137 min_alpha = tf.constant(0.0, dtype=grid_type) 

138 max_alpha = tf.constant(1.0, dtype=grid_type) 

139 alpha = tf.math.minimum(tf.math.maximum(min_alpha, alpha), max_alpha) 

140 

141 # Expand alpha to [b, n, 1] so we can use broadcasting 

142 # (since the alpha values don't depend on the channel). 

143 alpha = tf.expand_dims(alpha, 2) 

144 alphas.append(alpha) 

145 

146 flattened_grid = tf.reshape(grid, [batch_size * height * width, channels]) 

147 batch_offsets = tf.reshape( 

148 tf.range(batch_size) * height * width, [batch_size, 1] 

149 ) 

150 

151 # This wraps tf.gather. We reshape the image data such that the 

152 # batch, y, and x coordinates are pulled into the first dimension. 

153 # Then we gather. Finally, we reshape the output back. It's possible this 

154 # code would be made simpler by using tf.gather_nd. 

155 def gather(y_coords, x_coords, name): 

156 with tf.name_scope("gather-" + name): 

157 linear_coordinates = batch_offsets + y_coords * width + x_coords 

158 gathered_values = tf.gather(flattened_grid, linear_coordinates) 

159 return tf.reshape(gathered_values, [batch_size, num_queries, channels]) 

160 

161 # grab the pixel values in the 4 corners around each query point 

162 top_left = gather(floors[0], floors[1], "top_left") 

163 top_right = gather(floors[0], ceils[1], "top_right") 

164 bottom_left = gather(ceils[0], floors[1], "bottom_left") 

165 bottom_right = gather(ceils[0], ceils[1], "bottom_right") 

166 

167 # now, do the actual interpolation 

168 with tf.name_scope("interpolate"): 

169 interp_top = alphas[1] * (top_right - top_left) + top_left 

170 interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left 

171 interp = alphas[0] * (interp_bottom - interp_top) + interp_top 

172 

173 return interp 

174 

175 

176def _get_dim(x, idx): 

177 if x.shape.ndims is None: 

178 return tf.shape(x)[idx] 

179 return x.shape[idx] or tf.shape(x)[idx] 

180 

181 

182@tf.function 

183def dense_image_warp( 

184 image: types.TensorLike, flow: types.TensorLike, name: Optional[str] = None 

185) -> tf.Tensor: 

186 """Image warping using per-pixel flow vectors. 

187 

188 Apply a non-linear warp to the image, where the warp is specified by a 

189 dense flow field of offset vectors that define the correspondences of 

190 pixel values in the output image back to locations in the source image. 

191 Specifically, the pixel value at `output[b, j, i, c]` is 

192 `images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]`. 

193 

194 The locations specified by this formula do not necessarily map to an int 

195 index. Therefore, the pixel value is obtained by bilinear 

196 interpolation of the 4 nearest pixels around 

197 `(b, j - flow[b, j, i, 0], i - flow[b, j, i, 1])`. For locations outside 

198 of the image, we use the nearest pixel values at the image boundary. 

199 

200 NOTE: The definition of the flow field above is different from that 

201 of optical flow. This function expects the negative forward flow from 

202 output image to source image. Given two images `I_1` and `I_2` and the 

203 optical flow `F_12` from `I_1` to `I_2`, the image `I_1` can be 

204 reconstructed by `I_1_rec = dense_image_warp(I_2, -F_12)`. 

205 

206 Args: 

207 image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. 

208 flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. 

209 name: A name for the operation (optional). 

210 

211 Note that image and flow can be of type `tf.half`, `tf.float32`, or 

212 `tf.float64`, and do not necessarily have to be the same type. 

213 

214 Returns: 

215 A 4-D float `Tensor` with shape`[batch, height, width, channels]` 

216 and same type as input image. 

217 

218 Raises: 

219 ValueError: if `height < 2` or `width < 2` or the inputs have the wrong 

220 number of dimensions. 

221 """ 

222 with tf.name_scope(name or "dense_image_warp"): 

223 image = tf.convert_to_tensor(image) 

224 flow = tf.convert_to_tensor(flow) 

225 batch_size, height, width, channels = ( 

226 _get_dim(image, 0), 

227 _get_dim(image, 1), 

228 _get_dim(image, 2), 

229 _get_dim(image, 3), 

230 ) 

231 

232 # The flow is defined on the image grid. Turn the flow into a list of query 

233 # points in the grid space. 

234 grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height)) 

235 stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype) 

236 batched_grid = tf.expand_dims(stacked_grid, axis=0) 

237 query_points_on_grid = batched_grid - flow 

238 query_points_flattened = tf.reshape( 

239 query_points_on_grid, [batch_size, height * width, 2] 

240 ) 

241 # Compute values at the query points, then reshape the result back to the 

242 # image grid. 

243 interpolated = interpolate_bilinear(image, query_points_flattened) 

244 interpolated = tf.reshape(interpolated, [batch_size, height, width, channels]) 

245 return interpolated 

246 

247 

248@tf.function(experimental_implements="addons:DenseImageWarp") 

249def dense_image_warp_annotated( 

250 image: types.TensorLike, flow: types.TensorLike 

251) -> tf.Tensor: 

252 """Similar to dense_image_warp but annotated with experimental_implements. 

253 

254 IMPORTANT: This is a temporary function and will be removed after TensorFlow's 

255 next release. 

256 

257 This annotation make the serialized function detectable by the TFLite MLIR 

258 converter and allow the converter to convert it to corresponding TFLite op. 

259 

260 However, with the annotation, this function cannot be used with backprop 

261 under `tf.GradientTape` objects. 

262 """ 

263 return dense_image_warp(image, flow)