Skip to content

Instantly share code, notes, and snippets.

@tam17aki
Last active October 23, 2024 21:50
Show Gist options
  • Save tam17aki/8ff8fbd12437ec59a1e898a61cbe98e3 to your computer and use it in GitHub Desktop.
Save tam17aki/8ff8fbd12437ec59a1e898a61cbe98e3 to your computer and use it in GitHub Desktop.
Demonstration script for phase recovery via rational function approximation.
# -*- coding: utf-8 -*-
"""Demonstration script for phase recovery via rational function approximation.
Copyright (C) 2024 by Akira TAMAMORI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import math
from dataclasses import dataclass
import numpy as np
import soundfile as sf
import torch
from pesq import pesq
from pystoi import stoi
from scipy import signal
from torch import nn, optim
@dataclass(frozen=True)
class NetworkConfig:
"""Class for network config."""
input_channels: int = 257
hidden_channels: int = 64
kernel_size: int = 3
n_order: int = 10
@dataclass(frozen=True)
class FeatureConfig:
"""Class for config of feature extraction."""
hop_length: int = 128
n_fft: int = 512
rate: int = 16000
in_wavefile: str = "input.wav"
@dataclass(frozen=True)
class TrainingConfig:
"""Class for training config."""
n_epochs: int = 5000
lr: float = 0.001
n_interval: int = 1000
class PhaseRecoveryNet(nn.Module):
"""Phase recovery via rational function approximation."""
def __init__(self):
"""Initialize class."""
super().__init__()
cfg = NetworkConfig()
self.conv1 = nn.Conv1d(
cfg.input_channels,
cfg.hidden_channels,
cfg.kernel_size,
padding=(cfg.kernel_size - 1) // 2,
)
self.conv2 = nn.Conv1d(
cfg.hidden_channels,
cfg.hidden_channels,
cfg.kernel_size,
padding=(cfg.kernel_size - 1) // 2,
)
self.norm1 = nn.InstanceNorm1d(cfg.hidden_channels)
self.norm2 = nn.InstanceNorm1d(cfg.hidden_channels)
self.fc = nn.Linear(cfg.hidden_channels, (2 * cfg.n_order) * cfg.input_channels)
self.activation = nn.ReLU()
def forward(self, inputs):
"""Forward propagation."""
inputs = inputs - torch.mean(inputs, dim=-1, keepdim=True)
hidden = self.norm1(self.conv1(inputs))
hidden = self.activation(hidden)
hidden = self.norm2(self.conv2(hidden))
hidden = self.activation(hidden)
coefficients = self.fc(hidden.transpose(1, 2))
return coefficients
def compute_ratiofunc(coeffs, freqs):
"""Compute a rational function."""
n_spec, n_coeffs = freqs.shape
n_batch, n_frame, _ = coeffs.size()
coeffs = torch.reshape(coeffs, (n_batch, n_frame, n_spec, 2 * n_coeffs))
numer_coeffs = coeffs[:, :, :, :n_coeffs]
den_coeffs = coeffs[:, :, :, n_coeffs:]
numer_coeffs = numer_coeffs.transpose(1, 2)
den_coeffs = den_coeffs.transpose(1, 2)
freqs = freqs.unsqueeze(0).expand(n_batch, n_spec, n_coeffs)
freqs = freqs.unsqueeze(2).expand(n_batch, n_spec, n_frame, n_coeffs)
numerator = (numer_coeffs * freqs).sum(dim=-1)
denominator = (den_coeffs * freqs).sum(dim=-1)
phase_spectrum = torch.atan2(numerator, denominator)
return phase_spectrum
def compensate_phase(phase, win_len, n_batch, n_frame):
"""Compensate uniform linear phases."""
pi_tensor = torch.Tensor([math.pi]).cuda()
k = torch.arange(0, win_len // 2 + 1).cuda()
angle_freq = (2 * pi_tensor / win_len) * k * (win_len - 1) / 2
angle_freq = angle_freq.unsqueeze(1).expand(len(k), n_frame)
angle_freq = angle_freq.unsqueeze(0).expand(n_batch, len(k), n_frame)
phase = phase + torch.angle(torch.exp(1j * angle_freq))
return phase
def extract_log_amp_and_phase(cfg: FeatureConfig):
"""Extract log magnitude and phase spectra from audio files."""
filepath = cfg.in_wavefile
n_fft = cfg.n_fft
hop_length = cfg.hop_length
audio, rate = sf.read(filepath)
stfft = signal.ShortTimeFFT(
win=signal.windows.hann(n_fft), hop=hop_length, fs=rate, mfft=n_fft
)
stft_data = stfft.stft(audio)
amplitude = np.abs(stft_data).astype(np.float32)
phase = np.angle(stft_data).astype(np.float32)
amplitude = torch.from_numpy(amplitude)
phase = torch.from_numpy(phase)
log_amplitude = torch.log(amplitude + 1e-8)
return log_amplitude, phase
def make_power_freq(n_splits, n_order):
"""Compute powers of angular frequencies."""
freq = torch.tensor(np.linspace(0, np.pi, n_splits)).float()
power_freq = []
for order in range(n_order):
power_freq.append((freq) ** order)
power_freq = torch.stack(power_freq, dim=1)
return power_freq
def loss_func(criterion, log_amplitude, pred_phase, target_phase):
"""Compute loss function."""
pred = torch.exp(log_amplitude + 1j * pred_phase)
target = torch.exp(log_amplitude + 1j * target_phase)
loss = criterion(pred.real, target.real) + criterion(pred.imag, target.imag)
return loss
@torch.no_grad()
def recover_phase(log_amplitude, model, power_freq):
"""Reconstruct phase."""
coefficients = model(log_amplitude)
phase = compute_ratiofunc(coefficients, power_freq)
phase = phase.to("cpu").detach().numpy().copy()
phase = np.squeeze(phase)
return phase
def generate_wave(
log_amplitude, phase, feat_cfg: FeatureConfig, train_cfg: TrainingConfig
):
"""Generate waveform from reconstructed phase."""
log_amplitude = log_amplitude.to("cpu").detach().numpy().copy()
log_amplitude = np.squeeze(log_amplitude)
predict = np.exp(log_amplitude + 1j * phase)
stfft = signal.ShortTimeFFT(
win=signal.windows.hann(feat_cfg.n_fft),
hop=feat_cfg.hop_length,
fs=feat_cfg.rate,
mfft=feat_cfg.n_fft,
)
audio = stfft.istft(predict)
sf.write(f"output_{train_cfg.n_epochs}.wav", audio, feat_cfg.rate)
def eval_scores(feat_cfg: FeatureConfig, train_cfg: TrainingConfig):
"""Compute objective scores; PESQ, STOI and LSC."""
reference, rate = sf.read(feat_cfg.in_wavefile)
audio, _ = sf.read(f"output_{train_cfg.n_epochs}.wav")
if len(audio) > len(reference):
audio = audio[: len(reference)]
else:
reference = reference[: len(audio)]
stfft = signal.ShortTimeFFT(
win=signal.windows.hann(feat_cfg.n_fft),
hop=feat_cfg.hop_length,
fs=feat_cfg.rate,
mfft=feat_cfg.n_fft,
)
ref_spec = stfft.stft(reference).T
eval_spec = stfft.stft(audio).T
lsc = np.linalg.norm(np.abs(ref_spec) - np.abs(eval_spec))
lsc = lsc / np.linalg.norm(np.abs(ref_spec))
lsc = 20 * np.log10(lsc)
print(
pesq(rate, reference, audio, "wb"),
stoi(reference, audio, rate, extended=True),
lsc,
)
def main():
"""Perform model training."""
net_cfg = NetworkConfig()
feat_cfg = FeatureConfig()
train_cfg = TrainingConfig()
# setup training modules
model = PhaseRecoveryNet().cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=train_cfg.lr)
# feature extraction
log_amplitude, target_phase = extract_log_amp_and_phase(feat_cfg)
power_freq = make_power_freq(log_amplitude.size(0), net_cfg.n_order).cuda()
log_amplitude = log_amplitude.unsqueeze(0).cuda()
target_phase = target_phase.unsqueeze(0).cuda()
# training
model.train()
for epoch in range(train_cfg.n_epochs):
epoch_loss = 0
optimizer.zero_grad()
coefficients = model(log_amplitude)
pred_phase = compute_ratiofunc(coefficients, power_freq)
loss = loss_func(criterion, log_amplitude, pred_phase, target_phase)
loss.backward()
optimizer.step()
epoch_loss = loss.item()
if (epoch + 1) % train_cfg.n_interval == 0:
print(f"Epoch [{epoch+1}/{train_cfg.n_epochs}], Loss: {epoch_loss:.4f}")
# inference
model.eval()
phase = recover_phase(log_amplitude, model, power_freq)
generate_wave(log_amplitude, phase, feat_cfg, train_cfg)
eval_scores(feat_cfg, train_cfg)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment