Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/window_ops.py: 36%
81 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for computing common window functions."""
17import numpy as np
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import cond
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import special_math_ops
27from tensorflow.python.util import dispatch
28from tensorflow.python.util.tf_export import tf_export
31def _check_params(window_length, dtype):
32 """Check window_length and dtype params.
34 Args:
35 window_length: A scalar value or `Tensor`.
36 dtype: The data type to produce. Must be a floating point type.
38 Returns:
39 window_length converted to a tensor of type int32.
41 Raises:
42 ValueError: If `dtype` is not a floating point type or window_length is not
43 a scalar.
44 """
45 if not dtype.is_floating:
46 raise ValueError('dtype must be a floating point type. Found %s' % dtype)
47 window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32)
48 window_length.shape.assert_has_rank(0)
49 return window_length
52@tf_export('signal.kaiser_window')
53@dispatch.add_dispatch_support
54def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
55 """Generate a [Kaiser window][kaiser].
57 Args:
58 window_length: A scalar `Tensor` indicating the window length to generate.
59 beta: Beta parameter for Kaiser window, see reference below.
60 dtype: The data type to produce. Must be a floating point type.
61 name: An optional name for the operation.
63 Returns:
64 A `Tensor` of shape `[window_length]` of type `dtype`.
66 [kaiser]:
67 https://docs.scipy.org/doc/numpy/reference/generated/numpy.kaiser.html
68 """
69 with ops.name_scope(name, 'kaiser_window'):
70 window_length = _check_params(window_length, dtype)
71 window_length_const = tensor_util.constant_value(window_length)
72 if window_length_const == 1:
73 return array_ops.ones([1], dtype=dtype)
74 # tf.range does not support float16 so we work with float32 initially.
75 halflen_float = (
76 math_ops.cast(window_length, dtype=dtypes.float32) - 1.0) / 2.0
77 arg = math_ops.range(-halflen_float, halflen_float + 0.1,
78 dtype=dtypes.float32)
79 # Convert everything into given dtype which can be float16.
80 arg = math_ops.cast(arg, dtype=dtype)
81 beta = math_ops.cast(beta, dtype=dtype)
82 one = math_ops.cast(1.0, dtype=dtype)
83 two = math_ops.cast(2.0, dtype=dtype)
84 halflen_float = math_ops.cast(halflen_float, dtype=dtype)
85 num = beta * math_ops.sqrt(
86 one - math_ops.pow(arg, two) / math_ops.pow(halflen_float, two))
87 window = math_ops.exp(num - beta) * (
88 special_math_ops.bessel_i0e(num) / special_math_ops.bessel_i0e(beta))
89 return window
92@tf_export('signal.kaiser_bessel_derived_window')
93@dispatch.add_dispatch_support
94def kaiser_bessel_derived_window(window_length, beta=12.,
95 dtype=dtypes.float32, name=None):
96 """Generate a [Kaiser Bessel derived window][kbd].
98 Args:
99 window_length: A scalar `Tensor` indicating the window length to generate.
100 beta: Beta parameter for Kaiser window.
101 dtype: The data type to produce. Must be a floating point type.
102 name: An optional name for the operation.
104 Returns:
105 A `Tensor` of shape `[window_length]` of type `dtype`.
107 [kbd]:
108 https://en.wikipedia.org/wiki/Kaiser_window#Kaiser%E2%80%93Bessel-derived_(KBD)_window
109 """
110 with ops.name_scope(name, 'kaiser_bessel_derived_window'):
111 window_length = _check_params(window_length, dtype)
112 halflen = window_length // 2
113 kaiserw = kaiser_window(halflen + 1, beta, dtype=dtype)
114 kaiserw_csum = math_ops.cumsum(kaiserw)
115 halfw = math_ops.sqrt(kaiserw_csum[:-1] / kaiserw_csum[-1])
116 window = array_ops.concat((halfw, halfw[::-1]), axis=0)
117 return window
120@tf_export('signal.vorbis_window')
121@dispatch.add_dispatch_support
122def vorbis_window(window_length, dtype=dtypes.float32, name=None):
123 """Generate a [Vorbis power complementary window][vorbis].
125 Args:
126 window_length: A scalar `Tensor` indicating the window length to generate.
127 dtype: The data type to produce. Must be a floating point type.
128 name: An optional name for the operation.
130 Returns:
131 A `Tensor` of shape `[window_length]` of type `dtype`.
133 [vorbis]:
134 https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform#Window_functions
135 """
136 with ops.name_scope(name, 'vorbis_window'):
137 window_length = _check_params(window_length, dtype)
138 arg = math_ops.cast(math_ops.range(window_length), dtype=dtype)
139 window = math_ops.sin(np.pi / 2.0 * math_ops.pow(math_ops.sin(
140 np.pi / math_ops.cast(window_length, dtype=dtype) *
141 (arg + 0.5)), 2.0))
142 return window
145@tf_export('signal.hann_window')
146@dispatch.add_dispatch_support
147def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
148 """Generate a [Hann window][hann].
150 Args:
151 window_length: A scalar `Tensor` indicating the window length to generate.
152 periodic: A bool `Tensor` indicating whether to generate a periodic or
153 symmetric window. Periodic windows are typically used for spectral
154 analysis while symmetric windows are typically used for digital
155 filter design.
156 dtype: The data type to produce. Must be a floating point type.
157 name: An optional name for the operation.
159 Returns:
160 A `Tensor` of shape `[window_length]` of type `dtype`.
162 Raises:
163 ValueError: If `dtype` is not a floating point type.
165 [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
166 """
167 return _raised_cosine_window(name, 'hann_window', window_length, periodic,
168 dtype, 0.5, 0.5)
171@tf_export('signal.hamming_window')
172@dispatch.add_dispatch_support
173def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
174 name=None):
175 """Generate a [Hamming][hamming] window.
177 Args:
178 window_length: A scalar `Tensor` indicating the window length to generate.
179 periodic: A bool `Tensor` indicating whether to generate a periodic or
180 symmetric window. Periodic windows are typically used for spectral
181 analysis while symmetric windows are typically used for digital
182 filter design.
183 dtype: The data type to produce. Must be a floating point type.
184 name: An optional name for the operation.
186 Returns:
187 A `Tensor` of shape `[window_length]` of type `dtype`.
189 Raises:
190 ValueError: If `dtype` is not a floating point type.
192 [hamming]:
193 https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
194 """
195 return _raised_cosine_window(name, 'hamming_window', window_length, periodic,
196 dtype, 0.54, 0.46)
199def _raised_cosine_window(name, default_name, window_length, periodic,
200 dtype, a, b):
201 """Helper function for computing a raised cosine window.
203 Args:
204 name: Name to use for the scope.
205 default_name: Default name to use for the scope.
206 window_length: A scalar `Tensor` or integer indicating the window length.
207 periodic: A bool `Tensor` indicating whether to generate a periodic or
208 symmetric window.
209 dtype: A floating point `DType`.
210 a: The alpha parameter to the raised cosine window.
211 b: The beta parameter to the raised cosine window.
213 Returns:
214 A `Tensor` of shape `[window_length]` of type `dtype`.
216 Raises:
217 ValueError: If `dtype` is not a floating point type or `window_length` is
218 not scalar or `periodic` is not scalar.
219 """
220 if not dtype.is_floating:
221 raise ValueError('dtype must be a floating point type. Found %s' % dtype)
223 with ops.name_scope(name, default_name, [window_length, periodic]):
224 window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32,
225 name='window_length')
226 window_length.shape.assert_has_rank(0)
227 window_length_const = tensor_util.constant_value(window_length)
228 if window_length_const == 1:
229 return array_ops.ones([1], dtype=dtype)
230 periodic = math_ops.cast(
231 ops.convert_to_tensor(periodic, dtype=dtypes.bool, name='periodic'),
232 dtypes.int32)
233 periodic.shape.assert_has_rank(0)
234 even = 1 - math_ops.mod(window_length, 2)
236 n = math_ops.cast(window_length + periodic * even - 1, dtype=dtype)
237 count = math_ops.cast(math_ops.range(window_length), dtype)
238 cos_arg = constant_op.constant(2 * np.pi, dtype=dtype) * count / n
240 if window_length_const is not None:
241 return math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype)
242 return cond.cond(
243 math_ops.equal(window_length, 1),
244 lambda: array_ops.ones([window_length], dtype=dtype),
245 lambda: math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype))