Last active
September 26, 2023 21:49
-
-
Save vadimkantorov/4f34fe60d2ef00e72dcad16512d224af to your computer and use it in GitHub Desktop.
Sinc convolution module in PyTorch (adapted and simplified from the original SincNet codebase)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Sinc learned filter banks were proposed in "Speaker Recognition from raw waveform with SincNet", Ravanelli and Bengio, http://arxiv.org/abs/1808.00158 | |
# Code is simplified and adapted from https://github.com/mravanelli/SincNet/blob/master/dnn_models.py | |
import math | |
import torch | |
class SincConv1d(torch.nn.Conv1d): | |
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, groups = 1, bias = False, padding_mode = 'zeros', sample_rate = 16_000, min_low_hz = 50, min_band_hz = 50, low_hz = 30): | |
assert in_channels == 1 and kernel_size % 2 == 1 and bias is False and groups == 1 | |
super().__init__(in_channels, out_channels, kernel_size, stride = stride, padding = padding, dilation = dilation, groups = groups, bias = bias, padding_mode = padding_mode) | |
self.register_parameter('weight', None) | |
to_mel = lambda hz: 2595 * torch.log10(1 + hz / 700) | |
to_hz = lambda mel: 700 * (10 ** (mel / 2595) - 1) | |
ediff1d = lambda t, prepend = False: t[1:] - t[:-1] | |
self.sample_rate = sample_rate | |
self.min_lo_hz = min_low_hz | |
self.min_band_hz = min_band_hz | |
self.max_hi_hz = sample_rate / 2 | |
high_hz = self.max_hi_hz - (self.min_lo_hz + self.min_band_hz) | |
hz = to_hz(torch.linspace(to_mel(torch.tensor(float(low_hz))), to_mel(torch.tensor(float(high_hz))), steps = out_channels + 1)) | |
self.low_hz_ = nn.Parameter(hz[:-1].unsqueeze(-1)) | |
self.band_hz_ = nn.Parameter(ediff1d(hz).unsqueeze(-1)) | |
self.register_buffer('window', torch.hamming_window(kernel_size)[:kernel_size // 2]) | |
self.register_buffer('sinct', 2 * math.pi * torch.arange(-(kernel_size // 2), 0, dtype = torch.float32) / sample_rate) | |
@property | |
def weight(self): | |
if self._buffers.get('weight') is not None: | |
return self._buffers['weight'] | |
lo = self.min_low_hz + self.low_hz_.abs() | |
hi = (lo + self.min_band_hz + self.band_hz_.abs()).clamp(min = self.min_low_hz, max = self.max_hi_hz) | |
sincarg_hi, sincarg_lo = hi * self.sinct, lo * self.sinct | |
band_pass_left = (sincarg_hi.sin() - sincarg_lo.sin()) / (self.sinct / 2) * self.window | |
band_pass_center = (hi - lo) * 2 | |
band_pass_right = band_pass_left.flip(dims = [1]) | |
band_pass = torch.cat([band_pass_left, band_pass_center, band_pass_right], dim = 1) / band_pass_center | |
return band_pass.unsqueeze(1) | |
def freeze(self): | |
self.register_buffer('weight', self.weight) | |
self.register_parameter('low_hz_', None) | |
self.register_parameter('band_hz_', None) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment