Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/interpolate_spline.py: 13%
79 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Polyharmonic spline interpolation."""
17import tensorflow as tf
18from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
20EPSILON = 0.0000000001
23def _cross_squared_distance_matrix(x: TensorLike, y: TensorLike) -> tf.Tensor:
24 """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
26 Computes the pairwise distances between rows of x and rows of y.
28 Args:
29 x: `[batch_size, n, d]` float `Tensor`.
30 y: `[batch_size, m, d]` float `Tensor`.
32 Returns:
33 squared_dists: `[batch_size, n, m]` float `Tensor`, where
34 `squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2`.
35 """
36 x_norm_squared = tf.reduce_sum(tf.square(x), 2)
37 y_norm_squared = tf.reduce_sum(tf.square(y), 2)
39 # Expand so that we can broadcast.
40 x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2)
41 y_norm_squared_tile = tf.expand_dims(y_norm_squared, 1)
43 x_y_transpose = tf.matmul(x, y, adjoint_b=True)
45 # squared_dists[b,i,j] = ||x_bi - y_bj||^2 =
46 # x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
47 squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile
49 return squared_dists
52def _pairwise_squared_distance_matrix(x: TensorLike) -> tf.Tensor:
53 """Pairwise squared distance among a (batch) matrix's rows (2nd dim).
55 This saves a bit of computation vs. using
56 `_cross_squared_distance_matrix(x, x)`
58 Args:
59 x: `[batch_size, n, d]` float `Tensor`.
61 Returns:
62 squared_dists: `[batch_size, n, n]` float `Tensor`, where
63 `squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2`.
64 """
66 x_x_transpose = tf.matmul(x, x, adjoint_b=True)
67 x_norm_squared = tf.linalg.diag_part(x_x_transpose)
68 x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2)
70 # squared_dists[b,i,j] = ||x_bi - x_bj||^2 =
71 # = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
72 squared_dists = (
73 x_norm_squared_tile
74 - 2 * x_x_transpose
75 + tf.transpose(x_norm_squared_tile, [0, 2, 1])
76 )
78 return squared_dists
81def _solve_interpolation(
82 train_points: TensorLike,
83 train_values: TensorLike,
84 order: int,
85 regularization_weight: FloatTensorLike,
86) -> TensorLike:
87 r"""Solve for interpolation coefficients.
89 Computes the coefficients of the polyharmonic interpolant for the
90 'training' data defined by `(train_points, train_values)` using the kernel
91 $\phi$.
93 Args:
94 train_points: `[b, n, d]` interpolation centers.
95 train_values: `[b, n, k]` function values.
96 order: order of the interpolation.
97 regularization_weight: weight to place on smoothness regularization term.
99 Returns:
100 w: `[b, n, k]` weights on each interpolation center
101 v: `[b, d, k]` weights on each input dimension
102 Raises:
103 ValueError: if d or k is not fully specified.
104 """
106 # These dimensions are set dynamically at runtime.
107 b, n, _ = tf.unstack(tf.shape(train_points), num=3)
109 d = train_points.shape[-1]
110 if d is None:
111 raise ValueError(
112 "The dimensionality of the input points (d) must be "
113 "statically-inferrable."
114 )
116 k = train_values.shape[-1]
117 if k is None:
118 raise ValueError(
119 "The dimensionality of the output values (k) must be "
120 "statically-inferrable."
121 )
123 # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
124 # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
125 # To account for python style guidelines we use
126 # matrix_a for A and matrix_b for B.
128 c = train_points
129 f = train_values
131 # Next, construct the linear system.
132 with tf.name_scope("construct_linear_system"):
134 matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n]
135 if regularization_weight > 0:
136 batch_identity_matrix = tf.expand_dims(tf.eye(n, dtype=c.dtype), 0)
137 matrix_a += regularization_weight * batch_identity_matrix
139 # Append ones to the feature values for the bias term
140 # in the linear model.
141 ones = tf.ones_like(c[..., :1], dtype=c.dtype)
142 matrix_b = tf.concat([c, ones], 2) # [b, n, d + 1]
144 # [b, n + d + 1, n]
145 left_block = tf.concat([matrix_a, tf.transpose(matrix_b, [0, 2, 1])], 1)
147 num_b_cols = matrix_b.get_shape()[2] # d + 1
148 lhs_zeros = tf.zeros([b, num_b_cols, num_b_cols], train_points.dtype)
149 right_block = tf.concat([matrix_b, lhs_zeros], 1) # [b, n + d + 1, d + 1]
150 lhs = tf.concat([left_block, right_block], 2) # [b, n + d + 1, n + d + 1]
152 rhs_zeros = tf.zeros([b, d + 1, k], train_points.dtype)
153 rhs = tf.concat([f, rhs_zeros], 1) # [b, n + d + 1, k]
155 # Then, solve the linear system and unpack the results.
156 with tf.name_scope("solve_linear_system"):
157 w_v = tf.linalg.solve(lhs, rhs)
158 w = w_v[:, :n, :]
159 v = w_v[:, n:, :]
161 return w, v
164def _apply_interpolation(
165 query_points: TensorLike,
166 train_points: TensorLike,
167 w: TensorLike,
168 v: TensorLike,
169 order: int,
170) -> TensorLike:
171 """Apply polyharmonic interpolation model to data.
173 Given coefficients w and v for the interpolation model, we evaluate
174 interpolated function values at query_points.
176 Args:
177 query_points: `[b, m, d]` x values to evaluate the interpolation at.
178 train_points: `[b, n, d]` x values that act as the interpolation centers
179 (the c variables in the wikipedia article).
180 w: `[b, n, k]` weights on each interpolation center.
181 v: `[b, d, k]` weights on each input dimension.
182 order: order of the interpolation.
184 Returns:
185 Polyharmonic interpolation evaluated at points defined in `query_points`.
186 """
188 # First, compute the contribution from the rbf term.
189 pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
190 phi_pairwise_dists = _phi(pairwise_dists, order)
192 rbf_term = tf.matmul(phi_pairwise_dists, w)
194 # Then, compute the contribution from the linear term.
195 # Pad query_points with ones, for the bias term in the linear model.
196 query_points_pad = tf.concat(
197 [query_points, tf.ones_like(query_points[..., :1], train_points.dtype)], 2
198 )
199 linear_term = tf.matmul(query_points_pad, v)
201 return rbf_term + linear_term
204def _phi(r: FloatTensorLike, order: int) -> FloatTensorLike:
205 """Coordinate-wise nonlinearity used to define the order of the
206 interpolation.
208 See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
210 Args:
211 r: input op.
212 order: interpolation order.
214 Returns:
215 `phi_k` evaluated coordinate-wise on `r`, for `k = r`.
216 """
218 # using EPSILON prevents log(0), sqrt0), etc.
219 # sqrt(0) is well-defined, but its gradient is not
220 with tf.name_scope("phi"):
221 if order == 1:
222 r = tf.maximum(r, EPSILON)
223 r = tf.sqrt(r)
224 return r
225 elif order == 2:
226 return 0.5 * r * tf.math.log(tf.maximum(r, EPSILON))
227 elif order == 4:
228 return 0.5 * tf.square(r) * tf.math.log(tf.maximum(r, EPSILON))
229 elif order % 2 == 0:
230 r = tf.maximum(r, EPSILON)
231 return 0.5 * tf.pow(r, 0.5 * order) * tf.math.log(r)
232 else:
233 r = tf.maximum(r, EPSILON)
234 return tf.pow(r, 0.5 * order)
237def interpolate_spline(
238 train_points: TensorLike,
239 train_values: TensorLike,
240 query_points: TensorLike,
241 order: int,
242 regularization_weight: FloatTensorLike = 0.0,
243 name: str = "interpolate_spline",
244) -> tf.Tensor:
245 r"""Interpolate signal using polyharmonic interpolation.
247 The interpolant has the form
248 $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$
250 This is a sum of two terms: (1) a weighted sum of radial basis function
251 (RBF) terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term
252 with a bias. The \\(c_i\\) vectors are 'training' points.
253 In the code, b is absorbed into v
254 by appending 1 as a final dimension to x. The coefficients w and v are
255 estimated such that the interpolant exactly fits the value of the function
256 at the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\),
257 and the vector w sums to 0. With these constraints, the coefficients
258 can be obtained by solving a linear system.
260 \\(\phi\\) is an RBF, parametrized by an interpolation
261 order. Using order=2 produces the well-known thin-plate spline.
263 We also provide the option to perform regularized interpolation. Here, the
264 interpolant is selected to trade off between the squared loss on the
265 training data and a certain measure of its curvature
266 ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)).
267 Using a regularization weight greater than zero has the effect that the
268 interpolant will no longer exactly fit the training data. However, it may
269 be less vulnerable to overfitting, particularly for high-order
270 interpolation.
272 Note the interpolation procedure is differentiable with respect to all
273 inputs besides the order parameter.
275 We support dynamically-shaped inputs, where batch_size, n, and m are None
276 at graph construction time. However, d and k must be known.
278 Args:
279 train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
280 locations. These do not need to be regularly-spaced.
281 train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional
282 values evaluated at train_points.
283 query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations
284 where we will output the interpolant's values.
285 order: order of the interpolation. Common values are 1 for
286 \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\)
287 (thin-plate spline), or 3 for \\(\phi(r) = r^3\\).
288 regularization_weight: weight placed on the regularization term.
289 This will depend substantially on the problem, and it should always be
290 tuned. For many problems, it is reasonable to use no regularization.
291 If using a non-zero value, we recommend a small value like 0.001.
292 name: name prefix for ops created by this function
294 Returns:
295 `[b, m, k]` float `Tensor` of query values. We use train_points and
296 train_values to perform polyharmonic interpolation. The query values are
297 the values of the interpolant evaluated at the locations specified in
298 query_points.
299 """
300 with tf.name_scope(name or "interpolate_spline"):
301 train_points = tf.convert_to_tensor(train_points)
302 train_values = tf.convert_to_tensor(train_values)
303 query_points = tf.convert_to_tensor(query_points)
305 # First, fit the spline to the observed data.
306 with tf.name_scope("solve"):
307 w, v = _solve_interpolation(
308 train_points, train_values, order, regularization_weight
309 )
311 # Then, evaluate the spline at the query locations.
312 with tf.name_scope("predict"):
313 query_values = _apply_interpolation(query_points, train_points, w, v, order)
315 return query_values