Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/spectral_ops.py: 23%
133 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Spectral operations (e.g. Short-time Fourier Transform)."""
17import numpy as np
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.signal import dct_ops
26from tensorflow.python.ops.signal import fft_ops
27from tensorflow.python.ops.signal import reconstruction_ops
28from tensorflow.python.ops.signal import shape_ops
29from tensorflow.python.ops.signal import window_ops
30from tensorflow.python.util import dispatch
31from tensorflow.python.util.tf_export import tf_export
34@tf_export('signal.stft')
35@dispatch.add_dispatch_support
36def stft(signals, frame_length, frame_step, fft_length=None,
37 window_fn=window_ops.hann_window,
38 pad_end=False, name=None):
39 """Computes the [Short-time Fourier Transform][stft] of `signals`.
41 Implemented with TPU/GPU-compatible ops and supports gradients.
43 Args:
44 signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued
45 signals.
46 frame_length: An integer scalar `Tensor`. The window length in samples.
47 frame_step: An integer scalar `Tensor`. The number of samples to step.
48 fft_length: An integer scalar `Tensor`. The size of the FFT to apply.
49 If not provided, uses the smallest power of 2 enclosing `frame_length`.
50 window_fn: A callable that takes a window length and a `dtype` keyword
51 argument and returns a `[window_length]` `Tensor` of samples in the
52 provided datatype. If set to `None`, no windowing is used.
53 pad_end: Whether to pad the end of `signals` with zeros when the provided
54 frame length and step produces a frame that lies partially past its end.
55 name: An optional name for the operation.
57 Returns:
58 A `[..., frames, fft_unique_bins]` `Tensor` of `complex64`/`complex128`
59 STFT values where `fft_unique_bins` is `fft_length // 2 + 1` (the unique
60 components of the FFT).
62 Raises:
63 ValueError: If `signals` is not at least rank 1, `frame_length` is
64 not scalar, or `frame_step` is not scalar.
66 [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
67 """
68 with ops.name_scope(name, 'stft', [signals, frame_length,
69 frame_step]):
70 signals = ops.convert_to_tensor(signals, name='signals')
71 signals.shape.with_rank_at_least(1)
72 frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
73 frame_length.shape.assert_has_rank(0)
74 frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
75 frame_step.shape.assert_has_rank(0)
77 if fft_length is None:
78 fft_length = _enclosing_power_of_two(frame_length)
79 else:
80 fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
82 framed_signals = shape_ops.frame(
83 signals, frame_length, frame_step, pad_end=pad_end)
85 # Optionally window the framed signals.
86 if window_fn is not None:
87 window = window_fn(frame_length, dtype=framed_signals.dtype)
88 framed_signals *= window
90 # fft_ops.rfft produces the (fft_length/2 + 1) unique components of the
91 # FFT of the real windowed signals in framed_signals.
92 return fft_ops.rfft(framed_signals, [fft_length])
95@tf_export('signal.inverse_stft_window_fn')
96@dispatch.add_dispatch_support
97def inverse_stft_window_fn(frame_step,
98 forward_window_fn=window_ops.hann_window,
99 name=None):
100 """Generates a window function that can be used in `inverse_stft`.
102 Constructs a window that is equal to the forward window with a further
103 pointwise amplitude correction. `inverse_stft_window_fn` is equivalent to
104 `forward_window_fn` in the case where it would produce an exact inverse.
106 See examples in `inverse_stft` documentation for usage.
108 Args:
109 frame_step: An integer scalar `Tensor`. The number of samples to step.
110 forward_window_fn: window_fn used in the forward transform, `stft`.
111 name: An optional name for the operation.
113 Returns:
114 A callable that takes a window length and a `dtype` keyword argument and
115 returns a `[window_length]` `Tensor` of samples in the provided datatype.
116 The returned window is suitable for reconstructing original waveform in
117 inverse_stft.
118 """
119 def inverse_stft_window_fn_inner(frame_length, dtype):
120 """Computes a window that can be used in `inverse_stft`.
122 Args:
123 frame_length: An integer scalar `Tensor`. The window length in samples.
124 dtype: Data type of waveform passed to `stft`.
126 Returns:
127 A window suitable for reconstructing original waveform in `inverse_stft`.
129 Raises:
130 ValueError: If `frame_length` is not scalar, `forward_window_fn` is not a
131 callable that takes a window length and a `dtype` keyword argument and
132 returns a `[window_length]` `Tensor` of samples in the provided datatype
133 `frame_step` is not scalar, or `frame_step` is not scalar.
134 """
135 with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
136 frame_step_ = ops.convert_to_tensor(frame_step, name='frame_step')
137 frame_step_.shape.assert_has_rank(0)
138 frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
139 frame_length.shape.assert_has_rank(0)
141 # Use equation 7 from Griffin + Lim.
142 forward_window = forward_window_fn(frame_length, dtype=dtype)
143 denom = math_ops.square(forward_window)
144 overlaps = -(-frame_length // frame_step_) # Ceiling division. # pylint: disable=invalid-unary-operand-type
145 denom = array_ops.pad(denom, [(0, overlaps * frame_step_ - frame_length)])
146 denom = array_ops.reshape(denom, [overlaps, frame_step_])
147 denom = math_ops.reduce_sum(denom, 0, keepdims=True)
148 denom = array_ops.tile(denom, [overlaps, 1])
149 denom = array_ops.reshape(denom, [overlaps * frame_step_])
151 return forward_window / denom[:frame_length]
152 return inverse_stft_window_fn_inner
155@tf_export('signal.inverse_stft')
156@dispatch.add_dispatch_support
157def inverse_stft(stfts,
158 frame_length,
159 frame_step,
160 fft_length=None,
161 window_fn=window_ops.hann_window,
162 name=None):
163 """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`.
165 To reconstruct an original waveform, a complementary window function should
166 be used with `inverse_stft`. Such a window function can be constructed with
167 `tf.signal.inverse_stft_window_fn`.
168 Example:
170 ```python
171 frame_length = 400
172 frame_step = 160
173 waveform = tf.random.normal(dtype=tf.float32, shape=[1000])
174 stft = tf.signal.stft(waveform, frame_length, frame_step)
175 inverse_stft = tf.signal.inverse_stft(
176 stft, frame_length, frame_step,
177 window_fn=tf.signal.inverse_stft_window_fn(frame_step))
178 ```
180 If a custom `window_fn` is used with `tf.signal.stft`, it must be passed to
181 `tf.signal.inverse_stft_window_fn`:
183 ```python
184 frame_length = 400
185 frame_step = 160
186 window_fn = tf.signal.hamming_window
187 waveform = tf.random.normal(dtype=tf.float32, shape=[1000])
188 stft = tf.signal.stft(
189 waveform, frame_length, frame_step, window_fn=window_fn)
190 inverse_stft = tf.signal.inverse_stft(
191 stft, frame_length, frame_step,
192 window_fn=tf.signal.inverse_stft_window_fn(
193 frame_step, forward_window_fn=window_fn))
194 ```
196 Implemented with TPU/GPU-compatible ops and supports gradients.
198 Args:
199 stfts: A `complex64`/`complex128` `[..., frames, fft_unique_bins]`
200 `Tensor` of STFT bins representing a batch of `fft_length`-point STFTs
201 where `fft_unique_bins` is `fft_length // 2 + 1`
202 frame_length: An integer scalar `Tensor`. The window length in samples.
203 frame_step: An integer scalar `Tensor`. The number of samples to step.
204 fft_length: An integer scalar `Tensor`. The size of the FFT that produced
205 `stfts`. If not provided, uses the smallest power of 2 enclosing
206 `frame_length`.
207 window_fn: A callable that takes a window length and a `dtype` keyword
208 argument and returns a `[window_length]` `Tensor` of samples in the
209 provided datatype. If set to `None`, no windowing is used.
210 name: An optional name for the operation.
212 Returns:
213 A `[..., samples]` `Tensor` of `float32`/`float64` signals representing
214 the inverse STFT for each input STFT in `stfts`.
216 Raises:
217 ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
218 `frame_step` is not scalar, or `fft_length` is not scalar.
220 [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
221 """
222 with ops.name_scope(name, 'inverse_stft', [stfts]):
223 stfts = ops.convert_to_tensor(stfts, name='stfts')
224 stfts.shape.with_rank_at_least(2)
225 frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
226 frame_length.shape.assert_has_rank(0)
227 frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
228 frame_step.shape.assert_has_rank(0)
229 if fft_length is None:
230 fft_length = _enclosing_power_of_two(frame_length)
231 else:
232 fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
233 fft_length.shape.assert_has_rank(0)
235 real_frames = fft_ops.irfft(stfts, [fft_length])
237 # frame_length may be larger or smaller than fft_length, so we pad or
238 # truncate real_frames to frame_length.
239 frame_length_static = tensor_util.constant_value(frame_length)
240 # If we don't know the shape of real_frames's inner dimension, pad and
241 # truncate to frame_length.
242 if (frame_length_static is None or real_frames.shape.ndims is None or
243 real_frames.shape.as_list()[-1] is None):
244 real_frames = real_frames[..., :frame_length]
245 real_frames_rank = array_ops.rank(real_frames)
246 real_frames_shape = array_ops.shape(real_frames)
247 paddings = array_ops.concat(
248 [array_ops.zeros([real_frames_rank - 1, 2],
249 dtype=frame_length.dtype),
250 [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0)
251 real_frames = array_ops.pad(real_frames, paddings)
252 # We know real_frames's last dimension and frame_length statically. If they
253 # are different, then pad or truncate real_frames to frame_length.
254 elif real_frames.shape.as_list()[-1] > frame_length_static:
255 real_frames = real_frames[..., :frame_length_static]
256 elif real_frames.shape.as_list()[-1] < frame_length_static:
257 pad_amount = frame_length_static - real_frames.shape.as_list()[-1]
258 real_frames = array_ops.pad(real_frames,
259 [[0, 0]] * (real_frames.shape.ndims - 1) +
260 [[0, pad_amount]])
262 # The above code pads the inner dimension of real_frames to frame_length,
263 # but it does so in a way that may not be shape-inference friendly.
264 # Restore shape information if we are able to.
265 if frame_length_static is not None and real_frames.shape.ndims is not None:
266 real_frames.set_shape([None] * (real_frames.shape.ndims - 1) +
267 [frame_length_static])
269 # Optionally window and overlap-add the inner 2 dimensions of real_frames
270 # into a single [samples] dimension.
271 if window_fn is not None:
272 window = window_fn(frame_length, dtype=stfts.dtype.real_dtype)
273 real_frames *= window
274 return reconstruction_ops.overlap_and_add(real_frames, frame_step)
277def _enclosing_power_of_two(value):
278 """Return 2**N for integer N such that 2**N >= value."""
279 value_static = tensor_util.constant_value(value)
280 if value_static is not None:
281 return constant_op.constant(
282 int(2**np.ceil(np.log(value_static) / np.log(2.0))), value.dtype)
283 return math_ops.cast(
284 math_ops.pow(
285 2.0,
286 math_ops.ceil(
287 math_ops.log(math_ops.cast(value, dtypes.float32)) /
288 math_ops.log(2.0))), value.dtype)
291@tf_export('signal.mdct')
292@dispatch.add_dispatch_support
293def mdct(signals, frame_length, window_fn=window_ops.vorbis_window,
294 pad_end=False, norm=None, name=None):
295 """Computes the [Modified Discrete Cosine Transform][mdct] of `signals`.
297 Implemented with TPU/GPU-compatible ops and supports gradients.
299 Args:
300 signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued
301 signals.
302 frame_length: An integer scalar `Tensor`. The window length in samples
303 which must be divisible by 4.
304 window_fn: A callable that takes a frame_length and a `dtype` keyword
305 argument and returns a `[frame_length]` `Tensor` of samples in the
306 provided datatype. If set to `None`, a rectangular window with a scale of
307 1/sqrt(2) is used. For perfect reconstruction of a signal from `mdct`
308 followed by `inverse_mdct`, please use `tf.signal.vorbis_window`,
309 `tf.signal.kaiser_bessel_derived_window` or `None`. If using another
310 window function, make sure that w[n]^2 + w[n + frame_length // 2]^2 = 1
311 and w[n] = w[frame_length - n - 1] for n = 0,...,frame_length // 2 - 1 to
312 achieve perfect reconstruction.
313 pad_end: Whether to pad the end of `signals` with zeros when the provided
314 frame length and step produces a frame that lies partially past its end.
315 norm: If it is None, unnormalized dct4 is used, if it is "ortho"
316 orthonormal dct4 is used.
317 name: An optional name for the operation.
319 Returns:
320 A `[..., frames, frame_length // 2]` `Tensor` of `float32`/`float64`
321 MDCT values where `frames` is roughly `samples // (frame_length // 2)`
322 when `pad_end=False`.
324 Raises:
325 ValueError: If `signals` is not at least rank 1, `frame_length` is
326 not scalar, or `frame_length` is not a multiple of `4`.
328 [mdct]: https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform
329 """
330 with ops.name_scope(name, 'mdct', [signals, frame_length]):
331 signals = ops.convert_to_tensor(signals, name='signals')
332 signals.shape.with_rank_at_least(1)
333 frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
334 frame_length.shape.assert_has_rank(0)
335 # Assert that frame_length is divisible by 4.
336 frame_length_static = tensor_util.constant_value(frame_length)
337 if frame_length_static is not None:
338 if frame_length_static % 4 != 0:
339 raise ValueError('The frame length must be a multiple of 4.')
340 frame_step = ops.convert_to_tensor(frame_length_static // 2,
341 dtype=frame_length.dtype)
342 else:
343 frame_step = frame_length // 2
345 framed_signals = shape_ops.frame(
346 signals, frame_length, frame_step, pad_end=pad_end)
348 # Optionally window the framed signals.
349 if window_fn is not None:
350 window = window_fn(frame_length, dtype=framed_signals.dtype)
351 framed_signals *= window
352 else:
353 framed_signals *= 1.0 / np.sqrt(2)
355 split_frames = array_ops.split(framed_signals, 4, axis=-1)
356 frame_firsthalf = -array_ops.reverse(split_frames[2],
357 [-1]) - split_frames[3]
358 frame_secondhalf = split_frames[0] - array_ops.reverse(split_frames[1],
359 [-1])
360 frames_rearranged = array_ops.concat((frame_firsthalf, frame_secondhalf),
361 axis=-1)
362 # Below call produces the (frame_length // 2) unique components of the
363 # type 4 orthonormal DCT of the real windowed signals in frames_rearranged.
364 return dct_ops.dct(frames_rearranged, type=4, norm=norm)
367@tf_export('signal.inverse_mdct')
368@dispatch.add_dispatch_support
369def inverse_mdct(mdcts,
370 window_fn=window_ops.vorbis_window,
371 norm=None,
372 name=None):
373 """Computes the inverse modified DCT of `mdcts`.
375 To reconstruct an original waveform, the same window function should
376 be used with `mdct` and `inverse_mdct`.
378 Example usage:
380 >>> @tf.function
381 ... def compare_round_trip():
382 ... samples = 1000
383 ... frame_length = 400
384 ... halflen = frame_length // 2
385 ... waveform = tf.random.normal(dtype=tf.float32, shape=[samples])
386 ... waveform_pad = tf.pad(waveform, [[halflen, 0],])
387 ... mdct = tf.signal.mdct(waveform_pad, frame_length, pad_end=True,
388 ... window_fn=tf.signal.vorbis_window)
389 ... inverse_mdct = tf.signal.inverse_mdct(mdct,
390 ... window_fn=tf.signal.vorbis_window)
391 ... inverse_mdct = inverse_mdct[halflen: halflen + samples]
392 ... return waveform, inverse_mdct
393 >>> waveform, inverse_mdct = compare_round_trip()
394 >>> np.allclose(waveform.numpy(), inverse_mdct.numpy(), rtol=1e-3, atol=1e-4)
395 True
397 Implemented with TPU/GPU-compatible ops and supports gradients.
399 Args:
400 mdcts: A `float32`/`float64` `[..., frames, frame_length // 2]`
401 `Tensor` of MDCT bins representing a batch of `frame_length // 2`-point
402 MDCTs.
403 window_fn: A callable that takes a frame_length and a `dtype` keyword
404 argument and returns a `[frame_length]` `Tensor` of samples in the
405 provided datatype. If set to `None`, a rectangular window with a scale of
406 1/sqrt(2) is used. For perfect reconstruction of a signal from `mdct`
407 followed by `inverse_mdct`, please use `tf.signal.vorbis_window`,
408 `tf.signal.kaiser_bessel_derived_window` or `None`. If using another
409 window function, make sure that w[n]^2 + w[n + frame_length // 2]^2 = 1
410 and w[n] = w[frame_length - n - 1] for n = 0,...,frame_length // 2 - 1 to
411 achieve perfect reconstruction.
412 norm: If "ortho", orthonormal inverse DCT4 is performed, if it is None,
413 a regular dct4 followed by scaling of `1/frame_length` is performed.
414 name: An optional name for the operation.
416 Returns:
417 A `[..., samples]` `Tensor` of `float32`/`float64` signals representing
418 the inverse MDCT for each input MDCT in `mdcts` where `samples` is
419 `(frames - 1) * (frame_length // 2) + frame_length`.
421 Raises:
422 ValueError: If `mdcts` is not at least rank 2.
424 [mdct]: https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform
425 """
426 with ops.name_scope(name, 'inverse_mdct', [mdcts]):
427 mdcts = ops.convert_to_tensor(mdcts, name='mdcts')
428 mdcts.shape.with_rank_at_least(2)
429 half_len = math_ops.cast(mdcts.shape[-1], dtype=dtypes.int32)
431 if norm is None:
432 half_len_float = math_ops.cast(half_len, dtype=mdcts.dtype)
433 result_idct4 = (0.5 / half_len_float) * dct_ops.dct(mdcts, type=4)
434 elif norm == 'ortho':
435 result_idct4 = dct_ops.dct(mdcts, type=4, norm='ortho')
436 split_result = array_ops.split(result_idct4, 2, axis=-1)
437 real_frames = array_ops.concat((split_result[1],
438 -array_ops.reverse(split_result[1], [-1]),
439 -array_ops.reverse(split_result[0], [-1]),
440 -split_result[0]), axis=-1)
442 # Optionally window and overlap-add the inner 2 dimensions of real_frames
443 # into a single [samples] dimension.
444 if window_fn is not None:
445 window = window_fn(2 * half_len, dtype=mdcts.dtype)
446 real_frames *= window
447 else:
448 real_frames *= 1.0 / np.sqrt(2)
449 return reconstruction_ops.overlap_and_add(real_frames, half_len)