Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/image/resampler_ops.py: 61%

18 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Python layer for Resampler.""" 

16 

17import tensorflow as tf 

18 

19from tensorflow_addons.utils import types 

20from tensorflow_addons.utils.resource_loader import LazySO 

21 

22from typing import Optional 

23 

24_resampler_so = LazySO("custom_ops/image/_resampler_ops.so") 

25 

26 

27@tf.function 

28def resampler( 

29 data: types.TensorLike, warp: types.TensorLike, name: Optional[str] = None 

30) -> tf.Tensor: 

31 """Resamples input data at user defined coordinates. 

32 

33 The resampler currently only supports bilinear interpolation of 2D data. 

34 

35 Args: 

36 data: `Tensor` of shape `[batch_size, data_height, data_width, 

37 data_num_channels]` containing 2D data that will be resampled. 

38 warp: Tensor of minimum rank 2 containing the coordinates at 

39 which resampling will be performed. Since only bilinear 

40 interpolation is currently supported, the last dimension of the 

41 `warp` tensor must be 2, representing the `(x, y)` coordinate where 

42 `x` is the index for width and `y` is the index for height. 

43 name: Optional name of the op. 

44 Returns: 

45 Tensor of resampled values from `data`. The output tensor shape 

46 is determined by the shape of the warp tensor. For example, if `data` 

47 is of shape `[batch_size, data_height, data_width, data_num_channels]` 

48 and warp of shape `[batch_size, dim_0, ... , dim_n, 2]` the output will 

49 be of shape `[batch_size, dim_0, ... , dim_n, data_num_channels]`. 

50 Raises: 

51 ImportError: if the wrapper generated during compilation is not 

52 present when the function is called. 

53 """ 

54 with tf.name_scope(name or "resampler"): 

55 data_tensor = tf.convert_to_tensor(data, name="data") 

56 warp_tensor = tf.convert_to_tensor(warp, name="warp") 

57 return _resampler_so.ops.addons_resampler(data_tensor, warp_tensor) 

58 

59 

60@tf.RegisterGradient("Addons>Resampler") 

61def _resampler_grad(op: types.TensorLike, grad_output: types.TensorLike) -> tf.Tensor: 

62 data, warp = op.inputs 

63 grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output") 

64 return _resampler_so.ops.addons_resampler_grad(data, warp, grad_output_tensor) 

65 

66 

67tf.no_gradient("Addons>ResamplerGrad")