Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/spectral_ops.py: 23%

133 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Spectral operations (e.g. Short-time Fourier Transform).""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_util 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import math_ops 

25from tensorflow.python.ops.signal import dct_ops 

26from tensorflow.python.ops.signal import fft_ops 

27from tensorflow.python.ops.signal import reconstruction_ops 

28from tensorflow.python.ops.signal import shape_ops 

29from tensorflow.python.ops.signal import window_ops 

30from tensorflow.python.util import dispatch 

31from tensorflow.python.util.tf_export import tf_export 

32 

33 

34@tf_export('signal.stft') 

35@dispatch.add_dispatch_support 

36def stft(signals, frame_length, frame_step, fft_length=None, 

37 window_fn=window_ops.hann_window, 

38 pad_end=False, name=None): 

39 """Computes the [Short-time Fourier Transform][stft] of `signals`. 

40 

41 Implemented with TPU/GPU-compatible ops and supports gradients. 

42 

43 Args: 

44 signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued 

45 signals. 

46 frame_length: An integer scalar `Tensor`. The window length in samples. 

47 frame_step: An integer scalar `Tensor`. The number of samples to step. 

48 fft_length: An integer scalar `Tensor`. The size of the FFT to apply. 

49 If not provided, uses the smallest power of 2 enclosing `frame_length`. 

50 window_fn: A callable that takes a window length and a `dtype` keyword 

51 argument and returns a `[window_length]` `Tensor` of samples in the 

52 provided datatype. If set to `None`, no windowing is used. 

53 pad_end: Whether to pad the end of `signals` with zeros when the provided 

54 frame length and step produces a frame that lies partially past its end. 

55 name: An optional name for the operation. 

56 

57 Returns: 

58 A `[..., frames, fft_unique_bins]` `Tensor` of `complex64`/`complex128` 

59 STFT values where `fft_unique_bins` is `fft_length // 2 + 1` (the unique 

60 components of the FFT). 

61 

62 Raises: 

63 ValueError: If `signals` is not at least rank 1, `frame_length` is 

64 not scalar, or `frame_step` is not scalar. 

65 

66 [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform 

67 """ 

68 with ops.name_scope(name, 'stft', [signals, frame_length, 

69 frame_step]): 

70 signals = ops.convert_to_tensor(signals, name='signals') 

71 signals.shape.with_rank_at_least(1) 

72 frame_length = ops.convert_to_tensor(frame_length, name='frame_length') 

73 frame_length.shape.assert_has_rank(0) 

74 frame_step = ops.convert_to_tensor(frame_step, name='frame_step') 

75 frame_step.shape.assert_has_rank(0) 

76 

77 if fft_length is None: 

78 fft_length = _enclosing_power_of_two(frame_length) 

79 else: 

80 fft_length = ops.convert_to_tensor(fft_length, name='fft_length') 

81 

82 framed_signals = shape_ops.frame( 

83 signals, frame_length, frame_step, pad_end=pad_end) 

84 

85 # Optionally window the framed signals. 

86 if window_fn is not None: 

87 window = window_fn(frame_length, dtype=framed_signals.dtype) 

88 framed_signals *= window 

89 

90 # fft_ops.rfft produces the (fft_length/2 + 1) unique components of the 

91 # FFT of the real windowed signals in framed_signals. 

92 return fft_ops.rfft(framed_signals, [fft_length]) 

93 

94 

95@tf_export('signal.inverse_stft_window_fn') 

96@dispatch.add_dispatch_support 

97def inverse_stft_window_fn(frame_step, 

98 forward_window_fn=window_ops.hann_window, 

99 name=None): 

100 """Generates a window function that can be used in `inverse_stft`. 

101 

102 Constructs a window that is equal to the forward window with a further 

103 pointwise amplitude correction. `inverse_stft_window_fn` is equivalent to 

104 `forward_window_fn` in the case where it would produce an exact inverse. 

105 

106 See examples in `inverse_stft` documentation for usage. 

107 

108 Args: 

109 frame_step: An integer scalar `Tensor`. The number of samples to step. 

110 forward_window_fn: window_fn used in the forward transform, `stft`. 

111 name: An optional name for the operation. 

112 

113 Returns: 

114 A callable that takes a window length and a `dtype` keyword argument and 

115 returns a `[window_length]` `Tensor` of samples in the provided datatype. 

116 The returned window is suitable for reconstructing original waveform in 

117 inverse_stft. 

118 """ 

119 def inverse_stft_window_fn_inner(frame_length, dtype): 

120 """Computes a window that can be used in `inverse_stft`. 

121 

122 Args: 

123 frame_length: An integer scalar `Tensor`. The window length in samples. 

124 dtype: Data type of waveform passed to `stft`. 

125 

126 Returns: 

127 A window suitable for reconstructing original waveform in `inverse_stft`. 

128 

129 Raises: 

130 ValueError: If `frame_length` is not scalar, `forward_window_fn` is not a 

131 callable that takes a window length and a `dtype` keyword argument and 

132 returns a `[window_length]` `Tensor` of samples in the provided datatype 

133 `frame_step` is not scalar, or `frame_step` is not scalar. 

134 """ 

135 with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]): 

136 frame_step_ = ops.convert_to_tensor(frame_step, name='frame_step') 

137 frame_step_.shape.assert_has_rank(0) 

138 frame_length = ops.convert_to_tensor(frame_length, name='frame_length') 

139 frame_length.shape.assert_has_rank(0) 

140 

141 # Use equation 7 from Griffin + Lim. 

142 forward_window = forward_window_fn(frame_length, dtype=dtype) 

143 denom = math_ops.square(forward_window) 

144 overlaps = -(-frame_length // frame_step_) # Ceiling division. # pylint: disable=invalid-unary-operand-type 

145 denom = array_ops.pad(denom, [(0, overlaps * frame_step_ - frame_length)]) 

146 denom = array_ops.reshape(denom, [overlaps, frame_step_]) 

147 denom = math_ops.reduce_sum(denom, 0, keepdims=True) 

148 denom = array_ops.tile(denom, [overlaps, 1]) 

149 denom = array_ops.reshape(denom, [overlaps * frame_step_]) 

150 

151 return forward_window / denom[:frame_length] 

152 return inverse_stft_window_fn_inner 

153 

154 

155@tf_export('signal.inverse_stft') 

156@dispatch.add_dispatch_support 

157def inverse_stft(stfts, 

158 frame_length, 

159 frame_step, 

160 fft_length=None, 

161 window_fn=window_ops.hann_window, 

162 name=None): 

163 """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`. 

164 

165 To reconstruct an original waveform, a complementary window function should 

166 be used with `inverse_stft`. Such a window function can be constructed with 

167 `tf.signal.inverse_stft_window_fn`. 

168 Example: 

169 

170 ```python 

171 frame_length = 400 

172 frame_step = 160 

173 waveform = tf.random.normal(dtype=tf.float32, shape=[1000]) 

174 stft = tf.signal.stft(waveform, frame_length, frame_step) 

175 inverse_stft = tf.signal.inverse_stft( 

176 stft, frame_length, frame_step, 

177 window_fn=tf.signal.inverse_stft_window_fn(frame_step)) 

178 ``` 

179 

180 If a custom `window_fn` is used with `tf.signal.stft`, it must be passed to 

181 `tf.signal.inverse_stft_window_fn`: 

182 

183 ```python 

184 frame_length = 400 

185 frame_step = 160 

186 window_fn = tf.signal.hamming_window 

187 waveform = tf.random.normal(dtype=tf.float32, shape=[1000]) 

188 stft = tf.signal.stft( 

189 waveform, frame_length, frame_step, window_fn=window_fn) 

190 inverse_stft = tf.signal.inverse_stft( 

191 stft, frame_length, frame_step, 

192 window_fn=tf.signal.inverse_stft_window_fn( 

193 frame_step, forward_window_fn=window_fn)) 

194 ``` 

195 

196 Implemented with TPU/GPU-compatible ops and supports gradients. 

197 

198 Args: 

199 stfts: A `complex64`/`complex128` `[..., frames, fft_unique_bins]` 

200 `Tensor` of STFT bins representing a batch of `fft_length`-point STFTs 

201 where `fft_unique_bins` is `fft_length // 2 + 1` 

202 frame_length: An integer scalar `Tensor`. The window length in samples. 

203 frame_step: An integer scalar `Tensor`. The number of samples to step. 

204 fft_length: An integer scalar `Tensor`. The size of the FFT that produced 

205 `stfts`. If not provided, uses the smallest power of 2 enclosing 

206 `frame_length`. 

207 window_fn: A callable that takes a window length and a `dtype` keyword 

208 argument and returns a `[window_length]` `Tensor` of samples in the 

209 provided datatype. If set to `None`, no windowing is used. 

210 name: An optional name for the operation. 

211 

212 Returns: 

213 A `[..., samples]` `Tensor` of `float32`/`float64` signals representing 

214 the inverse STFT for each input STFT in `stfts`. 

215 

216 Raises: 

217 ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar, 

218 `frame_step` is not scalar, or `fft_length` is not scalar. 

219 

220 [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform 

221 """ 

222 with ops.name_scope(name, 'inverse_stft', [stfts]): 

223 stfts = ops.convert_to_tensor(stfts, name='stfts') 

224 stfts.shape.with_rank_at_least(2) 

225 frame_length = ops.convert_to_tensor(frame_length, name='frame_length') 

226 frame_length.shape.assert_has_rank(0) 

227 frame_step = ops.convert_to_tensor(frame_step, name='frame_step') 

228 frame_step.shape.assert_has_rank(0) 

229 if fft_length is None: 

230 fft_length = _enclosing_power_of_two(frame_length) 

231 else: 

232 fft_length = ops.convert_to_tensor(fft_length, name='fft_length') 

233 fft_length.shape.assert_has_rank(0) 

234 

235 real_frames = fft_ops.irfft(stfts, [fft_length]) 

236 

237 # frame_length may be larger or smaller than fft_length, so we pad or 

238 # truncate real_frames to frame_length. 

239 frame_length_static = tensor_util.constant_value(frame_length) 

240 # If we don't know the shape of real_frames's inner dimension, pad and 

241 # truncate to frame_length. 

242 if (frame_length_static is None or real_frames.shape.ndims is None or 

243 real_frames.shape.as_list()[-1] is None): 

244 real_frames = real_frames[..., :frame_length] 

245 real_frames_rank = array_ops.rank(real_frames) 

246 real_frames_shape = array_ops.shape(real_frames) 

247 paddings = array_ops.concat( 

248 [array_ops.zeros([real_frames_rank - 1, 2], 

249 dtype=frame_length.dtype), 

250 [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0) 

251 real_frames = array_ops.pad(real_frames, paddings) 

252 # We know real_frames's last dimension and frame_length statically. If they 

253 # are different, then pad or truncate real_frames to frame_length. 

254 elif real_frames.shape.as_list()[-1] > frame_length_static: 

255 real_frames = real_frames[..., :frame_length_static] 

256 elif real_frames.shape.as_list()[-1] < frame_length_static: 

257 pad_amount = frame_length_static - real_frames.shape.as_list()[-1] 

258 real_frames = array_ops.pad(real_frames, 

259 [[0, 0]] * (real_frames.shape.ndims - 1) + 

260 [[0, pad_amount]]) 

261 

262 # The above code pads the inner dimension of real_frames to frame_length, 

263 # but it does so in a way that may not be shape-inference friendly. 

264 # Restore shape information if we are able to. 

265 if frame_length_static is not None and real_frames.shape.ndims is not None: 

266 real_frames.set_shape([None] * (real_frames.shape.ndims - 1) + 

267 [frame_length_static]) 

268 

269 # Optionally window and overlap-add the inner 2 dimensions of real_frames 

270 # into a single [samples] dimension. 

271 if window_fn is not None: 

272 window = window_fn(frame_length, dtype=stfts.dtype.real_dtype) 

273 real_frames *= window 

274 return reconstruction_ops.overlap_and_add(real_frames, frame_step) 

275 

276 

277def _enclosing_power_of_two(value): 

278 """Return 2**N for integer N such that 2**N >= value.""" 

279 value_static = tensor_util.constant_value(value) 

280 if value_static is not None: 

281 return constant_op.constant( 

282 int(2**np.ceil(np.log(value_static) / np.log(2.0))), value.dtype) 

283 return math_ops.cast( 

284 math_ops.pow( 

285 2.0, 

286 math_ops.ceil( 

287 math_ops.log(math_ops.cast(value, dtypes.float32)) / 

288 math_ops.log(2.0))), value.dtype) 

289 

290 

291@tf_export('signal.mdct') 

292@dispatch.add_dispatch_support 

293def mdct(signals, frame_length, window_fn=window_ops.vorbis_window, 

294 pad_end=False, norm=None, name=None): 

295 """Computes the [Modified Discrete Cosine Transform][mdct] of `signals`. 

296 

297 Implemented with TPU/GPU-compatible ops and supports gradients. 

298 

299 Args: 

300 signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued 

301 signals. 

302 frame_length: An integer scalar `Tensor`. The window length in samples 

303 which must be divisible by 4. 

304 window_fn: A callable that takes a frame_length and a `dtype` keyword 

305 argument and returns a `[frame_length]` `Tensor` of samples in the 

306 provided datatype. If set to `None`, a rectangular window with a scale of 

307 1/sqrt(2) is used. For perfect reconstruction of a signal from `mdct` 

308 followed by `inverse_mdct`, please use `tf.signal.vorbis_window`, 

309 `tf.signal.kaiser_bessel_derived_window` or `None`. If using another 

310 window function, make sure that w[n]^2 + w[n + frame_length // 2]^2 = 1 

311 and w[n] = w[frame_length - n - 1] for n = 0,...,frame_length // 2 - 1 to 

312 achieve perfect reconstruction. 

313 pad_end: Whether to pad the end of `signals` with zeros when the provided 

314 frame length and step produces a frame that lies partially past its end. 

315 norm: If it is None, unnormalized dct4 is used, if it is "ortho" 

316 orthonormal dct4 is used. 

317 name: An optional name for the operation. 

318 

319 Returns: 

320 A `[..., frames, frame_length // 2]` `Tensor` of `float32`/`float64` 

321 MDCT values where `frames` is roughly `samples // (frame_length // 2)` 

322 when `pad_end=False`. 

323 

324 Raises: 

325 ValueError: If `signals` is not at least rank 1, `frame_length` is 

326 not scalar, or `frame_length` is not a multiple of `4`. 

327 

328 [mdct]: https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform 

329 """ 

330 with ops.name_scope(name, 'mdct', [signals, frame_length]): 

331 signals = ops.convert_to_tensor(signals, name='signals') 

332 signals.shape.with_rank_at_least(1) 

333 frame_length = ops.convert_to_tensor(frame_length, name='frame_length') 

334 frame_length.shape.assert_has_rank(0) 

335 # Assert that frame_length is divisible by 4. 

336 frame_length_static = tensor_util.constant_value(frame_length) 

337 if frame_length_static is not None: 

338 if frame_length_static % 4 != 0: 

339 raise ValueError('The frame length must be a multiple of 4.') 

340 frame_step = ops.convert_to_tensor(frame_length_static // 2, 

341 dtype=frame_length.dtype) 

342 else: 

343 frame_step = frame_length // 2 

344 

345 framed_signals = shape_ops.frame( 

346 signals, frame_length, frame_step, pad_end=pad_end) 

347 

348 # Optionally window the framed signals. 

349 if window_fn is not None: 

350 window = window_fn(frame_length, dtype=framed_signals.dtype) 

351 framed_signals *= window 

352 else: 

353 framed_signals *= 1.0 / np.sqrt(2) 

354 

355 split_frames = array_ops.split(framed_signals, 4, axis=-1) 

356 frame_firsthalf = -array_ops.reverse(split_frames[2], 

357 [-1]) - split_frames[3] 

358 frame_secondhalf = split_frames[0] - array_ops.reverse(split_frames[1], 

359 [-1]) 

360 frames_rearranged = array_ops.concat((frame_firsthalf, frame_secondhalf), 

361 axis=-1) 

362 # Below call produces the (frame_length // 2) unique components of the 

363 # type 4 orthonormal DCT of the real windowed signals in frames_rearranged. 

364 return dct_ops.dct(frames_rearranged, type=4, norm=norm) 

365 

366 

367@tf_export('signal.inverse_mdct') 

368@dispatch.add_dispatch_support 

369def inverse_mdct(mdcts, 

370 window_fn=window_ops.vorbis_window, 

371 norm=None, 

372 name=None): 

373 """Computes the inverse modified DCT of `mdcts`. 

374 

375 To reconstruct an original waveform, the same window function should 

376 be used with `mdct` and `inverse_mdct`. 

377 

378 Example usage: 

379 

380 >>> @tf.function 

381 ... def compare_round_trip(): 

382 ... samples = 1000 

383 ... frame_length = 400 

384 ... halflen = frame_length // 2 

385 ... waveform = tf.random.normal(dtype=tf.float32, shape=[samples]) 

386 ... waveform_pad = tf.pad(waveform, [[halflen, 0],]) 

387 ... mdct = tf.signal.mdct(waveform_pad, frame_length, pad_end=True, 

388 ... window_fn=tf.signal.vorbis_window) 

389 ... inverse_mdct = tf.signal.inverse_mdct(mdct, 

390 ... window_fn=tf.signal.vorbis_window) 

391 ... inverse_mdct = inverse_mdct[halflen: halflen + samples] 

392 ... return waveform, inverse_mdct 

393 >>> waveform, inverse_mdct = compare_round_trip() 

394 >>> np.allclose(waveform.numpy(), inverse_mdct.numpy(), rtol=1e-3, atol=1e-4) 

395 True 

396 

397 Implemented with TPU/GPU-compatible ops and supports gradients. 

398 

399 Args: 

400 mdcts: A `float32`/`float64` `[..., frames, frame_length // 2]` 

401 `Tensor` of MDCT bins representing a batch of `frame_length // 2`-point 

402 MDCTs. 

403 window_fn: A callable that takes a frame_length and a `dtype` keyword 

404 argument and returns a `[frame_length]` `Tensor` of samples in the 

405 provided datatype. If set to `None`, a rectangular window with a scale of 

406 1/sqrt(2) is used. For perfect reconstruction of a signal from `mdct` 

407 followed by `inverse_mdct`, please use `tf.signal.vorbis_window`, 

408 `tf.signal.kaiser_bessel_derived_window` or `None`. If using another 

409 window function, make sure that w[n]^2 + w[n + frame_length // 2]^2 = 1 

410 and w[n] = w[frame_length - n - 1] for n = 0,...,frame_length // 2 - 1 to 

411 achieve perfect reconstruction. 

412 norm: If "ortho", orthonormal inverse DCT4 is performed, if it is None, 

413 a regular dct4 followed by scaling of `1/frame_length` is performed. 

414 name: An optional name for the operation. 

415 

416 Returns: 

417 A `[..., samples]` `Tensor` of `float32`/`float64` signals representing 

418 the inverse MDCT for each input MDCT in `mdcts` where `samples` is 

419 `(frames - 1) * (frame_length // 2) + frame_length`. 

420 

421 Raises: 

422 ValueError: If `mdcts` is not at least rank 2. 

423 

424 [mdct]: https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform 

425 """ 

426 with ops.name_scope(name, 'inverse_mdct', [mdcts]): 

427 mdcts = ops.convert_to_tensor(mdcts, name='mdcts') 

428 mdcts.shape.with_rank_at_least(2) 

429 half_len = math_ops.cast(mdcts.shape[-1], dtype=dtypes.int32) 

430 

431 if norm is None: 

432 half_len_float = math_ops.cast(half_len, dtype=mdcts.dtype) 

433 result_idct4 = (0.5 / half_len_float) * dct_ops.dct(mdcts, type=4) 

434 elif norm == 'ortho': 

435 result_idct4 = dct_ops.dct(mdcts, type=4, norm='ortho') 

436 split_result = array_ops.split(result_idct4, 2, axis=-1) 

437 real_frames = array_ops.concat((split_result[1], 

438 -array_ops.reverse(split_result[1], [-1]), 

439 -array_ops.reverse(split_result[0], [-1]), 

440 -split_result[0]), axis=-1) 

441 

442 # Optionally window and overlap-add the inner 2 dimensions of real_frames 

443 # into a single [samples] dimension. 

444 if window_fn is not None: 

445 window = window_fn(2 * half_len, dtype=mdcts.dtype) 

446 real_frames *= window 

447 else: 

448 real_frames *= 1.0 / np.sqrt(2) 

449 return reconstruction_ops.overlap_and_add(real_frames, half_len)