Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/dct_ops.py: 23%
82 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Discrete Cosine Transform ops."""
16import math as _math
18from tensorflow.python.framework import dtypes as _dtypes
19from tensorflow.python.framework import ops as _ops
20from tensorflow.python.framework import smart_cond
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.ops import array_ops as _array_ops
23from tensorflow.python.ops import math_ops as _math_ops
24from tensorflow.python.ops.signal import fft_ops
25from tensorflow.python.util import dispatch
26from tensorflow.python.util.tf_export import tf_export
29def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
30 """Checks that DCT/IDCT arguments are compatible and well formed."""
31 if axis != -1:
32 raise NotImplementedError("axis must be -1. Got: %s" % axis)
33 if n is not None and n < 1:
34 raise ValueError("n should be a positive integer or None")
35 if dct_type not in (1, 2, 3, 4):
36 raise ValueError("Types I, II, III and IV (I)DCT are supported.")
37 if dct_type == 1:
38 if norm == "ortho":
39 raise ValueError("Normalization is not supported for the Type-I DCT.")
40 if input_tensor.shape[-1] is not None and input_tensor.shape[-1] < 2:
41 raise ValueError(
42 "Type-I DCT requires the dimension to be greater than one.")
44 if norm not in (None, "ortho"):
45 raise ValueError(
46 "Unknown normalization. Expected None or 'ortho', got: %s" % norm)
49# TODO(rjryan): Implement `axis` parameter.
50@tf_export("signal.dct", v1=["signal.dct", "spectral.dct"])
51@dispatch.add_dispatch_support
52def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
53 """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
55 Types I, II, III and IV are supported.
56 Type I is implemented using a length `2N` padded `tf.signal.rfft`.
57 Type II is implemented using a length `2N` padded `tf.signal.rfft`, as
58 described here: [Type 2 DCT using 2N FFT padded (Makhoul)]
59 (https://dsp.stackexchange.com/a/10606).
60 Type III is a fairly straightforward inverse of Type II
61 (i.e. using a length `2N` padded `tf.signal.irfft`).
62 Type IV is calculated through 2N length DCT2 of padded signal and
63 picking the odd indices.
65 @compatibility(scipy)
66 Equivalent to [scipy.fftpack.dct]
67 (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.dct.html)
68 for Type-I, Type-II, Type-III and Type-IV DCT.
69 @end_compatibility
71 Args:
72 input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
73 signals to take the DCT of.
74 type: The DCT type to perform. Must be 1, 2, 3 or 4.
75 n: The length of the transform. If length is less than sequence length,
76 only the first n elements of the sequence are considered for the DCT.
77 If n is greater than the sequence length, zeros are padded and then
78 the DCT is computed as usual.
79 axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
80 norm: The normalization to apply. `None` for no normalization or `'ortho'`
81 for orthonormal normalization.
82 name: An optional name for the operation.
84 Returns:
85 A `[..., samples]` `float32`/`float64` `Tensor` containing the DCT of
86 `input`.
88 Raises:
89 ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is
90 not `-1`, `n` is not `None` or greater than 0,
91 or `norm` is not `None` or `'ortho'`.
92 ValueError: If `type` is `1` and `norm` is `ortho`.
94 [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform
95 """
96 _validate_dct_arguments(input, type, n, axis, norm)
97 return _dct_internal(input, type, n, axis, norm, name)
100def _dct_internal(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
101 """Computes the 1D Discrete Cosine Transform (DCT) of `input`.
103 This internal version of `dct` does not perform any validation and accepts a
104 dynamic value for `n` in the form of a rank 0 tensor.
106 Args:
107 input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
108 signals to take the DCT of.
109 type: The DCT type to perform. Must be 1, 2, 3 or 4.
110 n: The length of the transform. If length is less than sequence length,
111 only the first n elements of the sequence are considered for the DCT.
112 If n is greater than the sequence length, zeros are padded and then
113 the DCT is computed as usual. Can be an int or rank 0 tensor.
114 axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
115 norm: The normalization to apply. `None` for no normalization or `'ortho'`
116 for orthonormal normalization.
117 name: An optional name for the operation.
119 Returns:
120 A `[..., samples]` `float32`/`float64` `Tensor` containing the DCT of
121 `input`.
122 """
123 with _ops.name_scope(name, "dct", [input]):
124 input = _ops.convert_to_tensor(input)
125 zero = _ops.convert_to_tensor(0.0, dtype=input.dtype)
127 seq_len = (
128 tensor_shape.dimension_value(input.shape[-1]) or
129 _array_ops.shape(input)[-1])
130 if n is not None:
132 def truncate_input():
133 return input[..., 0:n]
135 def pad_input():
136 rank = len(input.shape)
137 padding = [[0, 0] for _ in range(rank)]
138 padding[rank - 1][1] = n - seq_len
139 padding = _ops.convert_to_tensor(padding, dtype=_dtypes.int32)
140 return _array_ops.pad(input, paddings=padding)
142 input = smart_cond.smart_cond(n <= seq_len, truncate_input, pad_input)
144 axis_dim = (tensor_shape.dimension_value(input.shape[-1])
145 or _array_ops.shape(input)[-1])
146 axis_dim_float = _math_ops.cast(axis_dim, input.dtype)
148 if type == 1:
149 dct1_input = _array_ops.concat([input, input[..., -2:0:-1]], axis=-1)
150 dct1 = _math_ops.real(fft_ops.rfft(dct1_input))
151 return dct1
153 if type == 2:
154 scale = 2.0 * _math_ops.exp(
155 _math_ops.complex(
156 zero, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 /
157 axis_dim_float))
159 # TODO(rjryan): Benchmark performance and memory usage of the various
160 # approaches to computing a DCT via the RFFT.
161 dct2 = _math_ops.real(
162 fft_ops.rfft(
163 input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
165 if norm == "ortho":
166 n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
167 n2 = n1 * _math.sqrt(2.0)
168 # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
169 weights = _array_ops.pad(
170 _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
171 constant_values=n2)
172 dct2 *= weights
174 return dct2
176 elif type == 3:
177 if norm == "ortho":
178 n1 = _math_ops.sqrt(axis_dim_float)
179 n2 = n1 * _math.sqrt(0.5)
180 # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
181 weights = _array_ops.pad(
182 _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
183 constant_values=n2)
184 input *= weights
185 else:
186 input *= axis_dim_float
187 scale = 2.0 * _math_ops.exp(
188 _math_ops.complex(
189 zero,
190 _math_ops.range(axis_dim_float) * _math.pi * 0.5 /
191 axis_dim_float))
192 dct3 = _math_ops.real(
193 fft_ops.irfft(
194 scale * _math_ops.complex(input, zero),
195 fft_length=[2 * axis_dim]))[..., :axis_dim]
197 return dct3
199 elif type == 4:
200 # DCT-2 of 2N length zero-padded signal, unnormalized.
201 dct2 = _dct_internal(input, type=2, n=2*axis_dim, axis=axis, norm=None)
202 # Get odd indices of DCT-2 of zero padded 2N signal to obtain
203 # DCT-4 of the original N length signal.
204 dct4 = dct2[..., 1::2]
205 if norm == "ortho":
206 dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float)
208 return dct4
211# TODO(rjryan): Implement `n` and `axis` parameters.
212@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
213@dispatch.add_dispatch_support
214def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
215 """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
217 Currently Types I, II, III, IV are supported. Type III is the inverse of
218 Type II, and vice versa.
220 Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
221 not `'ortho'`. That is:
222 `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`.
223 When `norm='ortho'`, we have:
224 `signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
226 @compatibility(scipy)
227 Equivalent to [scipy.fftpack.idct]
228 (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.idct.html)
229 for Type-I, Type-II, Type-III and Type-IV DCT.
230 @end_compatibility
232 Args:
233 input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
234 signals to take the DCT of.
235 type: The IDCT type to perform. Must be 1, 2, 3 or 4.
236 n: For future expansion. The length of the transform. Must be `None`.
237 axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
238 norm: The normalization to apply. `None` for no normalization or `'ortho'`
239 for orthonormal normalization.
240 name: An optional name for the operation.
242 Returns:
243 A `[..., samples]` `float32`/`float64` `Tensor` containing the IDCT of
244 `input`.
246 Raises:
247 ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is
248 not `-1`, or `norm` is not `None` or `'ortho'`.
250 [idct]:
251 https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
252 """
253 _validate_dct_arguments(input, type, n, axis, norm)
254 inverse_type = {1: 1, 2: 3, 3: 2, 4: 4}[type]
255 return _dct_internal(
256 input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)