Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/fft_ops.py: 33%

223 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Fast-Fourier Transform ops.""" 

16import re 

17 

18import numpy as np 

19 

20from tensorflow.python.framework import dtypes as _dtypes 

21from tensorflow.python.framework import ops as _ops 

22from tensorflow.python.framework import tensor_util as _tensor_util 

23from tensorflow.python.ops import array_ops as _array_ops 

24from tensorflow.python.ops import array_ops_stack as _array_ops_stack 

25from tensorflow.python.ops import gen_spectral_ops 

26from tensorflow.python.ops import manip_ops 

27from tensorflow.python.ops import math_ops as _math_ops 

28from tensorflow.python.util import dispatch 

29from tensorflow.python.util.tf_export import tf_export 

30 

31 

32def _infer_fft_length_for_rfft(input_tensor, fft_rank): 

33 """Infers the `fft_length` argument for a `rank` RFFT from `input_tensor`.""" 

34 # A TensorShape for the inner fft_rank dimensions. 

35 fft_shape = input_tensor.get_shape()[-fft_rank:] 

36 

37 # If any dim is unknown, fall back to tensor-based math. 

38 if not fft_shape.is_fully_defined(): 

39 return _array_ops.shape(input_tensor)[-fft_rank:] 

40 

41 # Otherwise, return a constant. 

42 return _ops.convert_to_tensor(fft_shape.as_list(), _dtypes.int32) 

43 

44 

45def _infer_fft_length_for_irfft(input_tensor, fft_rank): 

46 """Infers the `fft_length` argument for a `rank` IRFFT from `input_tensor`.""" 

47 # A TensorShape for the inner fft_rank dimensions. 

48 fft_shape = input_tensor.get_shape()[-fft_rank:] 

49 

50 # If any dim is unknown, fall back to tensor-based math. 

51 if not fft_shape.is_fully_defined(): 

52 fft_length = _array_ops_stack.unstack( 

53 _array_ops.shape(input_tensor)[-fft_rank:]) 

54 fft_length[-1] = _math_ops.maximum(0, 2 * (fft_length[-1] - 1)) 

55 return _array_ops_stack.stack(fft_length) 

56 

57 # Otherwise, return a constant. 

58 fft_length = fft_shape.as_list() 

59 if fft_length: 

60 fft_length[-1] = max(0, 2 * (fft_length[-1] - 1)) 

61 return _ops.convert_to_tensor(fft_length, _dtypes.int32) 

62 

63 

64def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False): 

65 """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims.""" 

66 fft_shape = _tensor_util.constant_value_as_shape(fft_length) 

67 

68 # Edge case: skip padding empty tensors. 

69 if (input_tensor.shape.ndims is not None and 

70 any(dim.value == 0 for dim in input_tensor.shape.dims)): 

71 return input_tensor 

72 

73 # If we know the shapes ahead of time, we can either skip or pre-compute the 

74 # appropriate paddings. Otherwise, fall back to computing paddings in 

75 # TensorFlow. 

76 if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None: 

77 # Slice the last FFT-rank dimensions from input_tensor's shape. 

78 input_fft_shape = input_tensor.shape[-fft_shape.ndims:] # pylint: disable=invalid-unary-operand-type 

79 

80 if input_fft_shape.is_fully_defined(): 

81 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. 

82 if is_reverse: 

83 fft_shape = fft_shape[:-1].concatenate( 

84 fft_shape.dims[-1].value // 2 + 1) 

85 

86 paddings = [[0, max(fft_dim.value - input_dim.value, 0)] 

87 for fft_dim, input_dim in zip( 

88 fft_shape.dims, input_fft_shape.dims)] 

89 if any(pad > 0 for _, pad in paddings): 

90 outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims - 

91 fft_shape.ndims), 0) 

92 return _array_ops.pad(input_tensor, outer_paddings + paddings) 

93 return input_tensor 

94 

95 # If we can't determine the paddings ahead of time, then we have to pad. If 

96 # the paddings end up as zero, tf.pad has a special-case that does no work. 

97 input_rank = _array_ops.rank(input_tensor) 

98 input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:] 

99 outer_dims = _math_ops.maximum(0, input_rank - fft_rank) 

100 outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype) 

101 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. 

102 if is_reverse: 

103 fft_length = _array_ops.concat([fft_length[:-1], 

104 fft_length[-1:] // 2 + 1], 0) 

105 fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape) 

106 paddings = _array_ops.concat([outer_paddings, fft_paddings], 0) 

107 paddings = _array_ops_stack.stack( 

108 [_array_ops.zeros_like(paddings), paddings], axis=1) 

109 return _array_ops.pad(input_tensor, paddings) 

110 

111 

112def _rfft_wrapper(fft_fn, fft_rank, default_name): 

113 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" 

114 

115 def _rfft(input_tensor, fft_length=None, name=None): 

116 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" 

117 with _ops.name_scope(name, default_name, 

118 [input_tensor, fft_length]) as name: 

119 input_tensor = _ops.convert_to_tensor(input_tensor, 

120 preferred_dtype=_dtypes.float32) 

121 if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64): 

122 raise ValueError( 

123 "RFFT requires tf.float32 or tf.float64 inputs, got: %s" % 

124 input_tensor) 

125 real_dtype = input_tensor.dtype 

126 if real_dtype == _dtypes.float32: 

127 complex_dtype = _dtypes.complex64 

128 else: 

129 assert real_dtype == _dtypes.float64 

130 complex_dtype = _dtypes.complex128 

131 input_tensor.shape.with_rank_at_least(fft_rank) 

132 if fft_length is None: 

133 fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank) 

134 else: 

135 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) 

136 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length) 

137 

138 fft_length_static = _tensor_util.constant_value(fft_length) 

139 if fft_length_static is not None: 

140 fft_length = fft_length_static 

141 return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name) 

142 _rfft.__doc__ = re.sub(" Tcomplex.*?\n", "", fft_fn.__doc__) 

143 return _rfft 

144 

145 

146def _irfft_wrapper(ifft_fn, fft_rank, default_name): 

147 """Wrapper around gen_spectral_ops.irfft* that infers fft_length argument.""" 

148 

149 def _irfft(input_tensor, fft_length=None, name=None): 

150 """Wrapper irfft* that infers fft_length argument.""" 

151 with _ops.name_scope(name, default_name, 

152 [input_tensor, fft_length]) as name: 

153 input_tensor = _ops.convert_to_tensor(input_tensor, 

154 preferred_dtype=_dtypes.complex64) 

155 input_tensor.shape.with_rank_at_least(fft_rank) 

156 if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128): 

157 raise ValueError( 

158 "IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" % 

159 input_tensor) 

160 complex_dtype = input_tensor.dtype 

161 real_dtype = complex_dtype.real_dtype 

162 if fft_length is None: 

163 fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank) 

164 else: 

165 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) 

166 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, 

167 is_reverse=True) 

168 fft_length_static = _tensor_util.constant_value(fft_length) 

169 if fft_length_static is not None: 

170 fft_length = fft_length_static 

171 return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name) 

172 

173 _irfft.__doc__ = re.sub("`input`", "`input_tensor`", 

174 re.sub(" Treal.*?\n", "", ifft_fn.__doc__)) 

175 return _irfft 

176 

177 

178# FFT/IFFT 1/2/3D are exported via 

179# third_party/tensorflow/core/api_def/python_api/ 

180fft = gen_spectral_ops.fft 

181ifft = gen_spectral_ops.ifft 

182fft2d = gen_spectral_ops.fft2d 

183ifft2d = gen_spectral_ops.ifft2d 

184fft3d = gen_spectral_ops.fft3d 

185ifft3d = gen_spectral_ops.ifft3d 

186rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft") 

187tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])( 

188 dispatch.add_dispatch_support(rfft)) 

189irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft") 

190tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])( 

191 dispatch.add_dispatch_support(irfft)) 

192rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d") 

193tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])( 

194 dispatch.add_dispatch_support(rfft2d)) 

195irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d") 

196tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])( 

197 dispatch.add_dispatch_support(irfft2d)) 

198rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d") 

199tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])( 

200 dispatch.add_dispatch_support(rfft3d)) 

201irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d") 

202tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])( 

203 dispatch.add_dispatch_support(irfft3d)) 

204 

205 

206def _fft_size_for_grad(grad, rank): 

207 return _math_ops.reduce_prod(_array_ops.shape(grad)[-rank:]) 

208 

209 

210@_ops.RegisterGradient("FFT") 

211def _fft_grad(_, grad): 

212 size = _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype) 

213 return ifft(grad) * size 

214 

215 

216@_ops.RegisterGradient("IFFT") 

217def _ifft_grad(_, grad): 

218 rsize = _math_ops.cast( 

219 1. / _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype.real_dtype), 

220 grad.dtype) 

221 return fft(grad) * rsize 

222 

223 

224@_ops.RegisterGradient("FFT2D") 

225def _fft2d_grad(_, grad): 

226 size = _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype) 

227 return ifft2d(grad) * size 

228 

229 

230@_ops.RegisterGradient("IFFT2D") 

231def _ifft2d_grad(_, grad): 

232 rsize = _math_ops.cast( 

233 1. / _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype.real_dtype), 

234 grad.dtype) 

235 return fft2d(grad) * rsize 

236 

237 

238@_ops.RegisterGradient("FFT3D") 

239def _fft3d_grad(_, grad): 

240 size = _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype) 

241 return ifft3d(grad) * size 

242 

243 

244@_ops.RegisterGradient("IFFT3D") 

245def _ifft3d_grad(_, grad): 

246 rsize = _math_ops.cast( 

247 1. / _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype.real_dtype), 

248 grad.dtype) 

249 return fft3d(grad) * rsize 

250 

251 

252def _rfft_grad_helper(rank, irfft_fn): 

253 """Returns a gradient function for an RFFT of the provided rank.""" 

254 # Can't happen because we don't register a gradient for RFFT3D. 

255 assert rank in (1, 2), "Gradient for RFFT3D is not implemented." 

256 

257 def _grad(op, grad): 

258 """A gradient function for RFFT with the provided `rank` and `irfft_fn`.""" 

259 fft_length = op.inputs[1] 

260 complex_dtype = grad.dtype 

261 real_dtype = complex_dtype.real_dtype 

262 input_shape = _array_ops.shape(op.inputs[0]) 

263 is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype) 

264 

265 def _tile_for_broadcasting(matrix, t): 

266 expanded = _array_ops.reshape( 

267 matrix, 

268 _array_ops.concat([ 

269 _array_ops.ones([_array_ops.rank(t) - 2], _dtypes.int32), 

270 _array_ops.shape(matrix) 

271 ], 0)) 

272 return _array_ops.tile( 

273 expanded, _array_ops.concat([_array_ops.shape(t)[:-2], [1, 1]], 0)) 

274 

275 def _mask_matrix(length): 

276 """Computes t_n = exp(sqrt(-1) * pi * n^2 / line_len).""" 

277 # TODO(rjryan): Speed up computation of twiddle factors using the 

278 # following recurrence relation and cache them across invocations of RFFT. 

279 # 

280 # t_n = exp(sqrt(-1) * pi * n^2 / line_len) 

281 # for n = 0, 1,..., line_len-1. 

282 # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2 

283 a = _array_ops.tile( 

284 _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1)) 

285 b = _array_ops.transpose(a, [1, 0]) 

286 return _math_ops.exp( 

287 -2j * np.pi * _math_ops.cast(a * b, complex_dtype) / 

288 _math_ops.cast(length, complex_dtype)) 

289 

290 def _ymask(length): 

291 """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`.""" 

292 return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2), 

293 complex_dtype) 

294 

295 y0 = grad[..., 0:1] 

296 if rank == 1: 

297 ym = grad[..., -1:] 

298 extra_terms = y0 + is_even * ym * _ymask(input_shape[-1]) 

299 elif rank == 2: 

300 # Create a mask matrix for y0 and ym. 

301 base_mask = _mask_matrix(input_shape[-2]) 

302 

303 # Tile base_mask to match y0 in shape so that we can batch-matmul the 

304 # inner 2 dimensions. 

305 tiled_mask = _tile_for_broadcasting(base_mask, y0) 

306 

307 y0_term = _math_ops.matmul(tiled_mask, _math_ops.conj(y0)) 

308 extra_terms = y0_term 

309 

310 ym = grad[..., -1:] 

311 ym_term = _math_ops.matmul(tiled_mask, _math_ops.conj(ym)) 

312 

313 inner_dim = input_shape[-1] 

314 ym_term = _array_ops.tile( 

315 ym_term, 

316 _array_ops.concat([ 

317 _array_ops.ones([_array_ops.rank(grad) - 1], _dtypes.int32), 

318 [inner_dim] 

319 ], 0)) * _ymask(inner_dim) 

320 

321 extra_terms += is_even * ym_term 

322 

323 # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling 

324 # factor, plus some additional terms to make up for the components dropped 

325 # due to Hermitian symmetry. 

326 input_size = _math_ops.cast( 

327 _fft_size_for_grad(op.inputs[0], rank), real_dtype) 

328 the_irfft = irfft_fn(grad, fft_length) 

329 return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None 

330 

331 return _grad 

332 

333 

334def _irfft_grad_helper(rank, rfft_fn): 

335 """Returns a gradient function for an IRFFT of the provided rank.""" 

336 # Can't happen because we don't register a gradient for IRFFT3D. 

337 assert rank in (1, 2), "Gradient for IRFFT3D is not implemented." 

338 

339 def _grad(op, grad): 

340 """A gradient function for IRFFT with the provided `rank` and `rfft_fn`.""" 

341 # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs 

342 # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the 

343 # graph we special-case the situation where the FFT length and last 

344 # dimension of the input are known at graph construction time. 

345 fft_length = op.inputs[1] 

346 fft_length_static = _tensor_util.constant_value(fft_length) 

347 if fft_length_static is not None: 

348 fft_length = fft_length_static 

349 real_dtype = grad.dtype 

350 if real_dtype == _dtypes.float32: 

351 complex_dtype = _dtypes.complex64 

352 elif real_dtype == _dtypes.float64: 

353 complex_dtype = _dtypes.complex128 

354 is_odd = _math_ops.mod(fft_length[-1], 2) 

355 input_last_dimension = _array_ops.shape(op.inputs[0])[-1] 

356 mask = _array_ops.concat( 

357 [[1.0], 2.0 * _array_ops.ones( 

358 [input_last_dimension - 2 + is_odd], real_dtype), 

359 _array_ops.ones([1 - is_odd], real_dtype)], 0) 

360 

361 rsize = _math_ops.reciprocal(_math_ops.cast( 

362 _fft_size_for_grad(grad, rank), real_dtype)) 

363 

364 # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling 

365 # factor and a mask. The mask scales the gradient for the Hermitian 

366 # symmetric components of the RFFT by a factor of two, since these 

367 # components are de-duplicated in the RFFT. 

368 the_rfft = rfft_fn(grad, fft_length) 

369 return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None 

370 

371 return _grad 

372 

373 

374@tf_export("signal.fftshift") 

375@dispatch.add_dispatch_support 

376def fftshift(x, axes=None, name=None): 

377 """Shift the zero-frequency component to the center of the spectrum. 

378 

379 This function swaps half-spaces for all axes listed (defaults to all). 

380 Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. 

381 

382 @compatibility(numpy) 

383 Equivalent to numpy.fft.fftshift. 

384 https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.fftshift.html 

385 @end_compatibility 

386 

387 For example: 

388 

389 ```python 

