Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/optical_flow.py: 20%

80 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Tensorflow op performing correlation cost operation.""" 

16 

17import tensorflow as tf 

18from typeguard import typechecked 

19from tensorflow_addons.utils.resource_loader import LazySO 

20 

21_correlation_cost_so = LazySO("custom_ops/layers/_correlation_cost_ops.so") 

22 

23 

24def _correlation_cost( 

25 input_a, 

26 input_b, 

27 kernel_size, 

28 max_displacement, 

29 stride_1, 

30 stride_2, 

31 pad, 

32 data_format="channels_last", 

33 name=None, 

34): 

35 """Correlation Cost Volume computation. 

36 

37 See [FlowNet: Learning Optical Flow with Convolutional Networks](https://arxiv.org/abs/1504.06852). 

38 

39 Computes a cost volume using correlation for two inputs. For feature 

40 maps A, B with spatial dimensions w, h, c it computes 

41 

42 output(a, b) = sum_{l in [-k,k]**2} < I(a+l), J(b+l) > 

43 

44 where the patches of size K=2d + 1 are centered in position a resp. b. 

45 

46 The output shape is [B, C', H', W'], where 

47 

48 r = max_displacement / stride_2; 

49 bd = max_displacement + (kernel_size - 1) / 2 

50 C' = (2 * r + 1) ** 2 

51 H' = H + 2 * (pad - bd) / stride_1 

52 W' = W + 2 * (pad - bd) / stride_1 

53 

54 Note: When the data_format requests "channels_last", an additional explicit 

55 transpose operation is executed. 

56 

57 Args: 

58 input_a: A `Tensor` of the format specified by `data_format`. 

59 input_b: A `Tensor` of the format specified by `data_format`. 

60 kernel_size: An integer specifying the height and width of the 

61 patch used to compute the per-patch costs. 

62 max_displacement: An integer specifying the maximum search radius 

63 for each position. 

64 stride_1: An integer specifying the stride length in the input. 

65 stride_2: An integer specifying the stride length in the patch. 

66 pad: An integer specifying the paddings in height and width. 

67 data_format: Specifies the data format. 

68 Possible values are: 

69 "channels_last" float [batch, height, width, channels] 

70 "channels_first" float [batch, channels, height, width] 

71 Defaults to `"channels_last"`. 

72 name: A name for the operation (optional). 

73 

74 Returns: 

75 A `Tensor` of the format specified by `data_format`. 

76 """ 

77 

78 with tf.name_scope(name or "correlation_cost"): 

79 op_call = _correlation_cost_so.ops.addons_correlation_cost 

80 

81 if data_format == "channels_last": 

82 op_data_format = "NHWC" 

83 elif data_format == "channels_first": 

84 op_data_format = "NCHW" 

85 else: 

86 raise ValueError( 

87 "`data_format` must be either `channels_last` or" "`channels_first`" 

88 ) 

89 

90 ret = op_call( 

91 input_a, 

92 input_b, 

93 kernel_size=kernel_size, 

94 max_displacement=max_displacement, 

95 stride_1=stride_1, 

96 stride_2=stride_2, 

97 pad=pad, 

98 data_format=op_data_format, 

99 ) 

100 if data_format == "channels_last": 

101 # this is easier to maintain without 

102 # specializing an additional cuda kernel 

103 return tf.transpose(ret, [0, 2, 3, 1]) 

104 return ret 

105 

106 

107@tf.RegisterGradient("Addons>CorrelationCost") 

108def _correlation_cost_grad(op, grad_output): 

109 kernel_size = op.get_attr("kernel_size") 

110 max_displacement = op.get_attr("max_displacement") 

111 stride_1 = op.get_attr("stride_1") 

112 stride_2 = op.get_attr("stride_2") 

113 pad = op.get_attr("pad") 

114 data_format = op.get_attr("data_format") 

115 

116 input_a = tf.convert_to_tensor(op.inputs[0], name="input_a") 

117 input_b = tf.convert_to_tensor(op.inputs[1], name="input_b") 

118 grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output") 

119 

120 op_call = _correlation_cost_so.ops.addons_correlation_cost_grad 

121 grads = op_call( 

122 input_a, 

123 input_b, 

124 grad_output_tensor, 

125 kernel_size=kernel_size, 

126 max_displacement=max_displacement, 

127 stride_1=stride_1, 

128 stride_2=stride_2, 

129 pad=pad, 

130 data_format=data_format, 

131 ) 

132 

133 grad_input_a = tf.convert_to_tensor(grads[0], name="grad_input_a") 

134 grad_input_b = tf.convert_to_tensor(grads[1], name="grad_input_b") 

135 return [grad_input_a, grad_input_b] 

136 

137 

138@tf.keras.utils.register_keras_serializable(package="Addons") 

139class CorrelationCost(tf.keras.layers.Layer): 

140 """Correlation Cost Layer. 

141 

142 This layer implements the correlation operation from [FlowNet Learning 

143 Optical Flow with Convolutional Networks](https://arxiv.org/abs/1504.06852)(Fischer et al.). 

144 

145 Args: 

146 kernel_size: An integer specifying the height and width of the 

147 patch used to compute the per-patch costs. 

148 max_displacement: An integer specifying the maximum search radius 

149 for each position. 

150 stride_1: An integer specifying the stride length in the input. 

151 stride_2: An integer specifying the stride length in the patch. 

152 pad: An integer specifying the paddings in height and width. 

153 data_format: Specifies the data format. 

154 Possible values are: 

155 "channels_last" float [batch, height, width, channels] 

156 "channels_first" float [batch, channels, height, width] 

157 Defaults to `"channels_last"`. 

158 """ 

159 

160 @typechecked 

161 def __init__( 

162 self, 

163 kernel_size: int, 

164 max_displacement: int, 

165 stride_1: int, 

166 stride_2: int, 

167 pad: int, 

168 data_format: str, 

169 **kwargs, 

170 ): 

171 self.kernel_size = kernel_size 

172 self.max_displacement = max_displacement 

173 self.stride_1 = stride_1 

174 self.stride_2 = stride_2 

175 self.pad = pad 

176 

177 if data_format != "channels_last" and data_format != "channels_first": 

178 raise ValueError( 

179 "`data_format` must be either `channels_last` or" 

180 "`channels_first`, instead got %s" % data_format 

181 ) 

182 

183 self.data_format = data_format 

184 

185 super().__init__(**kwargs) 

186 

187 def build(self, input_shape): 

188 if not isinstance(input_shape, list): 

189 raise ValueError("Input must be a list of two Tensors to process") 

190 super().build(input_shape) 

191 

192 def call(self, inputs): 

193 if not isinstance(inputs, list): 

194 raise ValueError("Input must be a list of two Tensors to process") 

195 

196 input_a = tf.convert_to_tensor(inputs[0]) 

197 input_b = tf.convert_to_tensor(inputs[1]) 

198 

199 return _correlation_cost( 

200 input_a, 

201 input_b, 

202 kernel_size=self.kernel_size, 

203 max_displacement=self.max_displacement, 

204 stride_1=self.stride_1, 

205 stride_2=self.stride_2, 

206 pad=self.pad, 

207 data_format=self.data_format, 

208 ) 

209 

210 def compute_output_shape(self, input_shape): 

211 assert isinstance(input_shape, list) 

212 

213 # Input validation 

214 if len(input_shape) != 2: 

215 raise ValueError("Input must be a list of two shapes") 

216 

217 for idx in range(4): 

218 if input_shape[0][idx] != input_shape[1][idx]: 

219 raise ValueError("Input shapes must match") 

220 

221 n = input_shape[0][0] 

222 r = self.max_displacement // self.stride_2 

223 bd = self.max_displacement + (self.kernel_size - 1) // 2 

224 output_c = (2 * r + 1) ** 2 

225 

226 if self.data_format == "channels_first": 

227 output_h = input_shape[0][2] + 2 * (self.pad - bd) // self.stride_1 

228 output_w = input_shape[0][3] + 2 * (self.pad - bd) // self.stride_1 

229 return [(n, output_c, output_h, output_w)] 

230 

231 elif self.data_format == "channels_last": 

232 output_h = input_shape[0][1] + 2 * (self.pad - bd) // self.stride_1 

233 output_w = input_shape[0][2] + 2 * (self.pad - bd) // self.stride_1 

234 return [(n, output_h, output_w, output_c)] 

235 else: 

236 raise ValueError( 

237 "`data_format` must be either `channels_last` or" "`channels_first`" 

238 ) 

239 

240 def get_config(self): 

241 config = { 

242 "kernel_size": self.kernel_size, 

243 "max_displacement": self.max_displacement, 

244 "stride_1": self.stride_1, 

245 "stride_2": self.stride_2, 

246 "pad": self.pad, 

247 "data_format": self.data_format, 

248 } 

249 

250 base_config = super().get_config() 

251 return {**base_config, **config}