Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/resampler_ops.py: 61%
18 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Python layer for Resampler."""
17import tensorflow as tf
19from tensorflow_addons.utils import types
20from tensorflow_addons.utils.resource_loader import LazySO
22from typing import Optional
24_resampler_so = LazySO("custom_ops/image/_resampler_ops.so")
27@tf.function
28def resampler(
29 data: types.TensorLike, warp: types.TensorLike, name: Optional[str] = None
30) -> tf.Tensor:
31 """Resamples input data at user defined coordinates.
33 The resampler currently only supports bilinear interpolation of 2D data.
35 Args:
36 data: `Tensor` of shape `[batch_size, data_height, data_width,
37 data_num_channels]` containing 2D data that will be resampled.
38 warp: Tensor of minimum rank 2 containing the coordinates at
39 which resampling will be performed. Since only bilinear
40 interpolation is currently supported, the last dimension of the
41 `warp` tensor must be 2, representing the `(x, y)` coordinate where
42 `x` is the index for width and `y` is the index for height.
43 name: Optional name of the op.
44 Returns:
45 Tensor of resampled values from `data`. The output tensor shape
46 is determined by the shape of the warp tensor. For example, if `data`
47 is of shape `[batch_size, data_height, data_width, data_num_channels]`
48 and warp of shape `[batch_size, dim_0, ... , dim_n, 2]` the output will
49 be of shape `[batch_size, dim_0, ... , dim_n, data_num_channels]`.
50 Raises:
51 ImportError: if the wrapper generated during compilation is not
52 present when the function is called.
53 """
54 with tf.name_scope(name or "resampler"):
55 data_tensor = tf.convert_to_tensor(data, name="data")
56 warp_tensor = tf.convert_to_tensor(warp, name="warp")
57 return _resampler_so.ops.addons_resampler(data_tensor, warp_tensor)
60@tf.RegisterGradient("Addons>Resampler")
61def _resampler_grad(op: types.TensorLike, grad_output: types.TensorLike) -> tf.Tensor:
62 data, warp = op.inputs
63 grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output")
64 return _resampler_so.ops.addons_resampler_grad(data, warp, grad_output_tensor)
67tf.no_gradient("Addons>ResamplerGrad")