Skip to content

Commit d035d7b

Browse files
authored
Merge pull request #132 from sp-nitech/griffin
Fix bug of istft and add griffin
2 parents 7119b2e + 59fb2c5 commit d035d7b

25 files changed

+507
-28
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
*diffsptk* is a differentiable version of [SPTK](https://github.com/sp-nitech/SPTK) based on the PyTorch framework.
44

5-
[![Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/3.1.0/)
5+
[![Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/3.2.0/)
66
[![Downloads](https://static.pepy.tech/badge/diffsptk)](https://pepy.tech/project/diffsptk)
77
[![ClickPy](https://img.shields.io/badge/downloads-clickpy-yellow.svg)](https://clickpy.clickhouse.com/dashboard/diffsptk)
88
[![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)

diffsptk/functional.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,101 @@ def gnorm(x: Tensor, gamma: float = 0, c: int | None = None) -> Tensor:
864864
return nn.GeneralizedCepstrumGainNormalization._func(x, gamma=gamma, c=c)
865865

866866

867+
def griffin(
868+
y: Tensor,
869+
*,
870+
out_length: int | None = None,
871+
frame_length: int = 400,
872+
frame_period: int = 80,
873+
fft_length: int = 512,
874+
center: bool = True,
875+
mode: str = "constant",
876+
window: str = "blackman",
877+
norm: str = "power",
878+
symmetric: bool = True,
879+
n_iter: int = 100,
880+
alpha: float = 0.99,
881+
beta: float = 0.99,
882+
gamma: float = 1.1,
883+
init_phase: str = "random",
884+
verbose: bool = False,
885+
) -> Tensor:
886+
"""Reconstruct a waveform from the spectrum using the Griffin-Lim algorithm.
887+
888+
Parameters
889+
----------
890+
y : Tensor [shape=(..., T/P, N/2+1)]
891+
The power spectrum.
892+
893+
out_length : int > 0 or None
894+
The length of the output waveform.
895+
896+
frame_length : int >= 1
897+
The frame length in samples, :math:`L`.
898+
899+
frame_period : int >= 1
900+
The frame period in samples, :math:`P`.
901+
902+
fft_length : int >= L
903+
The number of FFT bins, :math:`N`.
904+
905+
center : bool
906+
If True, pad the input on both sides so that the frame is centered.
907+
908+
window : ['blackman', 'hamming', 'hanning', 'bartlett', 'trapezoidal', \
909+
'rectangular', 'nuttall']
910+
The window type.
911+
912+
norm : ['none', 'power', 'magnitude']
913+
The normalization type of the window.
914+
915+
symmetric : bool
916+
If True, the window is symmetric, otherwise periodic.
917+
918+
n_iter : int >= 1
919+
The number of iterations for phase reconstruction.
920+
921+
alpha : float >= 0
922+
The momentum factor, :math:`\\alpha`.
923+
924+
beta : float >= 0
925+
The momentum factor, :math:`\\beta`.
926+
927+
gamma : float >= 0
928+
The smoothing factor, :math:`\\gamma`.
929+
930+
init_phase : ['zeros', 'random']
931+
The initial phase for the reconstruction.
932+
933+
verbose : bool
934+
If True, print the SNR at each iteration.
935+
936+
Returns
937+
-------
938+
out : Tensor [shape=(..., T)]
939+
The reconstructed waveform.
940+
941+
"""
942+
return nn.GriffinLim._func(
943+
y,
944+
out_length=out_length,
945+
frame_length=frame_length,
946+
frame_period=frame_period,
947+
fft_length=fft_length,
948+
center=center,
949+
mode=mode,
950+
window=window,
951+
norm=norm,
952+
symmetric=symmetric,
953+
n_iter=n_iter,
954+
alpha=alpha,
955+
beta=beta,
956+
gamma=gamma,
957+
init_phase=init_phase,
958+
verbose=verbose,
959+
)
960+
961+
867962
def grpdelay(
868963
b: Tensor | None = None,
869964
a: Tensor | None = None,

diffsptk/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from .gmm import GaussianMixtureModeling
5757
from .gmm import GaussianMixtureModeling as GMM
5858
from .gnorm import GeneralizedCepstrumGainNormalization
59+
from .griffin import GriffinLim
5960
from .grpdelay import GroupDelay
6061
from .hilbert import HilbertTransform
6162
from .hilbert2 import TwoDimensionalHilbertTransform

diffsptk/modules/griffin.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# ------------------------------------------------------------------------ #
2+
# Copyright 2022 SPTK Working Group #
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License"); #
5+
# you may not use this file except in compliance with the License. #
6+
# You may obtain a copy of the License at #
7+
# #
8+
# http://www.apache.org/licenses/LICENSE-2.0 #
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software #
11+
# distributed under the License is distributed on an "AS IS" BASIS, #
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
13+
# See the License for the specific language governing permissions and #
14+
# limitations under the License. #
15+
# ------------------------------------------------------------------------ #
16+
17+
import inspect
18+
import logging
19+
20+
import torch
21+
from torch import nn
22+
23+
from ..typing import Callable, Precomputed
24+
from ..utils.private import TAU, get_layer, get_logger, get_values
25+
from .base import BaseFunctionalModule
26+
from .istft import InverseShortTimeFourierTransform
27+
from .stft import ShortTimeFourierTransform
28+
29+
30+
class GriffinLim(BaseFunctionalModule):
31+
"""Griffin-Lim phase reconstruction module.
32+
33+
Parameters
34+
----------
35+
frame_length : int >= 1
36+
The frame length in samples, :math:`L`.
37+
38+
frame_period : int >= 1
39+
The frame period in samples, :math:`P`.
40+
41+
fft_length : int >= L
42+
The number of FFT bins, :math:`N`.
43+
44+
center : bool
45+
If True, pad the input on both sides so that the frame is centered.
46+
47+
window : ['blackman', 'hamming', 'hanning', 'bartlett', 'trapezoidal', \
48+
'rectangular', 'nuttall']
49+
The window type.
50+
51+
norm : ['none', 'power', 'magnitude']
52+
The normalization type of the window.
53+
54+
symmetric : bool
55+
If True, the window is symmetric, otherwise periodic.
56+
57+
n_iter : int >= 1
58+
The number of iterations for phase reconstruction.
59+
60+
alpha : float >= 0
61+
The momentum factor, :math:`\\alpha`.
62+
63+
beta : float >= 0
64+
The momentum factor, :math:`\\beta`.
65+
66+
gamma : float >= 0
67+
The smoothing factor, :math:`\\gamma`.
68+
69+
init_phase : ['zeros', 'random']
70+
The initial phase for the reconstruction.
71+
72+
verbose : bool
73+
If True, print the SNR at each iteration.
74+
75+
References
76+
----------
77+
.. [1] R. Nenov et al., "Faster than fast: Accelerating the Griffin-Lim algorithm,"
78+
*Proceedings of ICASSP*, 2023.
79+
80+
"""
81+
82+
def __init__(
83+
self,
84+
frame_length: int,
85+
frame_period: int,
86+
fft_length: int,
87+
*,
88+
center: bool = True,
89+
mode: str = "constant",
90+
window: str = "blackman",
91+
norm: str = "power",
92+
symmetric: bool = True,
93+
n_iter: int = 100,
94+
alpha: float = 0.99,
95+
beta: float = 0.99,
96+
gamma: float = 1.1,
97+
init_phase: str = "random",
98+
verbose: bool = False,
99+
) -> None:
100+
super().__init__()
101+
102+
self.values, layers, _ = self._precompute(*get_values(locals()))
103+
self.layers = nn.ModuleList(layers)
104+
105+
def forward(self, y: torch.Tensor, out_length: int | None = None) -> torch.Tensor:
106+
"""Reconstruct a waveform from the spectrum using the Griffin-Lim algorithm.
107+
108+
Parameters
109+
----------
110+
y : Tensor [shape=(..., T/P, N/2+1)]
111+
The power spectrum.
112+
113+
out_length : int > 0 or None
114+
The length of the output waveform.
115+
116+
Returns
117+
-------
118+
out : Tensor [shape=(..., T)]
119+
The reconstructed waveform.
120+
121+
Examples
122+
--------
123+
>>> x = diffsptk.ramp(1, 3)
124+
>>> x
125+
tensor([1., 2., 3.])
126+
>>> stft_params = {"frame_length": 3, "frame_period": 1, "fft_length": 8}
127+
>>> stft = diffsptk.STFT(**stft_params, out_format="power")
128+
>>> griffin = diffsptk.GriffinLim(**stft_params, n_iter=10, init_phase="zeros")
129+
>>> y = griffin(stft(x), out_length=3)
130+
>>> y
131+
tensor([ 1.0000, 2.0000, -3.0000])
132+
133+
"""
134+
return self._forward(y, out_length, *self.values, *self.layers)
135+
136+
@staticmethod
137+
def _func(y: torch.Tensor, out_length: int | None, *args, **kwargs) -> torch.Tensor:
138+
values, layers, _ = GriffinLim._precompute(*args, **kwargs)
139+
return GriffinLim._forward(y, out_length, *values, *layers)
140+
141+
@staticmethod
142+
def _takes_input_size() -> bool:
143+
return False
144+
145+
@staticmethod
146+
def _check(
147+
n_iter: int,
148+
alpha: float,
149+
beta: float,
150+
gamma: float,
151+
) -> None:
152+
if n_iter <= 0:
153+
raise ValueError("n_iter must be positive.")
154+
if alpha < 0:
155+
raise ValueError("alpha must be non-negative.")
156+
if beta < 0:
157+
raise ValueError("beta must be non-negative.")
158+
if gamma < 0:
159+
raise ValueError("gamma must be non-negative.")
160+
161+
@staticmethod
162+
def _precompute(
163+
frame_length: int,
164+
frame_period: int,
165+
fft_length: int,
166+
center: bool,
167+
mode: str,
168+
window: str,
169+
norm: str,
170+
symmetric: bool,
171+
n_iter: int,
172+
alpha: float,
173+
beta: float,
174+
gamma: float,
175+
init_phase: str,
176+
verbose: bool,
177+
) -> Precomputed:
178+
GriffinLim._check(n_iter, alpha, beta, gamma)
179+
module = inspect.stack()[1].function != "_func"
180+
181+
if init_phase == "zeros":
182+
phase_generator = lambda x: torch.zeros_like(x)
183+
elif init_phase == "random":
184+
phase_generator = lambda x: TAU * torch.rand_like(x)
185+
else:
186+
raise ValueError(f"init_phase: {init_phase} is not supported.")
187+
188+
if verbose:
189+
logger = get_logger("griffin")
190+
else:
191+
logger = None
192+
193+
stft = get_layer(
194+
module,
195+
ShortTimeFourierTransform,
196+
dict(
197+
frame_length=frame_length,
198+
frame_period=frame_period,
199+
fft_length=fft_length,
200+
center=center,
201+
zmean=False,
202+
mode=mode,
203+
window=window,
204+
norm=norm,
205+
symmetric=symmetric,
206+
eps=0,
207+
relative_floor=None,
208+
out_format="complex",
209+
),
210+
)
211+
istft = get_layer(
212+
module,
213+
InverseShortTimeFourierTransform,
214+
dict(
215+
frame_length=frame_length,
216+
frame_period=frame_period,
217+
fft_length=fft_length,
218+
center=center,
219+
window=window,
220+
norm=norm,
221+
symmetric=symmetric,
222+
),
223+
)
224+
return (
225+
(n_iter, alpha, beta, gamma, phase_generator, logger),
226+
(stft, istft),
227+
None,
228+
)
229+
230+
@staticmethod
231+
def _forward(
232+
y: torch.Tensor,
233+
out_length: int | None,
234+
n_iter: int,
235+
alpha: float,
236+
beta: float,
237+
gamma: float,
238+
phase_generator: Callable,
239+
logger: logging.Logger | None,
240+
stft: Callable,
241+
istft: Callable,
242+
) -> torch.Tensor:
243+
if logger is not None:
244+
logger.info(f"alpha: {alpha}, beta: {beta}, gamma: {gamma}")
245+
246+
s = torch.sqrt(y)
247+
angle = torch.exp(1j * phase_generator(s))
248+
249+
t_prev = d_prev = 0 # This suppresses F821 and F841.
250+
for n in range(n_iter):
251+
t = stft(istft(s * angle, out_length=out_length))
252+
t = t[..., : s.shape[-2], :]
253+
254+
if 0 == n:
255+
c = d = t
256+
else:
257+
t = (1 - gamma) * d_prev + gamma * t
258+
diff = t - t_prev
259+
c = t + alpha * diff
260+
d = t + beta * diff
261+
262+
angle = c / (c.abs() + 1e-16)
263+
t_prev = t
264+
d_prev = d
265+
266+
if logger is not None:
267+
snr = -10 * torch.log10(
268+
torch.linalg.norm(c.abs() - s) / torch.linalg.norm(s)
269+
)
270+
logger.info(f" iter {n + 1:5d}: SNR = {snr:g}")
271+
272+
return istft(s * angle, out_length=out_length)

0 commit comments

Comments
 (0)