Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/fft_ops.py: 33%
223 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Fast-Fourier Transform ops."""
16import re
18import numpy as np
20from tensorflow.python.framework import dtypes as _dtypes
21from tensorflow.python.framework import ops as _ops
22from tensorflow.python.framework import tensor_util as _tensor_util
23from tensorflow.python.ops import array_ops as _array_ops
24from tensorflow.python.ops import array_ops_stack as _array_ops_stack
25from tensorflow.python.ops import gen_spectral_ops
26from tensorflow.python.ops import manip_ops
27from tensorflow.python.ops import math_ops as _math_ops
28from tensorflow.python.util import dispatch
29from tensorflow.python.util.tf_export import tf_export
32def _infer_fft_length_for_rfft(input_tensor, fft_rank):
33 """Infers the `fft_length` argument for a `rank` RFFT from `input_tensor`."""
34 # A TensorShape for the inner fft_rank dimensions.
35 fft_shape = input_tensor.get_shape()[-fft_rank:]
37 # If any dim is unknown, fall back to tensor-based math.
38 if not fft_shape.is_fully_defined():
39 return _array_ops.shape(input_tensor)[-fft_rank:]
41 # Otherwise, return a constant.
42 return _ops.convert_to_tensor(fft_shape.as_list(), _dtypes.int32)
45def _infer_fft_length_for_irfft(input_tensor, fft_rank):
46 """Infers the `fft_length` argument for a `rank` IRFFT from `input_tensor`."""
47 # A TensorShape for the inner fft_rank dimensions.
48 fft_shape = input_tensor.get_shape()[-fft_rank:]
50 # If any dim is unknown, fall back to tensor-based math.
51 if not fft_shape.is_fully_defined():
52 fft_length = _array_ops_stack.unstack(
53 _array_ops.shape(input_tensor)[-fft_rank:])
54 fft_length[-1] = _math_ops.maximum(0, 2 * (fft_length[-1] - 1))
55 return _array_ops_stack.stack(fft_length)
57 # Otherwise, return a constant.
58 fft_length = fft_shape.as_list()
59 if fft_length:
60 fft_length[-1] = max(0, 2 * (fft_length[-1] - 1))
61 return _ops.convert_to_tensor(fft_length, _dtypes.int32)
64def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False):
65 """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims."""
66 fft_shape = _tensor_util.constant_value_as_shape(fft_length)
68 # Edge case: skip padding empty tensors.
69 if (input_tensor.shape.ndims is not None and
70 any(dim.value == 0 for dim in input_tensor.shape.dims)):
71 return input_tensor
73 # If we know the shapes ahead of time, we can either skip or pre-compute the
74 # appropriate paddings. Otherwise, fall back to computing paddings in
75 # TensorFlow.
76 if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None:
77 # Slice the last FFT-rank dimensions from input_tensor's shape.
78 input_fft_shape = input_tensor.shape[-fft_shape.ndims:] # pylint: disable=invalid-unary-operand-type
80 if input_fft_shape.is_fully_defined():
81 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
82 if is_reverse:
83 fft_shape = fft_shape[:-1].concatenate(
84 fft_shape.dims[-1].value // 2 + 1)
86 paddings = [[0, max(fft_dim.value - input_dim.value, 0)]
87 for fft_dim, input_dim in zip(
88 fft_shape.dims, input_fft_shape.dims)]
89 if any(pad > 0 for _, pad in paddings):
90 outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims -
91 fft_shape.ndims), 0)
92 return _array_ops.pad(input_tensor, outer_paddings + paddings)
93 return input_tensor
95 # If we can't determine the paddings ahead of time, then we have to pad. If
96 # the paddings end up as zero, tf.pad has a special-case that does no work.
97 input_rank = _array_ops.rank(input_tensor)
98 input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:]
99 outer_dims = _math_ops.maximum(0, input_rank - fft_rank)
100 outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype)
101 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
102 if is_reverse:
103 fft_length = _array_ops.concat([fft_length[:-1],
104 fft_length[-1:] // 2 + 1], 0)
105 fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape)
106 paddings = _array_ops.concat([outer_paddings, fft_paddings], 0)
107 paddings = _array_ops_stack.stack(
108 [_array_ops.zeros_like(paddings), paddings], axis=1)
109 return _array_ops.pad(input_tensor, paddings)
112def _rfft_wrapper(fft_fn, fft_rank, default_name):
113 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
115 def _rfft(input_tensor, fft_length=None, name=None):
116 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
117 with _ops.name_scope(name, default_name,
118 [input_tensor, fft_length]) as name:
119 input_tensor = _ops.convert_to_tensor(input_tensor,
120 preferred_dtype=_dtypes.float32)
121 if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64):
122 raise ValueError(
123 "RFFT requires tf.float32 or tf.float64 inputs, got: %s" %
124 input_tensor)
125 real_dtype = input_tensor.dtype
126 if real_dtype == _dtypes.float32:
127 complex_dtype = _dtypes.complex64
128 else:
129 assert real_dtype == _dtypes.float64
130 complex_dtype = _dtypes.complex128
131 input_tensor.shape.with_rank_at_least(fft_rank)
132 if fft_length is None:
133 fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank)
134 else:
135 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
136 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
138 fft_length_static = _tensor_util.constant_value(fft_length)
139 if fft_length_static is not None:
140 fft_length = fft_length_static
141 return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
142 _rfft.__doc__ = re.sub(" Tcomplex.*?\n", "", fft_fn.__doc__)
143 return _rfft
146def _irfft_wrapper(ifft_fn, fft_rank, default_name):
147 """Wrapper around gen_spectral_ops.irfft* that infers fft_length argument."""
149 def _irfft(input_tensor, fft_length=None, name=None):
150 """Wrapper irfft* that infers fft_length argument."""
151 with _ops.name_scope(name, default_name,
152 [input_tensor, fft_length]) as name:
153 input_tensor = _ops.convert_to_tensor(input_tensor,
154 preferred_dtype=_dtypes.complex64)
155 input_tensor.shape.with_rank_at_least(fft_rank)
156 if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128):
157 raise ValueError(
158 "IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" %
159 input_tensor)
160 complex_dtype = input_tensor.dtype
161 real_dtype = complex_dtype.real_dtype
162 if fft_length is None:
163 fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank)
164 else:
165 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
166 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
167 is_reverse=True)
168 fft_length_static = _tensor_util.constant_value(fft_length)
169 if fft_length_static is not None:
170 fft_length = fft_length_static
171 return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
173 _irfft.__doc__ = re.sub("`input`", "`input_tensor`",
174 re.sub(" Treal.*?\n", "", ifft_fn.__doc__))
175 return _irfft
178# FFT/IFFT 1/2/3D are exported via
179# third_party/tensorflow/core/api_def/python_api/
180fft = gen_spectral_ops.fft
181ifft = gen_spectral_ops.ifft
182fft2d = gen_spectral_ops.fft2d
183ifft2d = gen_spectral_ops.ifft2d
184fft3d = gen_spectral_ops.fft3d
185ifft3d = gen_spectral_ops.ifft3d
186rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft")
187tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(
188 dispatch.add_dispatch_support(rfft))
189irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft")
190tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(
191 dispatch.add_dispatch_support(irfft))
192rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d")
193tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(
194 dispatch.add_dispatch_support(rfft2d))
195irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d")
196tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(
197 dispatch.add_dispatch_support(irfft2d))
198rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d")
199tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(
200 dispatch.add_dispatch_support(rfft3d))
201irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d")
202tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(
203 dispatch.add_dispatch_support(irfft3d))
206def _fft_size_for_grad(grad, rank):
207 return _math_ops.reduce_prod(_array_ops.shape(grad)[-rank:])
210@_ops.RegisterGradient("FFT")
211def _fft_grad(_, grad):
212 size = _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype)
213 return ifft(grad) * size
216@_ops.RegisterGradient("IFFT")
217def _ifft_grad(_, grad):
218 rsize = _math_ops.cast(
219 1. / _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype.real_dtype),
220 grad.dtype)
221 return fft(grad) * rsize
224@_ops.RegisterGradient("FFT2D")
225def _fft2d_grad(_, grad):
226 size = _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype)
227 return ifft2d(grad) * size
230@_ops.RegisterGradient("IFFT2D")
231def _ifft2d_grad(_, grad):
232 rsize = _math_ops.cast(
233 1. / _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype.real_dtype),
234 grad.dtype)
235 return fft2d(grad) * rsize
238@_ops.RegisterGradient("FFT3D")
239def _fft3d_grad(_, grad):
240 size = _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype)
241 return ifft3d(grad) * size
244@_ops.RegisterGradient("IFFT3D")
245def _ifft3d_grad(_, grad):
246 rsize = _math_ops.cast(
247 1. / _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype.real_dtype),
248 grad.dtype)
249 return fft3d(grad) * rsize
252def _rfft_grad_helper(rank, irfft_fn):
253 """Returns a gradient function for an RFFT of the provided rank."""
254 # Can't happen because we don't register a gradient for RFFT3D.
255 assert rank in (1, 2), "Gradient for RFFT3D is not implemented."
257 def _grad(op, grad):
258 """A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
259 fft_length = op.inputs[1]
260 complex_dtype = grad.dtype
261 real_dtype = complex_dtype.real_dtype
262 input_shape = _array_ops.shape(op.inputs[0])
263 is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype)
265 def _tile_for_broadcasting(matrix, t):
266 expanded = _array_ops.reshape(
267 matrix,
268 _array_ops.concat([
269 _array_ops.ones([_array_ops.rank(t) - 2], _dtypes.int32),
270 _array_ops.shape(matrix)
271 ], 0))
272 return _array_ops.tile(
273 expanded, _array_ops.concat([_array_ops.shape(t)[:-2], [1, 1]], 0))
275 def _mask_matrix(length):
276 """Computes t_n = exp(sqrt(-1) * pi * n^2 / line_len)."""
277 # TODO(rjryan): Speed up computation of twiddle factors using the
278 # following recurrence relation and cache them across invocations of RFFT.
279 #
280 # t_n = exp(sqrt(-1) * pi * n^2 / line_len)
281 # for n = 0, 1,..., line_len-1.
282 # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2
283 a = _array_ops.tile(
284 _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1))
285 b = _array_ops.transpose(a, [1, 0])
286 return _math_ops.exp(
287 -2j * np.pi * _math_ops.cast(a * b, complex_dtype) /
288 _math_ops.cast(length, complex_dtype))
290 def _ymask(length):
291 """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
292 return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2),
293 complex_dtype)
295 y0 = grad[..., 0:1]
296 if rank == 1:
297 ym = grad[..., -1:]
298 extra_terms = y0 + is_even * ym * _ymask(input_shape[-1])
299 elif rank == 2:
300 # Create a mask matrix for y0 and ym.
301 base_mask = _mask_matrix(input_shape[-2])
303 # Tile base_mask to match y0 in shape so that we can batch-matmul the
304 # inner 2 dimensions.
305 tiled_mask = _tile_for_broadcasting(base_mask, y0)
307 y0_term = _math_ops.matmul(tiled_mask, _math_ops.conj(y0))
308 extra_terms = y0_term
310 ym = grad[..., -1:]
311 ym_term = _math_ops.matmul(tiled_mask, _math_ops.conj(ym))
313 inner_dim = input_shape[-1]
314 ym_term = _array_ops.tile(
315 ym_term,
316 _array_ops.concat([
317 _array_ops.ones([_array_ops.rank(grad) - 1], _dtypes.int32),
318 [inner_dim]
319 ], 0)) * _ymask(inner_dim)
321 extra_terms += is_even * ym_term
323 # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling
324 # factor, plus some additional terms to make up for the components dropped
325 # due to Hermitian symmetry.
326 input_size = _math_ops.cast(
327 _fft_size_for_grad(op.inputs[0], rank), real_dtype)
328 the_irfft = irfft_fn(grad, fft_length)
329 return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None
331 return _grad
334def _irfft_grad_helper(rank, rfft_fn):
335 """Returns a gradient function for an IRFFT of the provided rank."""
336 # Can't happen because we don't register a gradient for IRFFT3D.
337 assert rank in (1, 2), "Gradient for IRFFT3D is not implemented."
339 def _grad(op, grad):
340 """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
341 # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
342 # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
343 # graph we special-case the situation where the FFT length and last
344 # dimension of the input are known at graph construction time.
345 fft_length = op.inputs[1]
346 fft_length_static = _tensor_util.constant_value(fft_length)
347 if fft_length_static is not None:
348 fft_length = fft_length_static
349 real_dtype = grad.dtype
350 if real_dtype == _dtypes.float32:
351 complex_dtype = _dtypes.complex64
352 elif real_dtype == _dtypes.float64:
353 complex_dtype = _dtypes.complex128
354 is_odd = _math_ops.mod(fft_length[-1], 2)
355 input_last_dimension = _array_ops.shape(op.inputs[0])[-1]
356 mask = _array_ops.concat(
357 [[1.0], 2.0 * _array_ops.ones(
358 [input_last_dimension - 2 + is_odd], real_dtype),
359 _array_ops.ones([1 - is_odd], real_dtype)], 0)
361 rsize = _math_ops.reciprocal(_math_ops.cast(
362 _fft_size_for_grad(grad, rank), real_dtype))
364 # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
365 # factor and a mask. The mask scales the gradient for the Hermitian
366 # symmetric components of the RFFT by a factor of two, since these
367 # components are de-duplicated in the RFFT.
368 the_rfft = rfft_fn(grad, fft_length)
369 return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None
371 return _grad
374@tf_export("signal.fftshift")
375@dispatch.add_dispatch_support
376def fftshift(x, axes=None, name=None):
377 """Shift the zero-frequency component to the center of the spectrum.
379 This function swaps half-spaces for all axes listed (defaults to all).
380 Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even.
382 @compatibility(numpy)
383 Equivalent to numpy.fft.fftshift.
384 https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.fftshift.html
385 @end_compatibility
387 For example:
389 ```python
390 x = tf.signal.fftshift([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.])
391 x.numpy() # array([-5., -4., -3., -2., -1., 0., 1., 2., 3., 4.])
392 ```
394 Args:
395 x: `Tensor`, input tensor.
396 axes: `int` or shape `tuple`, optional Axes over which to shift. Default is
397 None, which shifts all axes.
398 name: An optional name for the operation.
400 Returns:
401 A `Tensor`, The shifted tensor.
402 """
403 with _ops.name_scope(name, "fftshift") as name:
404 x = _ops.convert_to_tensor(x)
405 if axes is None:
406 axes = tuple(range(x.shape.ndims))
407 shift = _array_ops.shape(x) // 2
408 elif isinstance(axes, int):
409 shift = _array_ops.shape(x)[axes] // 2
410 else:
411 rank = _array_ops.rank(x)
412 # allows negative axis
413 axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes)
414 shift = _array_ops.gather(_array_ops.shape(x), axes) // 2
416 return manip_ops.roll(x, shift, axes, name)
419@tf_export("signal.ifftshift")
420@dispatch.add_dispatch_support
421def ifftshift(x, axes=None, name=None):
422 """The inverse of fftshift.
424 Although identical for even-length x,
425 the functions differ by one sample for odd-length x.
427 @compatibility(numpy)
428 Equivalent to numpy.fft.ifftshift.
429 https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.ifftshift.html
430 @end_compatibility
432 For example:
434 ```python
435 x = tf.signal.ifftshift([[ 0., 1., 2.],[ 3., 4., -4.],[-3., -2., -1.]])
436 x.numpy() # array([[ 4., -4., 3.],[-2., -1., -3.],[ 1., 2., 0.]])
437 ```
439 Args:
440 x: `Tensor`, input tensor.
441 axes: `int` or shape `tuple` Axes over which to calculate. Defaults to None,
442 which shifts all axes.
443 name: An optional name for the operation.
445 Returns:
446 A `Tensor`, The shifted tensor.
447 """
448 with _ops.name_scope(name, "ifftshift") as name:
449 x = _ops.convert_to_tensor(x)
450 if axes is None:
451 axes = tuple(range(x.shape.ndims))
452 shift = -(_array_ops.shape(x) // 2)
453 elif isinstance(axes, int):
454 shift = -(_array_ops.shape(x)[axes] // 2)
455 else:
456 rank = _array_ops.rank(x)
457 # allows negative axis
458 axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes)
459 shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2)
461 return manip_ops.roll(x, shift, axes, name)
464_ops.RegisterGradient("RFFT")(_rfft_grad_helper(1, irfft))
465_ops.RegisterGradient("IRFFT")(_irfft_grad_helper(1, rfft))
466_ops.RegisterGradient("RFFT2D")(_rfft_grad_helper(2, irfft2d))
467_ops.RegisterGradient("IRFFT2D")(_irfft_grad_helper(2, rfft2d))