Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/dense_image_warp.py: 16%
83 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Image warping using per-pixel flow vectors."""
17import tensorflow as tf
19from tensorflow_addons.utils import types
20from typing import Optional
23@tf.function
24def interpolate_bilinear(
25 grid: types.TensorLike,
26 query_points: types.TensorLike,
27 indexing: str = "ij",
28 name: Optional[str] = None,
29) -> tf.Tensor:
30 """Similar to Matlab's interp2 function.
32 Finds values for query points on a grid using bilinear interpolation.
34 Args:
35 grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
36 query_points: a 3-D float `Tensor` of N points with shape
37 `[batch, N, 2]`.
38 indexing: whether the query points are specified as row and column (ij),
39 or Cartesian coordinates (xy).
40 name: a name for the operation (optional).
42 Returns:
43 values: a 3-D `Tensor` with shape `[batch, N, channels]`
45 Raises:
46 ValueError: if the indexing mode is invalid, or if the shape of the
47 inputs invalid.
48 """
49 return _interpolate_bilinear_with_checks(grid, query_points, indexing, name)
52def _interpolate_bilinear_with_checks(
53 grid: types.TensorLike,
54 query_points: types.TensorLike,
55 indexing: str,
56 name: Optional[str],
57) -> tf.Tensor:
58 """Perform checks on inputs without tf.function decorator to avoid flakiness."""
59 if indexing != "ij" and indexing != "xy":
60 raise ValueError("Indexing mode must be 'ij' or 'xy'")
62 grid = tf.convert_to_tensor(grid)
63 query_points = tf.convert_to_tensor(query_points)
64 grid_shape = tf.shape(grid)
65 query_shape = tf.shape(query_points)
67 with tf.control_dependencies(
68 [
69 tf.debugging.assert_equal(tf.rank(grid), 4, "Grid must be 4D Tensor"),
70 tf.debugging.assert_greater_equal(
71 grid_shape[1], 2, "Grid height must be at least 2."
72 ),
73 tf.debugging.assert_greater_equal(
74 grid_shape[2], 2, "Grid width must be at least 2."
75 ),
76 tf.debugging.assert_equal(
77 tf.rank(query_points), 3, "Query points must be 3 dimensional."
78 ),
79 tf.debugging.assert_equal(
80 query_shape[2], 2, "Query points last dimension must be 2."
81 ),
82 ]
83 ):
84 return _interpolate_bilinear_impl(grid, query_points, indexing, name)
87def _interpolate_bilinear_impl(
88 grid: types.TensorLike,
89 query_points: types.TensorLike,
90 indexing: str,
91 name: Optional[str],
92) -> tf.Tensor:
93 """tf.function implementation of interpolate_bilinear."""
94 with tf.name_scope(name or "interpolate_bilinear"):
95 grid_shape = tf.shape(grid)
96 query_shape = tf.shape(query_points)
98 batch_size, height, width, channels = (
99 grid_shape[0],
100 grid_shape[1],
101 grid_shape[2],
102 grid_shape[3],
103 )
105 num_queries = query_shape[1]
107 query_type = query_points.dtype
108 grid_type = grid.dtype
110 alphas = []
111 floors = []
112 ceils = []
113 index_order = [0, 1] if indexing == "ij" else [1, 0]
114 unstacked_query_points = tf.unstack(query_points, axis=2, num=2)
116 for i, dim in enumerate(index_order):
117 with tf.name_scope("dim-" + str(dim)):
118 queries = unstacked_query_points[dim]
120 size_in_indexing_dimension = grid_shape[i + 1]
122 # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
123 # is still a valid index into the grid.
124 max_floor = tf.cast(size_in_indexing_dimension - 2, query_type)
125 min_floor = tf.constant(0.0, dtype=query_type)
126 floor = tf.math.minimum(
127 tf.math.maximum(min_floor, tf.math.floor(queries)), max_floor
128 )
129 int_floor = tf.cast(floor, tf.dtypes.int32)
130 floors.append(int_floor)
131 ceil = int_floor + 1
132 ceils.append(ceil)
134 # alpha has the same type as the grid, as we will directly use alpha
135 # when taking linear combinations of pixel values from the image.
136 alpha = tf.cast(queries - floor, grid_type)
137 min_alpha = tf.constant(0.0, dtype=grid_type)
138 max_alpha = tf.constant(1.0, dtype=grid_type)
139 alpha = tf.math.minimum(tf.math.maximum(min_alpha, alpha), max_alpha)
141 # Expand alpha to [b, n, 1] so we can use broadcasting
142 # (since the alpha values don't depend on the channel).
143 alpha = tf.expand_dims(alpha, 2)
144 alphas.append(alpha)
146 flattened_grid = tf.reshape(grid, [batch_size * height * width, channels])
147 batch_offsets = tf.reshape(
148 tf.range(batch_size) * height * width, [batch_size, 1]
149 )
151 # This wraps tf.gather. We reshape the image data such that the
152 # batch, y, and x coordinates are pulled into the first dimension.
153 # Then we gather. Finally, we reshape the output back. It's possible this
154 # code would be made simpler by using tf.gather_nd.
155 def gather(y_coords, x_coords, name):
156 with tf.name_scope("gather-" + name):
157 linear_coordinates = batch_offsets + y_coords * width + x_coords
158 gathered_values = tf.gather(flattened_grid, linear_coordinates)
159 return tf.reshape(gathered_values, [batch_size, num_queries, channels])
161 # grab the pixel values in the 4 corners around each query point
162 top_left = gather(floors[0], floors[1], "top_left")
163 top_right = gather(floors[0], ceils[1], "top_right")
164 bottom_left = gather(ceils[0], floors[1], "bottom_left")
165 bottom_right = gather(ceils[0], ceils[1], "bottom_right")
167 # now, do the actual interpolation
168 with tf.name_scope("interpolate"):
169 interp_top = alphas[1] * (top_right - top_left) + top_left
170 interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
171 interp = alphas[0] * (interp_bottom - interp_top) + interp_top
173 return interp
176def _get_dim(x, idx):
177 if x.shape.ndims is None:
178 return tf.shape(x)[idx]
179 return x.shape[idx] or tf.shape(x)[idx]
182@tf.function
183def dense_image_warp(
184 image: types.TensorLike, flow: types.TensorLike, name: Optional[str] = None
185) -> tf.Tensor:
186 """Image warping using per-pixel flow vectors.
188 Apply a non-linear warp to the image, where the warp is specified by a
189 dense flow field of offset vectors that define the correspondences of
190 pixel values in the output image back to locations in the source image.
191 Specifically, the pixel value at `output[b, j, i, c]` is
192 `images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]`.
194 The locations specified by this formula do not necessarily map to an int
195 index. Therefore, the pixel value is obtained by bilinear
196 interpolation of the 4 nearest pixels around
197 `(b, j - flow[b, j, i, 0], i - flow[b, j, i, 1])`. For locations outside
198 of the image, we use the nearest pixel values at the image boundary.
200 NOTE: The definition of the flow field above is different from that
201 of optical flow. This function expects the negative forward flow from
202 output image to source image. Given two images `I_1` and `I_2` and the
203 optical flow `F_12` from `I_1` to `I_2`, the image `I_1` can be
204 reconstructed by `I_1_rec = dense_image_warp(I_2, -F_12)`.
206 Args:
207 image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
208 flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
209 name: A name for the operation (optional).
211 Note that image and flow can be of type `tf.half`, `tf.float32`, or
212 `tf.float64`, and do not necessarily have to be the same type.
214 Returns:
215 A 4-D float `Tensor` with shape`[batch, height, width, channels]`
216 and same type as input image.
218 Raises:
219 ValueError: if `height < 2` or `width < 2` or the inputs have the wrong
220 number of dimensions.
221 """
222 with tf.name_scope(name or "dense_image_warp"):
223 image = tf.convert_to_tensor(image)
224 flow = tf.convert_to_tensor(flow)
225 batch_size, height, width, channels = (
226 _get_dim(image, 0),
227 _get_dim(image, 1),
228 _get_dim(image, 2),
229 _get_dim(image, 3),
230 )
232 # The flow is defined on the image grid. Turn the flow into a list of query
233 # points in the grid space.
234 grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
235 stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype)
236 batched_grid = tf.expand_dims(stacked_grid, axis=0)
237 query_points_on_grid = batched_grid - flow
238 query_points_flattened = tf.reshape(
239 query_points_on_grid, [batch_size, height * width, 2]
240 )
241 # Compute values at the query points, then reshape the result back to the
242 # image grid.
243 interpolated = interpolate_bilinear(image, query_points_flattened)
244 interpolated = tf.reshape(interpolated, [batch_size, height, width, channels])
245 return interpolated
248@tf.function(experimental_implements="addons:DenseImageWarp")
249def dense_image_warp_annotated(
250 image: types.TensorLike, flow: types.TensorLike
251) -> tf.Tensor:
252 """Similar to dense_image_warp but annotated with experimental_implements.
254 IMPORTANT: This is a temporary function and will be removed after TensorFlow's
255 next release.
257 This annotation make the serialized function detectable by the TFLite MLIR
258 converter and allow the converter to convert it to corresponding TFLite op.
260 However, with the annotation, this function cannot be used with backprop
261 under `tf.GradientTape` objects.
262 """
263 return dense_image_warp(image, flow)