Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/signal/reconstruction_ops.py: 19%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Signal reconstruction via overlapped addition of frames.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import tensor_util 

20from tensorflow.python.ops import array_ops 

21from tensorflow.python.ops import math_ops 

22from tensorflow.python.util import dispatch 

23from tensorflow.python.util.tf_export import tf_export 

24 

25 

26@tf_export("signal.overlap_and_add") 

27@dispatch.add_dispatch_support 

28def overlap_and_add(signal, frame_step, name=None): 

29 """Reconstructs a signal from a framed representation. 

30 

31 Adds potentially overlapping frames of a signal with shape 

32 `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. 

33 The resulting tensor has shape `[..., output_size]` where 

34 

35 output_size = (frames - 1) * frame_step + frame_length 

36 

37 Args: 

38 signal: A [..., frames, frame_length] `Tensor`. All dimensions may be 

39 unknown, and rank must be at least 2. 

40 frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be 

41 less than or equal to `frame_length`. 

42 name: An optional name for the operation. 

43 

44 Returns: 

45 A `Tensor` with shape `[..., output_size]` containing the overlap-added 

46 frames of `signal`'s inner-most two dimensions. 

47 

48 Raises: 

49 ValueError: If `signal`'s rank is less than 2, or `frame_step` is not a 

50 scalar integer. 

51 """ 

52 with ops.name_scope(name, "overlap_and_add", [signal, frame_step]): 

53 signal = ops.convert_to_tensor(signal, name="signal") 

54 signal.shape.with_rank_at_least(2) 

55 frame_step = ops.convert_to_tensor(frame_step, name="frame_step") 

56 frame_step.shape.assert_has_rank(0) 

57 if not frame_step.dtype.is_integer: 

58 raise ValueError("frame_step must be an integer. Got %s" % 

59 frame_step.dtype) 

60 frame_step_static = tensor_util.constant_value(frame_step) 

61 frame_step_is_static = frame_step_static is not None 

62 frame_step = frame_step_static if frame_step_is_static else frame_step 

63 

64 signal_shape = array_ops.shape(signal) 

65 signal_shape_static = tensor_util.constant_value(signal_shape) 

66 if signal_shape_static is not None: 

67 signal_shape = signal_shape_static 

68 

69 # All dimensions that are not part of the overlap-and-add. Can be empty for 

70 # rank 2 inputs. 

71 outer_dimensions = signal_shape[:-2] 

72 outer_rank = array_ops.size(outer_dimensions) 

73 outer_rank_static = tensor_util.constant_value(outer_rank) 

74 if outer_rank_static is not None: 

75 outer_rank = outer_rank_static 

76 

77 def full_shape(inner_shape): 

78 return array_ops.concat([outer_dimensions, inner_shape], 0) 

79 

80 frame_length = signal_shape[-1] 

81 frames = signal_shape[-2] 

82 

83 # Compute output length. 

84 output_length = frame_length + frame_step * (frames - 1) 

85 

86 # If frame_length is equal to frame_step, there's no overlap so just 

87 # reshape the tensor. 

88 if (frame_step_is_static and signal.shape.dims is not None and 

89 frame_step == signal.shape.dims[-1].value): 

90 output_shape = full_shape([output_length]) 

91 return array_ops.reshape(signal, output_shape, name="fast_path") 

92 

93 # The following code is documented using this example: 

94 # 

95 # frame_step = 2 

96 # signal.shape = (3, 5) 

97 # a b c d e 

98 # f g h i j 

99 # k l m n o 

100 

101 # Compute the number of segments, per frame. 

102 segments = -(-frame_length // frame_step) # Divide and round up. 

103 

104 # Pad the frame_length dimension to a multiple of the frame step. 

105 # Pad the frames dimension by `segments` so that signal.shape = (6, 6) 

106 # a b c d e 0 

107 # f g h i j 0 

108 # k l m n o 0 

109 # 0 0 0 0 0 0 

110 # 0 0 0 0 0 0 

111 # 0 0 0 0 0 0 

112 paddings = [[0, segments], [0, segments * frame_step - frame_length]] 

113 outer_paddings = array_ops.zeros([outer_rank, 2], dtypes.int32) 

114 paddings = array_ops.concat([outer_paddings, paddings], 0) 

115 signal = array_ops.pad(signal, paddings) 

116 

117 # Reshape so that signal.shape = (3, 6, 2) 

118 # ab cd e0 

119 # fg hi j0 

120 # kl mn o0 

121 # 00 00 00 

122 # 00 00 00 

123 # 00 00 00 

124 shape = full_shape([frames + segments, segments, frame_step]) 

125 signal = array_ops.reshape(signal, shape) 

126 

127 # Transpose dimensions so that signal.shape = (3, 6, 2) 

128 # ab fg kl 00 00 00 

129 # cd hi mn 00 00 00 

130 # e0 j0 o0 00 00 00 

131 perm = array_ops.concat( 

132 [math_ops.range(outer_rank), outer_rank + [1, 0, 2]], 0) 

133 perm_static = tensor_util.constant_value(perm) 

134 perm = perm_static if perm_static is not None else perm 

135 signal = array_ops.transpose(signal, perm) 

136 

137 # Reshape so that signal.shape = (18, 2) 

138 # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0 00 00 00 

139 shape = full_shape([(frames + segments) * segments, frame_step]) 

140 signal = array_ops.reshape(signal, shape) 

141 

142 # Truncate so that signal.shape = (15, 2) 

143 # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0 

144 signal = signal[..., :(frames + segments - 1) * segments, :] 

145 

146 # Reshape so that signal.shape = (3, 5, 2) 

147 # ab fg kl 00 00 

148 # 00 cd hi mn 00 

149 # 00 00 e0 j0 o0 

150 shape = full_shape([segments, (frames + segments - 1), frame_step]) 

151 signal = array_ops.reshape(signal, shape) 

152 

153 # Now, reduce over the columns, to achieve the desired sum. 

154 signal = math_ops.reduce_sum(signal, -3) 

155 

156 # Flatten the array. 

157 shape = full_shape([(frames + segments - 1) * frame_step]) 

158 signal = array_ops.reshape(signal, shape) 

159 

160 # Truncate to final length. 

161 signal = signal[..., :output_length] 

162 

163 return signal