Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active September 26, 2023 21:49
Show Gist options
  • Save vadimkantorov/4f34fe60d2ef00e72dcad16512d224af to your computer and use it in GitHub Desktop.
Save vadimkantorov/4f34fe60d2ef00e72dcad16512d224af to your computer and use it in GitHub Desktop.
Sinc convolution module in PyTorch (adapted and simplified from the original SincNet codebase)
# Sinc learned filter banks were proposed in "Speaker Recognition from raw waveform with SincNet", Ravanelli and Bengio, http://arxiv.org/abs/1808.00158
# Code is simplified and adapted from https://github.com/mravanelli/SincNet/blob/master/dnn_models.py
import math
import torch
class SincConv1d(torch.nn.Conv1d):
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, groups = 1, bias = False, padding_mode = 'zeros', sample_rate = 16_000, min_low_hz = 50, min_band_hz = 50, low_hz = 30):
assert in_channels == 1 and kernel_size % 2 == 1 and bias is False and groups == 1
super().__init__(in_channels, out_channels, kernel_size, stride = stride, padding = padding, dilation = dilation, groups = groups, bias = bias, padding_mode = padding_mode)
self.register_parameter('weight', None)
to_mel = lambda hz: 2595 * torch.log10(1 + hz / 700)
to_hz = lambda mel: 700 * (10 ** (mel / 2595) - 1)
ediff1d = lambda t, prepend = False: t[1:] - t[:-1]
self.sample_rate = sample_rate
self.min_lo_hz = min_low_hz
self.min_band_hz = min_band_hz
self.max_hi_hz = sample_rate / 2
high_hz = self.max_hi_hz - (self.min_lo_hz + self.min_band_hz)
hz = to_hz(torch.linspace(to_mel(torch.tensor(float(low_hz))), to_mel(torch.tensor(float(high_hz))), steps = out_channels + 1))
self.low_hz_ = nn.Parameter(hz[:-1].unsqueeze(-1))
self.band_hz_ = nn.Parameter(ediff1d(hz).unsqueeze(-1))
self.register_buffer('window', torch.hamming_window(kernel_size)[:kernel_size // 2])
self.register_buffer('sinct', 2 * math.pi * torch.arange(-(kernel_size // 2), 0, dtype = torch.float32) / sample_rate)
@property
def weight(self):
if self._buffers.get('weight') is not None:
return self._buffers['weight']
lo = self.min_low_hz + self.low_hz_.abs()
hi = (lo + self.min_band_hz + self.band_hz_.abs()).clamp(min = self.min_low_hz, max = self.max_hi_hz)
sincarg_hi, sincarg_lo = hi * self.sinct, lo * self.sinct
band_pass_left = (sincarg_hi.sin() - sincarg_lo.sin()) / (self.sinct / 2) * self.window
band_pass_center = (hi - lo) * 2
band_pass_right = band_pass_left.flip(dims = [1])
band_pass = torch.cat([band_pass_left, band_pass_center, band_pass_right], dim = 1) / band_pass_center
return band_pass.unsqueeze(1)
def freeze(self):
self.register_buffer('weight', self.weight)
self.register_parameter('low_hz_', None)
self.register_parameter('band_hz_', None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment