Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/dct_ops.py: 23%

82 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Discrete Cosine Transform ops.""" 

16import math as _math 

17 

18from tensorflow.python.framework import dtypes as _dtypes 

19from tensorflow.python.framework import ops as _ops 

20from tensorflow.python.framework import smart_cond 

21from tensorflow.python.framework import tensor_shape 

22from tensorflow.python.ops import array_ops as _array_ops 

23from tensorflow.python.ops import math_ops as _math_ops 

24from tensorflow.python.ops.signal import fft_ops 

25from tensorflow.python.util import dispatch 

26from tensorflow.python.util.tf_export import tf_export 

27 

28 

29def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm): 

30 """Checks that DCT/IDCT arguments are compatible and well formed.""" 

31 if axis != -1: 

32 raise NotImplementedError("axis must be -1. Got: %s" % axis) 

33 if n is not None and n < 1: 

34 raise ValueError("n should be a positive integer or None") 

35 if dct_type not in (1, 2, 3, 4): 

36 raise ValueError("Types I, II, III and IV (I)DCT are supported.") 

37 if dct_type == 1: 

38 if norm == "ortho": 

39 raise ValueError("Normalization is not supported for the Type-I DCT.") 

40 if input_tensor.shape[-1] is not None and input_tensor.shape[-1] < 2: 

41 raise ValueError( 

42 "Type-I DCT requires the dimension to be greater than one.") 

43 

44 if norm not in (None, "ortho"): 

45 raise ValueError( 

46 "Unknown normalization. Expected None or 'ortho', got: %s" % norm) 

47 

48 

49# TODO(rjryan): Implement `axis` parameter. 

50@tf_export("signal.dct", v1=["signal.dct", "spectral.dct"]) 

51@dispatch.add_dispatch_support 

52def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin 

53 """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. 

54 

55 Types I, II, III and IV are supported. 

56 Type I is implemented using a length `2N` padded `tf.signal.rfft`. 

57 Type II is implemented using a length `2N` padded `tf.signal.rfft`, as 

58 described here: [Type 2 DCT using 2N FFT padded (Makhoul)] 