390 x = tf.signal.fftshift([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.]) 

391 x.numpy() # array([-5., -4., -3., -2., -1., 0., 1., 2., 3., 4.]) 

392 ``` 

393 

394 Args: 

395 x: `Tensor`, input tensor. 

396 axes: `int` or shape `tuple`, optional Axes over which to shift. Default is 

397 None, which shifts all axes. 

398 name: An optional name for the operation. 

399 

400 Returns: 

401 A `Tensor`, The shifted tensor. 

402 """ 

403 with _ops.name_scope(name, "fftshift") as name: 

404 x = _ops.convert_to_tensor(x) 

405 if axes is None: 

406 axes = tuple(range(x.shape.ndims)) 

407 shift = _array_ops.shape(x) // 2 

408 elif isinstance(axes, int): 

409 shift = _array_ops.shape(x)[axes] // 2 

410 else: 

411 rank = _array_ops.rank(x) 

412 # allows negative axis 

413 axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes) 

414 shift = _array_ops.gather(_array_ops.shape(x), axes) // 2 

415 

416 return manip_ops.roll(x, shift, axes, name) 

417 

418 

419@tf_export("signal.ifftshift") 

420@dispatch.add_dispatch_support 

421def ifftshift(x, axes=None, name=None): 

422 """The inverse of fftshift. 

423 

424 Although identical for even-length x, 

425 the functions differ by one sample for odd-length x. 

426 

427 @compatibility(numpy) 

428 Equivalent to numpy.fft.ifftshift. 

429 https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.ifftshift.html 

430 @end_compatibility 

431 

432 For example: 

433 

434 ```python 

