Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/sparse_image_warp.py: 20%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Image warping using sparse flow defined at control points.""" 

16 

17import tensorflow as tf 

18 

19from tensorflow_addons.image import dense_image_warp 

20from tensorflow_addons.image import interpolate_spline 

21from tensorflow_addons.image import utils as img_utils 

22from tensorflow_addons.utils.types import TensorLike, FloatTensorLike 

23 

24 

25def _get_grid_locations( 

26 image_height: TensorLike, image_width: TensorLike 

27) -> TensorLike: 

28 """Wrapper for `tf.meshgrid`.""" 

29 y_range = tf.linspace(0, image_height - 1, image_height) 

30 x_range = tf.linspace(0, image_width - 1, image_width) 

31 y_grid, x_grid = tf.meshgrid(y_range, x_range, indexing="ij") 

32 return tf.stack((y_grid, x_grid), -1) 

33 

34 

35def _expand_to_minibatch(array: TensorLike, batch_size: TensorLike) -> TensorLike: 

36 """Tile arbitrarily-sized array to include new batch dimension.""" 

37 batch_size = tf.expand_dims(batch_size, 0) 

38 array_ones = tf.ones((tf.rank(array)), dtype=tf.dtypes.int32) 

39 tiles = tf.concat([batch_size, array_ones], axis=0) 

40 return tf.tile(tf.expand_dims(array, 0), tiles) 

41 

42 

43def _get_boundary_locations( 

44 image_height: TensorLike, image_width: TensorLike, num_points_per_edge: TensorLike 

45) -> TensorLike: 

46 """Compute evenly-spaced indices along edge of image.""" 

47 image_height = tf.cast(image_height, tf.float32) 

48 image_width = tf.cast(image_width, tf.float32) 

49 y_range = tf.linspace(0.0, image_height - 1, num_points_per_edge + 2) 

50 x_range = tf.linspace(0.0, image_width - 1, num_points_per_edge + 2) 

51 ys, xs = tf.meshgrid(y_range, x_range, indexing="ij") 

52 is_boundary = tf.logical_or( 

53 tf.logical_or(tf.equal(xs, 0), tf.equal(xs, image_width - 1)), 

54 tf.logical_or(tf.equal(ys, 0), tf.equal(ys, image_height - 1)), 

55 ) 

56 return tf.stack( 

57 [tf.boolean_mask(ys, is_boundary), tf.boolean_mask(xs, is_boundary)], axis=-1 

58 ) 

59 

60 

61def _add_zero_flow_controls_at_boundary( 

62 control_point_locations: TensorLike, 

63 control_point_flows: TensorLike, 

64 image_height: TensorLike, 

65 image_width: TensorLike, 

66 boundary_points_per_edge: TensorLike, 

67) -> tf.Tensor: 

68 """Add control points for zero-flow boundary conditions. 

69 

70 Augment the set of control points with extra points on the 

71 boundary of the image that have zero flow. 

72 

73 Args: 

74 control_point_locations: input control points. 

75 control_point_flows: their flows. 

76 image_height: image height. 

77 image_width: image width. 

78 boundary_points_per_edge: number of points to add in the middle of each 

79 edge (not including the corners). 

80 The total number of points added is 

81 `4 + 4*(boundary_points_per_edge)`. 

82 

83 Returns: 

84 merged_control_point_locations: augmented set of control point locations. 

85 merged_control_point_flows: augmented set of control point flows. 

86 """ 

87 

88 batch_size = tf.shape(control_point_locations)[0] 

89 

90 boundary_point_locations = _get_boundary_locations( 

91 image_height, image_width, boundary_points_per_edge 

92 ) 

93 

94 boundary_point_flows = tf.zeros([tf.shape(boundary_point_locations)[0], 2]) 

95 

96 type_to_use = control_point_locations.dtype 

97 boundary_point_locations = tf.cast( 

98 _expand_to_minibatch(boundary_point_locations, batch_size), type_to_use 

99 ) 

100 

101 boundary_point_flows = tf.cast( 

102 _expand_to_minibatch(boundary_point_flows, batch_size), type_to_use 

103 ) 

104 

105 merged_control_point_locations = tf.concat( 

106 [control_point_locations, boundary_point_locations], 1 

107 ) 

108 

109 merged_control_point_flows = tf.concat( 

110 [control_point_flows, boundary_point_flows], 1 

111 ) 

112 

113 return merged_control_point_locations, merged_control_point_flows 

114 

115 

116def sparse_image_warp( 

117 image: TensorLike, 

118 source_control_point_locations: TensorLike, 

119 dest_control_point_locations: TensorLike, 

120 interpolation_order: int = 2, 

121 regularization_weight: FloatTensorLike = 0.0, 

122 num_boundary_points: int = 0, 

123 name: str = "sparse_image_warp", 

124) -> tf.Tensor: 

125 """Image warping using correspondences between sparse control points. 

