Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/reconstruction_ops.py: 19%
59 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Signal reconstruction via overlapped addition of frames."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_util
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.util import dispatch
23from tensorflow.python.util.tf_export import tf_export
26@tf_export("signal.overlap_and_add")
27@dispatch.add_dispatch_support
28def overlap_and_add(signal, frame_step, name=None):
29 """Reconstructs a signal from a framed representation.
31 Adds potentially overlapping frames of a signal with shape
32 `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
33 The resulting tensor has shape `[..., output_size]` where
35 output_size = (frames - 1) * frame_step + frame_length
37 Args:
38 signal: A [..., frames, frame_length] `Tensor`. All dimensions may be
39 unknown, and rank must be at least 2.
40 frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be
41 less than or equal to `frame_length`.
42 name: An optional name for the operation.
44 Returns:
45 A `Tensor` with shape `[..., output_size]` containing the overlap-added
46 frames of `signal`'s inner-most two dimensions.
48 Raises:
49 ValueError: If `signal`'s rank is less than 2, or `frame_step` is not a
50 scalar integer.
51 """
52 with ops.name_scope(name, "overlap_and_add", [signal, frame_step]):
53 signal = ops.convert_to_tensor(signal, name="signal")
54 signal.shape.with_rank_at_least(2)
55 frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
56 frame_step.shape.assert_has_rank(0)
57 if not frame_step.dtype.is_integer:
58 raise ValueError("frame_step must be an integer. Got %s" %
59 frame_step.dtype)
60 frame_step_static = tensor_util.constant_value(frame_step)
61 frame_step_is_static = frame_step_static is not None
62 frame_step = frame_step_static if frame_step_is_static else frame_step
64 signal_shape = array_ops.shape(signal)
65 signal_shape_static = tensor_util.constant_value(signal_shape)
66 if signal_shape_static is not None:
67 signal_shape = signal_shape_static
69 # All dimensions that are not part of the overlap-and-add. Can be empty for
70 # rank 2 inputs.
71 outer_dimensions = signal_shape[:-2]
72 outer_rank = array_ops.size(outer_dimensions)
73 outer_rank_static = tensor_util.constant_value(outer_rank)
74 if outer_rank_static is not None:
75 outer_rank = outer_rank_static
77 def full_shape(inner_shape):
78 return array_ops.concat([outer_dimensions, inner_shape], 0)
80 frame_length = signal_shape[-1]
81 frames = signal_shape[-2]
83 # Compute output length.
84 output_length = frame_length + frame_step * (frames - 1)
86 # If frame_length is equal to frame_step, there's no overlap so just
87 # reshape the tensor.
88 if (frame_step_is_static and signal.shape.dims is not None and
89 frame_step == signal.shape.dims[-1].value):
90 output_shape = full_shape([output_length])
91 return array_ops.reshape(signal, output_shape, name="fast_path")
93 # The following code is documented using this example:
94 #
95 # frame_step = 2
96 # signal.shape = (3, 5)
97 # a b c d e
98 # f g h i j
99 # k l m n o
101 # Compute the number of segments, per frame.
102 segments = -(-frame_length // frame_step) # Divide and round up.
104 # Pad the frame_length dimension to a multiple of the frame step.
105 # Pad the frames dimension by `segments` so that signal.shape = (6, 6)
106 # a b c d e 0
107 # f g h i j 0
108 # k l m n o 0
109 # 0 0 0 0 0 0
110 # 0 0 0 0 0 0
111 # 0 0 0 0 0 0
112 paddings = [[0, segments], [0, segments * frame_step - frame_length]]
113 outer_paddings = array_ops.zeros([outer_rank, 2], dtypes.int32)
114 paddings = array_ops.concat([outer_paddings, paddings], 0)
115 signal = array_ops.pad(signal, paddings)
117 # Reshape so that signal.shape = (3, 6, 2)
118 # ab cd e0
119 # fg hi j0
120 # kl mn o0
121 # 00 00 00
122 # 00 00 00
123 # 00 00 00
124 shape = full_shape([frames + segments, segments, frame_step])
125 signal = array_ops.reshape(signal, shape)
127 # Transpose dimensions so that signal.shape = (3, 6, 2)
128 # ab fg kl 00 00 00
129 # cd hi mn 00 00 00
130 # e0 j0 o0 00 00 00
131 perm = array_ops.concat(
132 [math_ops.range(outer_rank), outer_rank + [1, 0, 2]], 0)
133 perm_static = tensor_util.constant_value(perm)
134 perm = perm_static if perm_static is not None else perm
135 signal = array_ops.transpose(signal, perm)
137 # Reshape so that signal.shape = (18, 2)
138 # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0 00 00 00
139 shape = full_shape([(frames + segments) * segments, frame_step])
140 signal = array_ops.reshape(signal, shape)
142 # Truncate so that signal.shape = (15, 2)
143 # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0
144 signal = signal[..., :(frames + segments - 1) * segments, :]
146 # Reshape so that signal.shape = (3, 5, 2)
147 # ab fg kl 00 00
148 # 00 cd hi mn 00
149 # 00 00 e0 j0 o0
150 shape = full_shape([segments, (frames + segments - 1), frame_step])
151 signal = array_ops.reshape(signal, shape)
153 # Now, reduce over the columns, to achieve the desired sum.
154 signal = math_ops.reduce_sum(signal, -3)
156 # Flatten the array.
157 shape = full_shape([(frames + segments - 1) * frame_step])
158 signal = array_ops.reshape(signal, shape)
160 # Truncate to final length.
161 signal = signal[..., :output_length]
163 return signal