435 x = tf.signal.ifftshift([[ 0., 1., 2.],[ 3., 4., -4.],[-3., -2., -1.]]) 

436 x.numpy() # array([[ 4., -4., 3.],[-2., -1., -3.],[ 1., 2., 0.]]) 

437 ``` 

438 

439 Args: 

440 x: `Tensor`, input tensor. 

441 axes: `int` or shape `tuple` Axes over which to calculate. Defaults to None, 

442 which shifts all axes. 

443 name: An optional name for the operation. 

444 

445 Returns: 

446 A `Tensor`, The shifted tensor. 

447 """ 

448 with _ops.name_scope(name, "ifftshift") as name: 

449 x = _ops.convert_to_tensor(x) 

450 if axes is None: 

451 axes = tuple(range(x.shape.ndims)) 

452 shift = -(_array_ops.shape(x) // 2) 

453 elif isinstance(axes, int): 

454 shift = -(_array_ops.shape(x)[axes] // 2) 

455 else: 

456 rank = _array_ops.rank(x) 

457 # allows negative axis 

458 axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes) 

459 shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2) 

460 

461 return manip_ops.roll(x, shift, axes, name) 

462 

463 

464_ops.RegisterGradient("RFFT")(_rfft_grad_helper(1, irfft)) 

465_ops.RegisterGradient("IRFFT")(_irfft_grad_helper(1, rfft)) 

466_ops.RegisterGradient("RFFT2D")(_rfft_grad_helper(2, irfft2d)) 

467_ops.RegisterGradient("IRFFT2D")(_irfft_grad_helper(2, rfft2d))