126 

127 Apply a non-linear warp to the image, where the warp is specified by 

128 the source and destination locations of a (potentially small) number of 

129 control points. First, we use a polyharmonic spline 

130 (`tfa.image.interpolate_spline`) to interpolate the displacements 

131 between the corresponding control points to a dense flow field. 

132 Then, we warp the image using this dense flow field 

133 (`tfa.image.dense_image_warp`). 

134 

135 Let t index our control points. For `regularization_weight = 0`, we have: 

136 warped_image[b, dest_control_point_locations[b, t, 0], 

137 dest_control_point_locations[b, t, 1], :] = 

138 image[b, source_control_point_locations[b, t, 0], 

139 source_control_point_locations[b, t, 1], :]. 

140 

141 For `regularization_weight > 0`, this condition is met approximately, since 

142 regularized interpolation trades off smoothness of the interpolant vs. 

143 reconstruction of the interpolant at the control points. 

144 See `tfa.image.interpolate_spline` for further documentation of the 

145 `interpolation_order` and `regularization_weight` arguments. 

146 

147 

148 Args: 

149 image: Either a 2-D float `Tensor` of shape `[height, width]`, 

150 a 3-D `Tensor` of shape `[height, width, channels]`, 

151 or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`. 

152 `batch_size` is assumed as one when `image` is a 2-D or 3-D `Tensor`. 

153 source_control_point_locations: `[batch_size, num_control_points, 2]` float 

154 `Tensor`. 

155 dest_control_point_locations: `[batch_size, num_control_points, 2]` float 

156 `Tensor`. 

157 interpolation_order: polynomial order used by the spline interpolation 

158 regularization_weight: weight on smoothness regularizer in interpolation 

159 num_boundary_points: How many zero-flow boundary points to include at 

160 each image edge. Usage: 

161 - `num_boundary_points=0`: don't add zero-flow points 

162 - `num_boundary_points=1`: 4 corners of the image 

163 - `num_boundary_points=2`: 4 corners and one in the middle of each edge 

164 (8 points total) 

165 - `num_boundary_points=n`: 4 corners and n-1 along each edge 

166 name: A name for the operation (optional). 

167 

168 Note that `image` and `offsets` can be of type `tf.half`, `tf.float32`, or 

169 `tf.float64`, and do not necessarily have to be the same type. 

170 

171 Returns: 

172 warped_image: a float `Tensor` with the same shape and dtype as `image`. 

173 flow_field: `[batch_size, height, width, 2]` float `Tensor` containing the 

174 dense flow field produced by the interpolation. 

175 """ 

176 

177 image = tf.convert_to_tensor(image) 

178 original_ndims = img_utils.get_ndims(image) 

179 image = img_utils.to_4D_image(image) 

180 

181 source_control_point_locations = tf.convert_to_tensor( 

182 source_control_point_locations 

183 ) 

184 dest_control_point_locations = tf.convert_to_tensor(dest_control_point_locations) 

185 

186 control_point_flows = dest_control_point_locations - source_control_point_locations 

187 

188 clamp_boundaries = num_boundary_points > 0 

189 boundary_points_per_edge = num_boundary_points - 1 

190 

191 with tf.name_scope(name or "sparse_image_warp"): 

192 image_shape = tf.shape(image) 

193 batch_size, image_height, image_width = ( 

194 image_shape[0], 

195 image_shape[1], 

196 image_shape[2], 

197 ) 

198 

199 # This generates the dense locations where the interpolant 

200 # will be evaluated. 

201 grid_locations = _get_grid_locations(image_height, image_width) 

202 

203 flattened_grid_locations = tf.reshape( 

204 grid_locations, [image_height * image_width, 2] 

205 ) 

206 

207 flattened_grid_locations = tf.cast( 

208 _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype 

209 ) 

210 

211 if clamp_boundaries: 

212 ( 

213 dest_control_point_locations, 

214 control_point_flows, 

215 ) = _add_zero_flow_controls_at_boundary( 

216 dest_control_point_locations, 

217 control_point_flows, 

218 image_height, 

219 image_width, 

220 boundary_points_per_edge, 

221 ) 

222 

223 flattened_flows = interpolate_spline( 

224 dest_control_point_locations, 

225 control_point_flows, 

226 flattened_grid_locations, 

227 interpolation_order, 

228 regularization_weight, 

229 ) 

230 

231 dense_flows = tf.reshape( 

232 flattened_flows, [batch_size, image_height, image_width, 2] 

233 ) 

234 

235 warped_image = dense_image_warp(image, dense_flows) 

236 

237 return img_utils.from_4D_image(warped_image, original_ndims), dense_flows