Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/optical_flow.py: 20%
80 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Tensorflow op performing correlation cost operation."""
17import tensorflow as tf
18from typeguard import typechecked
19from tensorflow_addons.utils.resource_loader import LazySO
21_correlation_cost_so = LazySO("custom_ops/layers/_correlation_cost_ops.so")
24def _correlation_cost(
25 input_a,
26 input_b,
27 kernel_size,
28 max_displacement,
29 stride_1,
30 stride_2,
31 pad,
32 data_format="channels_last",
33 name=None,
34):
35 """Correlation Cost Volume computation.
37 See [FlowNet: Learning Optical Flow with Convolutional Networks](https://arxiv.org/abs/1504.06852).
39 Computes a cost volume using correlation for two inputs. For feature
40 maps A, B with spatial dimensions w, h, c it computes
42 output(a, b) = sum_{l in [-k,k]**2} < I(a+l), J(b+l) >
44 where the patches of size K=2d + 1 are centered in position a resp. b.
46 The output shape is [B, C', H', W'], where
48 r = max_displacement / stride_2;
49 bd = max_displacement + (kernel_size - 1) / 2
50 C' = (2 * r + 1) ** 2
51 H' = H + 2 * (pad - bd) / stride_1
52 W' = W + 2 * (pad - bd) / stride_1
54 Note: When the data_format requests "channels_last", an additional explicit
55 transpose operation is executed.
57 Args:
58 input_a: A `Tensor` of the format specified by `data_format`.
59 input_b: A `Tensor` of the format specified by `data_format`.
60 kernel_size: An integer specifying the height and width of the
61 patch used to compute the per-patch costs.
62 max_displacement: An integer specifying the maximum search radius
63 for each position.
64 stride_1: An integer specifying the stride length in the input.
65 stride_2: An integer specifying the stride length in the patch.
66 pad: An integer specifying the paddings in height and width.
67 data_format: Specifies the data format.
68 Possible values are:
69 "channels_last" float [batch, height, width, channels]
70 "channels_first" float [batch, channels, height, width]
71 Defaults to `"channels_last"`.
72 name: A name for the operation (optional).
74 Returns:
75 A `Tensor` of the format specified by `data_format`.
76 """
78 with tf.name_scope(name or "correlation_cost"):
79 op_call = _correlation_cost_so.ops.addons_correlation_cost
81 if data_format == "channels_last":
82 op_data_format = "NHWC"
83 elif data_format == "channels_first":
84 op_data_format = "NCHW"
85 else:
86 raise ValueError(
87 "`data_format` must be either `channels_last` or" "`channels_first`"
88 )
90 ret = op_call(
91 input_a,
92 input_b,
93 kernel_size=kernel_size,
94 max_displacement=max_displacement,
95 stride_1=stride_1,
96 stride_2=stride_2,
97 pad=pad,
98 data_format=op_data_format,
99 )
100 if data_format == "channels_last":
101 # this is easier to maintain without
102 # specializing an additional cuda kernel
103 return tf.transpose(ret, [0, 2, 3, 1])
104 return ret
107@tf.RegisterGradient("Addons>CorrelationCost")
108def _correlation_cost_grad(op, grad_output):
109 kernel_size = op.get_attr("kernel_size")
110 max_displacement = op.get_attr("max_displacement")
111 stride_1 = op.get_attr("stride_1")
112 stride_2 = op.get_attr("stride_2")
113 pad = op.get_attr("pad")
114 data_format = op.get_attr("data_format")
116 input_a = tf.convert_to_tensor(op.inputs[0], name="input_a")
117 input_b = tf.convert_to_tensor(op.inputs[1], name="input_b")
118 grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output")
120 op_call = _correlation_cost_so.ops.addons_correlation_cost_grad
121 grads = op_call(
122 input_a,
123 input_b,
124 grad_output_tensor,
125 kernel_size=kernel_size,
126 max_displacement=max_displacement,
127 stride_1=stride_1,
128 stride_2=stride_2,
129 pad=pad,
130 data_format=data_format,
131 )
133 grad_input_a = tf.convert_to_tensor(grads[0], name="grad_input_a")
134 grad_input_b = tf.convert_to_tensor(grads[1], name="grad_input_b")
135 return [grad_input_a, grad_input_b]
138@tf.keras.utils.register_keras_serializable(package="Addons")
139class CorrelationCost(tf.keras.layers.Layer):
140 """Correlation Cost Layer.
142 This layer implements the correlation operation from [FlowNet Learning
143 Optical Flow with Convolutional Networks](https://arxiv.org/abs/1504.06852)(Fischer et al.).
145 Args:
146 kernel_size: An integer specifying the height and width of the
147 patch used to compute the per-patch costs.
148 max_displacement: An integer specifying the maximum search radius
149 for each position.
150 stride_1: An integer specifying the stride length in the input.
151 stride_2: An integer specifying the stride length in the patch.
152 pad: An integer specifying the paddings in height and width.
153 data_format: Specifies the data format.
154 Possible values are:
155 "channels_last" float [batch, height, width, channels]
156 "channels_first" float [batch, channels, height, width]
157 Defaults to `"channels_last"`.
158 """
160 @typechecked
161 def __init__(
162 self,
163 kernel_size: int,
164 max_displacement: int,
165 stride_1: int,
166 stride_2: int,
167 pad: int,
168 data_format: str,
169 **kwargs,
170 ):
171 self.kernel_size = kernel_size
172 self.max_displacement = max_displacement
173 self.stride_1 = stride_1
174 self.stride_2 = stride_2
175 self.pad = pad
177 if data_format != "channels_last" and data_format != "channels_first":
178 raise ValueError(
179 "`data_format` must be either `channels_last` or"
180 "`channels_first`, instead got %s" % data_format
181 )
183 self.data_format = data_format
185 super().__init__(**kwargs)
187 def build(self, input_shape):
188 if not isinstance(input_shape, list):
189 raise ValueError("Input must be a list of two Tensors to process")
190 super().build(input_shape)
192 def call(self, inputs):
193 if not isinstance(inputs, list):
194 raise ValueError("Input must be a list of two Tensors to process")
196 input_a = tf.convert_to_tensor(inputs[0])
197 input_b = tf.convert_to_tensor(inputs[1])
199 return _correlation_cost(
200 input_a,
201 input_b,
202 kernel_size=self.kernel_size,
203 max_displacement=self.max_displacement,
204 stride_1=self.stride_1,
205 stride_2=self.stride_2,
206 pad=self.pad,
207 data_format=self.data_format,
208 )
210 def compute_output_shape(self, input_shape):
211 assert isinstance(input_shape, list)
213 # Input validation
214 if len(input_shape) != 2:
215 raise ValueError("Input must be a list of two shapes")
217 for idx in range(4):
218 if input_shape[0][idx] != input_shape[1][idx]:
219 raise ValueError("Input shapes must match")
221 n = input_shape[0][0]
222 r = self.max_displacement // self.stride_2
223 bd = self.max_displacement + (self.kernel_size - 1) // 2
224 output_c = (2 * r + 1) ** 2
226 if self.data_format == "channels_first":
227 output_h = input_shape[0][2] + 2 * (self.pad - bd) // self.stride_1
228 output_w = input_shape[0][3] + 2 * (self.pad - bd) // self.stride_1
229 return [(n, output_c, output_h, output_w)]
231 elif self.data_format == "channels_last":
232 output_h = input_shape[0][1] + 2 * (self.pad - bd) // self.stride_1
233 output_w = input_shape[0][2] + 2 * (self.pad - bd) // self.stride_1
234 return [(n, output_h, output_w, output_c)]
235 else:
236 raise ValueError(
237 "`data_format` must be either `channels_last` or" "`channels_first`"
238 )
240 def get_config(self):
241 config = {
242 "kernel_size": self.kernel_size,
243 "max_displacement": self.max_displacement,
244 "stride_1": self.stride_1,
245 "stride_2": self.stride_2,
246 "pad": self.pad,
247 "data_format": self.data_format,
248 }
250 base_config = super().get_config()
251 return {**base_config, **config}