Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/sparse_image_warp.py: 20%
55 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Image warping using sparse flow defined at control points."""
17import tensorflow as tf
19from tensorflow_addons.image import dense_image_warp
20from tensorflow_addons.image import interpolate_spline
21from tensorflow_addons.image import utils as img_utils
22from tensorflow_addons.utils.types import TensorLike, FloatTensorLike
25def _get_grid_locations(
26 image_height: TensorLike, image_width: TensorLike
27) -> TensorLike:
28 """Wrapper for `tf.meshgrid`."""
29 y_range = tf.linspace(0, image_height - 1, image_height)
30 x_range = tf.linspace(0, image_width - 1, image_width)
31 y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing="ij")
32 return tf.stack((y_grid, x_grid), -1)
35def _expand_to_minibatch(array: TensorLike, batch_size: TensorLike) -> TensorLike:
36 """Tile arbitrarily-sized array to include new batch dimension."""
37 batch_size = tf.expand_dims(batch_size, 0)
38 array_ones = tf.ones((tf.rank(array)), dtype=tf.dtypes.int32)
39 tiles = tf.concat([batch_size, array_ones], axis=0)
40 return tf.tile(tf.expand_dims(array, 0), tiles)
43def _get_boundary_locations(
44 image_height: TensorLike, image_width: TensorLike, num_points_per_edge: TensorLike
45) -> TensorLike:
46 """Compute evenly-spaced indices along edge of image."""
47 image_height = tf.cast(image_height, tf.float32)
48 image_width = tf.cast(image_width, tf.float32)
49 y_range = tf.linspace(0.0, image_height - 1, num_points_per_edge + 2)
50 x_range = tf.linspace(0.0, image_width - 1, num_points_per_edge + 2)
51 ys, xs = tf.meshgrid(y_range, x_range, indexing="ij")
52 is_boundary = tf.logical_or(
53 tf.logical_or(tf.equal(xs, 0), tf.equal(xs, image_width - 1)),
54 tf.logical_or(tf.equal(ys, 0), tf.equal(ys, image_height - 1)),
55 )
56 return tf.stack(
57 [tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1
58 )
61def _add_zero_flow_controls_at_boundary(
62 control_point_locations: TensorLike,
63 control_point_flows: TensorLike,
64 image_height: TensorLike,
65 image_width: TensorLike,
66 boundary_points_per_edge: TensorLike,
67) -> tf.Tensor:
68 """Add control points for zero-flow boundary conditions.
70 Augment the set of control points with extra points on the
71 boundary of the image that have zero flow.
73 Args:
74 control_point_locations: input control points.
75 control_point_flows: their flows.
76 image_height: image height.
77 image_width: image width.
78 boundary_points_per_edge: number of points to add in the middle of each
79 edge (not including the corners).
80 The total number of points added is
81 `4 + 4*(boundary_points_per_edge)`.
83 Returns:
84 merged_control_point_locations: augmented set of control point locations.
85 merged_control_point_flows: augmented set of control point flows.
86 """
88 batch_size = tf.shape(control_point_locations)[0]
90 boundary_point_locations = _get_boundary_locations(
91 image_height, image_width, boundary_points_per_edge
92 )
94 boundary_point_flows = tf.zeros([tf.shape(boundary_point_locations)[0], 2])
96 type_to_use = control_point_locations.dtype
97 boundary_point_locations = tf.cast(
98 _expand_to_minibatch(boundary_point_locations, batch_size), type_to_use
99 )
101 boundary_point_flows = tf.cast(
102 _expand_to_minibatch(boundary_point_flows, batch_size), type_to_use
103 )
105 merged_control_point_locations = tf.concat(
106 [control_point_locations, boundary_point_locations], 1
107 )
109 merged_control_point_flows = tf.concat(
110 [control_point_flows, boundary_point_flows], 1
111 )
113 return merged_control_point_locations, merged_control_point_flows
116def sparse_image_warp(
117 image: TensorLike,
118 source_control_point_locations: TensorLike,
119 dest_control_point_locations: TensorLike,
120 interpolation_order: int = 2,
121 regularization_weight: FloatTensorLike = 0.0,
122 num_boundary_points: int = 0,
123 name: str = "sparse_image_warp",
124) -> tf.Tensor:
125 """Image warping using correspondences between sparse control points.
127 Apply a non-linear warp to the image, where the warp is specified by
128 the source and destination locations of a (potentially small) number of
129 control points. First, we use a polyharmonic spline
130 (`tfa.image.interpolate_spline`) to interpolate the displacements
131 between the corresponding control points to a dense flow field.
132 Then, we warp the image using this dense flow field
133 (`tfa.image.dense_image_warp`).
135 Let t index our control points. For `regularization_weight = 0`, we have:
136 warped_image[b, dest_control_point_locations[b, t, 0],
137 dest_control_point_locations[b, t, 1], :] =
138 image[b, source_control_point_locations[b, t, 0],
139 source_control_point_locations[b, t, 1], :].
141 For `regularization_weight > 0`, this condition is met approximately, since
142 regularized interpolation trades off smoothness of the interpolant vs.
143 reconstruction of the interpolant at the control points.
144 See `tfa.image.interpolate_spline` for further documentation of the
145 `interpolation_order` and `regularization_weight` arguments.
148 Args:
149 image: Either a 2-D float `Tensor` of shape `[height, width]`,
150 a 3-D `Tensor` of shape `[height, width, channels]`,
151 or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
152 `batch_size` is assumed as one when `image` is a 2-D or 3-D `Tensor`.
153 source_control_point_locations: `[batch_size, num_control_points, 2]` float
154 `Tensor`.
155 dest_control_point_locations: `[batch_size, num_control_points, 2]` float
156 `Tensor`.
157 interpolation_order: polynomial order used by the spline interpolation
158 regularization_weight: weight on smoothness regularizer in interpolation
159 num_boundary_points: How many zero-flow boundary points to include at
160 each image edge. Usage:
161 - `num_boundary_points=0`: don't add zero-flow points
162 - `num_boundary_points=1`: 4 corners of the image
163 - `num_boundary_points=2`: 4 corners and one in the middle of each edge
164 (8 points total)
165 - `num_boundary_points=n`: 4 corners and n-1 along each edge
166 name: A name for the operation (optional).
168 Note that `image` and `offsets` can be of type `tf.half`, `tf.float32`, or
169 `tf.float64`, and do not necessarily have to be the same type.
171 Returns:
172 warped_image: a float `Tensor` with the same shape and dtype as `image`.
173 flow_field: `[batch_size, height, width, 2]` float `Tensor` containing the
174 dense flow field produced by the interpolation.
175 """
177 image = tf.convert_to_tensor(image)
178 original_ndims = img_utils.get_ndims(image)
179 image = img_utils.to_4D_image(image)
181 source_control_point_locations = tf.convert_to_tensor(
182 source_control_point_locations
183 )
184 dest_control_point_locations = tf.convert_to_tensor(dest_control_point_locations)
186 control_point_flows = dest_control_point_locations - source_control_point_locations
188 clamp_boundaries = num_boundary_points > 0
189 boundary_points_per_edge = num_boundary_points - 1
191 with tf.name_scope(name or "sparse_image_warp"):
192 image_shape = tf.shape(image)
193 batch_size, image_height, image_width = (
194 image_shape[0],
195 image_shape[1],
196 image_shape[2],
197 )
199 # This generates the dense locations where the interpolant
200 # will be evaluated.
201 grid_locations = _get_grid_locations(image_height, image_width)
203 flattened_grid_locations = tf.reshape(
204 grid_locations, [image_height * image_width, 2]
205 )
207 flattened_grid_locations = tf.cast(
208 _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype
209 )
211 if clamp_boundaries:
212 (
213 dest_control_point_locations,
214 control_point_flows,
215 ) = _add_zero_flow_controls_at_boundary(
216 dest_control_point_locations,
217 control_point_flows,
218 image_height,
219 image_width,
220 boundary_points_per_edge,
221 )
223 flattened_flows = interpolate_spline(
224 dest_control_point_locations,
225 control_point_flows,
226 flattened_grid_locations,
227 interpolation_order,
228 regularization_weight,
229 )
231 dense_flows = tf.reshape(
232 flattened_flows, [batch_size, image_height, image_width, 2]
233 )
235 warped_image = dense_image_warp(image, dense_flows)
237 return img_utils.from_4D_image(warped_image, original_ndims), dense_flows