59 (https://dsp.stackexchange.com/a/10606). 

60 Type III is a fairly straightforward inverse of Type II 

61 (i.e. using a length `2N` padded `tf.signal.irfft`). 

62 Type IV is calculated through 2N length DCT2 of padded signal and 

63 picking the odd indices. 

64 

65 @compatibility(scipy) 

66 Equivalent to [scipy.fftpack.dct] 

67 (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.dct.html) 

68 for Type-I, Type-II, Type-III and Type-IV DCT. 

69 @end_compatibility 

70 

71 Args: 

72 input: A `[..., samples]` `float32`/`float64` `Tensor` containing the 

73 signals to take the DCT of. 

74 type: The DCT type to perform. Must be 1, 2, 3 or 4. 

75 n: The length of the transform. If length is less than sequence length, 

76 only the first n elements of the sequence are considered for the DCT. 

77 If n is greater than the sequence length, zeros are padded and then 

78 the DCT is computed as usual. 

79 axis: For future expansion. The axis to compute the DCT along. Must be `-1`. 

80 norm: The normalization to apply. `None` for no normalization or `'ortho'` 

81 for orthonormal normalization. 

82 name: An optional name for the operation. 

83 

84 Returns: 

85 A `[..., samples]` `float32`/`float64` `Tensor` containing the DCT of 

86 `input`. 

87 

88 Raises: 

89 ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is 

90 not `-1`, `n` is not `None` or greater than 0, 

91 or `norm` is not `None` or `'ortho'`. 

92 ValueError: If `type` is `1` and `norm` is `ortho`. 

93 

94 [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform 

95 """ 

96 _validate_dct_arguments(input, type, n, axis, norm) 

97 return _dct_internal(input, type, n, axis, norm, name) 

98 

99 

100def _dct_internal(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin 

101 """Computes the 1D Discrete Cosine Transform (DCT) of `input`. 

102 

103 This internal version of `dct` does not perform any validation and accepts a 

104 dynamic value for `n` in the form of a rank 0 tensor. 

105 

106 Args: 

107 input: A `[..., samples]` `float32`/`float64` `Tensor` containing the 

108 signals to take the DCT of. 

109 type: The DCT type to perform. Must be 1, 2, 3 or 4. 

110 n: The length of the transform. If length is less than sequence length, 

111 only the first n elements of the sequence are considered for the DCT. 

112 If n is greater than the sequence length, zeros are padded and then 

113 the DCT is computed as usual. Can be an int or rank 0 tensor. 

114 axis: For future expansion. The axis to compute the DCT along. Must be `-1`. 

115 norm: The normalization to apply. `None` for no normalization or `'ortho'` 

116 for orthonormal normalization. 

117 name: An optional name for the operation. 

118 

119 Returns: 

120 A `[..., samples]` `float32`/`float64` `Tensor` containing the DCT of 

121 `input`. 

122 """ 

123 with _ops.name_scope(name, "dct", [input]): 

124 input = _ops.convert_to_tensor(input) 

125 zero = _ops.convert_to_tensor(0.0, dtype=input.dtype) 

126 

127 seq_len = ( 

128 tensor_shape.dimension_value(input.shape[-1]) or 

129 _array_ops.shape(input)[-1]) 

130 if n is not None: 

131 

132 def truncate_input(): 

133 return input[..., 0:n] 

134 

135 def pad_input(): 

136 rank = len(input.shape) 

137 padding = [[0, 0] for _ in range(rank)] 

138 padding[rank - 1][1] = n - seq_len 

139 padding = _ops.convert_to_tensor(padding, dtype=_dtypes.int32) 

140 return _array_ops.pad(input, paddings=padding) 

141 

142 input = smart_cond.smart_cond(n <= seq_len, truncate_input, pad_input) 

143 

144 axis_dim = (tensor_shape.dimension_value(input.shape[-1]) 

145 or _array_ops.shape(input)[-1]) 

146 axis_dim_float = _math_ops.cast(axis_dim, input.dtype) 

147 

148 if type == 1: 

149 dct1_input = _array_ops.concat([input, input[..., -2:0:-1]], axis=-1) 

150 dct1 = _math_ops.real(fft_ops.rfft(dct1_input)) 

151 return dct1 

152 

153 if type == 2: 

154 scale = 2.0 * _math_ops.exp( 

155 _math_ops.complex( 

156 zero, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 / 

157 axis_dim_float)) 

158 

159 # TODO(rjryan): Benchmark performance and memory usage of the various 

160 # approaches to computing a DCT via the RFFT. 

161 dct2 = _math_ops.real( 

162 fft_ops.rfft( 

163 input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) 

164 

165 if norm == "ortho": 

166 n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) 

167 n2 = n1 * _math.sqrt(2.0) 

168 # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. 

169 weights = _array_ops.pad( 

170 _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], 

171 constant_values=n2) 

172 dct2 *= weights 

173 

174 return dct2 

175 

176 elif type == 3: 

177 if norm == "ortho": 

178 n1 = _math_ops.sqrt(axis_dim_float) 

179 n2 = n1 * _math.sqrt(0.5) 

180 # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. 

181 weights = _array_ops.pad( 

182 _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], 

183 constant_values=n2) 

184 input *= weights 

185 else: 

186 input *= axis_dim_float 

187 scale = 2.0 * _math_ops.exp( 

188 _math_ops.complex( 

189 zero, 

190 _math_ops.range(axis_dim_float) * _math.pi * 0.5 / 

191 axis_dim_float)) 

192 dct3 = _math_ops.real( 

193 fft_ops.irfft( 

194 scale * _math_ops.complex(input, zero), 

195 fft_length=[2 * axis_dim]))[..., :axis_dim] 

196 

197 return dct3 

198 

199 elif type == 4: 

200 # DCT-2 of 2N length zero-padded signal, unnormalized. 

201 dct2 = _dct_internal(input, type=2, n=2*axis_dim, axis=axis, norm=None) 

202 # Get odd indices of DCT-2 of zero padded 2N signal to obtain 

203 # DCT-4 of the original N length signal. 

204 dct4 = dct2[..., 1::2] 

205 if norm == "ortho": 

206 dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float) 

207 

208 return dct4 

209 

210 

211# TODO(rjryan): Implement `n` and `axis` parameters. 

212@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"]) 

213@dispatch.add_dispatch_support 

214def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin 

215 """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. 

216 

217 Currently Types I, II, III, IV are supported. Type III is the inverse of 

218 Type II, and vice versa. 

219 

220 Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is 

221 not `'ortho'`. That is: 

222 `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`. 

223 When `norm='ortho'`, we have: 

224 `signal == idct(dct(signal, norm='ortho'), norm='ortho')`. 

225 

226 @compatibility(scipy) 

227 Equivalent to [scipy.fftpack.idct] 

228 (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.idct.html) 

229 for Type-I, Type-II, Type-III and Type-IV DCT. 

230 @end_compatibility 

231 

232 Args: 

233 input: A `[..., samples]` `float32`/`float64` `Tensor` containing the 

234 signals to take the DCT of. 

235 type: The IDCT type to perform. Must be 1, 2, 3 or 4. 

236 n: For future expansion. The length of the transform. Must be `None`. 

237 axis: For future expansion. The axis to compute the DCT along. Must be `-1`. 

238 norm: The normalization to apply. `None` for no normalization or `'ortho'` 

239 for orthonormal normalization. 

240 name: An optional name for the operation. 

241 

242 Returns: 

243 A `[..., samples]` `float32`/`float64` `Tensor` containing the IDCT of 

244 `input`. 

245 

246 Raises: 

247 ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is 

248 not `-1`, or `norm` is not `None` or `'ortho'`. 

249 

250 [idct]: 

251 https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms 

252 """ 

253 _validate_dct_arguments(input, type, n, axis, norm) 

254 inverse_type = {1: 1, 2: 3, 3: 2, 4: 4}[type] 

255 return _dct_internal( 

256 input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)