Last active
June 4, 2024 23:32
-
-
Save vadimkantorov/58edfb48122d9f4819c29c48268ed37d to your computer and use it in GitHub Desktop.
Source code of the model from https://github.com/snakers4/silero-vad v4 extracted from the source code attributes embedded in the TorchScript structures
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# I printed the code listings from the TorchScript silero_vad.jit's .code/_c.code attributes and tidied up the source a bit, nothing really fancy here | |
# This can be used for optimizing inference and enabling GPU inference | |
# Big thanks to the Silero company for making public their VAD checkpoint! | |
# The used checkpoint: | |
# https://github.com/snakers4/silero-vad/blob/a9d2b591dea11451d23aa4b480eff8e55dbd9d99/files/silero_vad.jit | |
import torch | |
import torch.nn as nn | |
class STFT(nn.Module): | |
def __init__(self, filter_length = 256, hop_length = 64): | |
super().__init__() | |
self.filter_length = filter_length | |
self.hop_length = hop_length | |
self.register_buffer('forward_basis_buffer', torch.zeros(258, 1, filter_length)) #TODO: initialize as cos/sin | |
#print(model_torchscript._model.feature_extractor.transform_.code, file = open('test.txt', 'w')) | |
#print(model_torchscript._model.feature_extractor.code, file = open('test.txt', 'w')) | |
def forward(self, input_data): | |
input_data0 = input_data.unsqueeze(1) | |
to_pad = int(torch.div(torch.sub(self.filter_length, self.hop_length), 2)) | |
input_data1 = torch.nn.functional.pad(torch.unsqueeze(input_data0, 1), [to_pad, to_pad, 0, 0], "reflect") | |
forward_transform = torch.conv1d(torch.squeeze(input_data1, 1), self.forward_basis_buffer, None, [self.hop_length], [0]) | |
cutoff = int(torch.add(torch.div(self.filter_length, 2), 1)) | |
real_part = forward_transform[:, :cutoff, :] | |
imag_part = forward_transform[:, cutoff:, :] | |
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) | |
#phase = torch.atan2(imag_part, real_part) | |
#return (magnitude, phase) | |
return magnitude | |
class AdaptiveAudioNormalizationNew(nn.Module): | |
def __init__(self, to_pad = 3): | |
super().__init__() | |
self.to_pad = to_pad | |
self.filter_ = nn.Parameter(torch.zeros(1, 1, 2 * to_pad + 1)) | |
#print(model_torchscript._model.adaptive_normalization.inlined_graph, file = open('test.txt', 'w')) | |
@staticmethod | |
def simple_pad(_mean_1, _to_pad_1): | |
_left_pad_1 = torch.flip(_mean_1[::1, ::1, 1 : _to_pad_1 + 1 : 1], [-1]) | |
_right_pad_1 = torch.flip(_mean_1[::1, ::1, -1 - _to_pad_1: -1 : 1], [-1]) | |
return torch.cat([_left_pad_1, _mean_1, _right_pad_1], 2) | |
#print(model_torchscript._model.adaptive_normalization.code, file = open('test.txt', 'w')) | |
def forward(self, spect): | |
spect0 = torch.log1p(spect * 1048576) | |
spect1 = torch.unsqueeze(spect0, 0) if spect0.ndim == 2 else spect0 | |
mean0 = self.simple_pad(torch.mean(spect1, [1], True), self.to_pad) | |
mean1 = torch.conv1d(mean0, self.filter_) | |
mean_mean = torch.mean(mean1, [-1], True) | |
return spect1 + (-mean_mean) | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels = 258, out_channels = 16, proj = False): | |
super().__init__() | |
self.dw_conv = nn.Sequential(nn.Conv1d(in_channels, in_channels, 5, padding = 2, groups = in_channels), nn.Identity(), nn.ReLU()) | |
self.pw_conv = nn.Sequential(nn.Conv1d(in_channels, out_channels, 1), nn.Identity()) | |
self.proj = nn.Conv1d(in_channels, out_channels, 1) if proj else nn.Identity() | |
self.activation = nn.ReLU() | |
#print(getattr(model_torchscript._model.first_layer, "0").code, file = open('test.txt', 'w')) | |
def forward(self, x): | |
residual = self.proj(x) | |
x0 = self.pw_conv(self.dw_conv(x)) | |
x0 += residual | |
return self.activation(x0) | |
class VADDecoderRNNJIT(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.rnn = nn.LSTM(64, 64, num_layers = 2, batch_first = True, dropout = 0.1) | |
self.decoder = nn.Sequential(nn.ReLU(), nn.Conv1d(64, 1, 1), nn.Sigmoid()) | |
#print(model_torchscript._model.decoder.code, file = open('test.txt', 'w')) | |
def forward(self, x, h=torch.Tensor(), c=torch.Tensor()): | |
x, (h, c), = self.rnn(torch.permute(x, [0, 2, 1]), (h, c) if h.numel() > 0 else None) | |
return (self.decoder(torch.permute(x, [0, 2, 1])), h, c) | |
class VADRNNJIT(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.feature_extractor = STFT() | |
self.adaptive_normalization = AdaptiveAudioNormalizationNew() | |
self.first_layer = nn.Sequential(ConvBlock(258, 16, proj = True), nn.Dropout(0.15)) | |
self.encoder = nn.Sequential(nn.Conv1d(16, 16, 1, stride = 2), nn.BatchNorm1d(16), nn.ReLU(), nn.Sequential(ConvBlock(16, 32, proj = True), nn.Dropout(0.15)), nn.Conv1d(32, 32, 1, stride = 2), nn.BatchNorm1d(32), nn.ReLU(), nn.Sequential(ConvBlock(32, 32, proj = False),nn.Dropout(0.15)), nn.Conv1d(32, 32, 1, stride = 2), nn.BatchNorm1d(32), nn.ReLU(), nn.Sequential(ConvBlock(32, 64, proj = True), nn.Dropout(0.15)), nn.Conv1d(64, 64, 1, stride = 1), nn.BatchNorm1d(64), nn.ReLU()) | |
self.decoder = VADDecoderRNNJIT() | |
#print(model_torchscript._model.code, file = open('test.txt', 'w')) | |
def forward(self, x, h = torch.Tensor(), c = torch.Tensor()): | |
x0 = self.feature_extractor(x) | |
norm = self.adaptive_normalization(x0) | |
x1 = torch.cat([x0, norm], 1) | |
x2 = self.first_layer(x1) | |
x3 = self.encoder(x2) | |
x4, h0, c0, = self.decoder(x3, h, c) | |
out = torch.unsqueeze(torch.mean(torch.squeeze(x4, 1), [1]), 1) | |
return (out, h0, c0) | |
class VADRNNJITMerge(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._model = VADRNNJIT() | |
self._model_8k = VADRNNJIT() | |
self._last_batch_size = None | |
self._last_sr = None | |
self._h = None | |
self.sample_rates = [8000, 16000] | |
self.reset_states() | |
#print(s.reset_states.code, file = open('test.txt', 'w')) | |
def reset_states(self, batch_size = 1): | |
self._h = torch.zeros([0]) | |
self._c = torch.zeros([0]) | |
#self._h = torch.zeros((2, batch_size, 64), dtype = torch.float32) | |
#self._c = torch.zeros((2, batch_size, 64), dtype = torch.float32) | |
self._last_sr = 0 | |
self._last_batch_size = 0 | |
#print(s._validate_input.code, file = open('test.txt', 'w')) | |
def _validate_input(self, x, sr): | |
x1 = torch.unsqueeze(x, 0) if x.ndim == 1 else x | |
assert x1.ndim == 2, f"Too many dimensions for input audio chunk {x1.ndim}" | |
sr1, x2 = (16000, x1[:, ::sr // 16000]) if sr != 16000 and sr % 16000 == 0 else (sr, x1) | |
assert sr1 in self.sample_rates, f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)" | |
assert sr1 / x2.shape[1] <= 31.25, "Input audio chunk is too short" | |
return (x2, sr1) | |
#print(s.code, file = open('test.txt', 'w')) | |
def forward(self, x, sr): | |
x0, sr0, = self._validate_input(x, sr) | |
if self._last_sr and self._last_sr != sr0: | |
self.reset_states() | |
if self._last_batch_size and self._last_batch_size != x0.shape[0]: | |
self.reset_states() | |
assert sr0 == 16000 or sr0 == 8000 | |
out, self._h, self._c, = (self._model_8k if sr == 8000 else self._model) (x0, self._h, self._c) | |
self._last_sr = sr0 | |
self._last_batch_size = self._h.shape[1] | |
return out | |
#print(model_torchscript.audio_forward.code, file = open('test.txt', 'w')) | |
def audio_forward(self, x, sr, num_samples: int = 512): | |
x, sr = self._validate_input(x, sr) | |
if x.shape[1] % num_samples: | |
pad_num = num_samples - (x.shape[1] % num_samples) | |
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0) | |
self.reset_states(x.shape[0]) | |
outs = [self(x[:, i:i+num_samples], sr) for i in range(0, x.shape[1], num_samples)] | |
return torch.cat(outs, dim=1) | |
if __name__ == '__main__': | |
silero_torchscript_checkpoint = 'silero_vad.jit' | |
import hashlib; assert '22aced3da46b9d9546686310f779818e' == hashlib.md5(open(silero_torchscript_checkpoint,'rb').read()).hexdigest() | |
model_torchscript = torch.jit.load(silero_torchscript_checkpoint) | |
model_torchscript.eval() | |
state_dict = model_torchscript.state_dict() | |
print(model_torchscript) | |
model = VADRNNJITMerge() | |
model.eval() | |
model.load_state_dict(state_dict) | |
print(model) | |
torch.set_grad_enabled(False) | |
torch.set_num_threads(1) | |
import torchaudio | |
samples_CT, sample_rate = torchaudio.load('ru.wav') # https://models.silero.ai/vad_models/ru.wav | |
assert sample_rate == 16000 | |
model_torchscript.reset_states() | |
speech_prob_torchscript = model_torchscript(samples_CT, sample_rate) | |
model_torchscript.reset_states() | |
speech_prob_torchscript_batch = model_torchscript.audio_forward(samples_CT, sample_rate) | |
print(speech_prob_torchscript, speech_prob_torchscript_batch, speech_prob_torchscript_batch.shape) | |
model.reset_states() | |
speech_prob = model(samples_CT, sample_rate) | |
model.reset_states() | |
speech_prob_batch = model.audio_forward(samples_CT, sample_rate) | |
print(speech_prob, speech_prob_batch, speech_prob_batch.shape) | |
assert torch.allclose(speech_prob_torchscript, speech_prob) | |
assert torch.allclose(speech_prob_torchscript_batch, speech_prob_batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
RecursiveScriptModule( | |
original_name=VADRNNJITMerge | |
(_model): RecursiveScriptModule( | |
original_name=VADRNNJIT | |
(adaptive_normalization): RecursiveScriptModule(original_name=AdaptiveAudioNormalizationNew) | |
(feature_extractor): RecursiveScriptModule(original_name=STFT) | |
(first_layer): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule( | |
original_name=ConvBlock | |
(dw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
) | |
(pw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
) | |
(proj): RecursiveScriptModule(original_name=Conv1d) | |
(activation): RecursiveScriptModule(original_name=ReLU) | |
) | |
(1): RecursiveScriptModule(original_name=Dropout) | |
) | |
(encoder): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=BatchNorm1d) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
(3): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule( | |
original_name=ConvBlock | |
(dw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
) | |
(pw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
) | |
(proj): RecursiveScriptModule(original_name=Conv1d) | |
(activation): RecursiveScriptModule(original_name=ReLU) | |
) | |
(1): RecursiveScriptModule(original_name=Dropout) | |
) | |
(4): RecursiveScriptModule(original_name=Conv1d) | |
(5): RecursiveScriptModule(original_name=BatchNorm1d) | |
(6): RecursiveScriptModule(original_name=ReLU) | |
(7): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule( | |
original_name=ConvBlock | |
(dw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
) | |
(pw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
) | |
(activation): RecursiveScriptModule(original_name=ReLU) | |
) | |
(1): RecursiveScriptModule(original_name=Dropout) | |
) | |
(8): RecursiveScriptModule(original_name=Conv1d) | |
(9): RecursiveScriptModule(original_name=BatchNorm1d) | |
(10): RecursiveScriptModule(original_name=ReLU) | |
(11): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule( | |
original_name=ConvBlock | |
(dw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
) | |
(pw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
) | |
(proj): RecursiveScriptModule(original_name=Conv1d) | |
(activation): RecursiveScriptModule(original_name=ReLU) | |
) | |
(1): RecursiveScriptModule(original_name=Dropout) | |
) | |
(12): RecursiveScriptModule(original_name=Conv1d) | |
(13): RecursiveScriptModule(original_name=BatchNorm1d) | |
(14): RecursiveScriptModule(original_name=ReLU) | |
) | |
(decoder): RecursiveScriptModule( | |
original_name=VADDecoderRNNJIT | |
(rnn): RecursiveScriptModule(original_name=LSTM) | |
(decoder): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=ReLU) | |
(1): RecursiveScriptModule(original_name=Conv1d) | |
(2): RecursiveScriptModule(original_name=Sigmoid) | |
) | |
) | |
) | |
(_model_8k): RecursiveScriptModule( | |
original_name=VADRNNJIT | |
(adaptive_normalization): RecursiveScriptModule(original_name=AdaptiveAudioNormalizationNew) | |
(feature_extractor): RecursiveScriptModule(original_name=STFT) | |
(first_layer): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule( | |
original_name=ConvBlock | |
(dw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
) | |
(pw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
) | |
(proj): RecursiveScriptModule(original_name=Conv1d) | |
(activation): RecursiveScriptModule(original_name=ReLU) | |
) | |
(1): RecursiveScriptModule(original_name=Dropout) | |
) | |
(encoder): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=BatchNorm1d) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
(3): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule( | |
original_name=ConvBlock | |
(dw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
) | |
(pw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
) | |
(proj): RecursiveScriptModule(original_name=Conv1d) | |
(activation): RecursiveScriptModule(original_name=ReLU) | |
) | |
(1): RecursiveScriptModule(original_name=Dropout) | |
) | |
(4): RecursiveScriptModule(original_name=Conv1d) | |
(5): RecursiveScriptModule(original_name=BatchNorm1d) | |
(6): RecursiveScriptModule(original_name=ReLU) | |
(7): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule( | |
original_name=ConvBlock | |
(dw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
) | |
(pw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
) | |
(activation): RecursiveScriptModule(original_name=ReLU) | |
) | |
(1): RecursiveScriptModule(original_name=Dropout) | |
) | |
(8): RecursiveScriptModule(original_name=Conv1d) | |
(9): RecursiveScriptModule(original_name=BatchNorm1d) | |
(10): RecursiveScriptModule(original_name=ReLU) | |
(11): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule( | |
original_name=ConvBlock | |
(dw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
(2): RecursiveScriptModule(original_name=ReLU) | |
) | |
(pw_conv): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=Conv1d) | |
(1): RecursiveScriptModule(original_name=Identity) | |
) | |
(proj): RecursiveScriptModule(original_name=Conv1d) | |
(activation): RecursiveScriptModule(original_name=ReLU) | |
) | |
(1): RecursiveScriptModule(original_name=Dropout) | |
) | |
(12): RecursiveScriptModule(original_name=Conv1d) | |
(13): RecursiveScriptModule(original_name=BatchNorm1d) | |
(14): RecursiveScriptModule(original_name=ReLU) | |
) | |
(decoder): RecursiveScriptModule( | |
original_name=VADDecoderRNNJIT | |
(rnn): RecursiveScriptModule(original_name=LSTM) | |
(decoder): RecursiveScriptModule( | |
original_name=Sequential | |
(0): RecursiveScriptModule(original_name=ReLU) | |
(1): RecursiveScriptModule(original_name=Conv1d) | |
(2): RecursiveScriptModule(original_name=Sigmoid) | |
) | |
) | |
) | |
) | |
VADRNNJITMerge( | |
(_model): VADRNNJIT( | |
(feature_extractor): STFT() | |
(adaptive_normalization): AdaptiveAudioNormalizationNew() | |
(first_layer): Sequential( | |
(0): ConvBlock( | |
(dw_conv): Sequential( | |
(0): Conv1d(258, 258, kernel_size=(5,), stride=(1,), padding=(2,), groups=258) | |
(1): Identity() | |
(2): ReLU() | |
) | |
(pw_conv): Sequential( | |
(0): Conv1d(258, 16, kernel_size=(1,), stride=(1,)) | |
(1): Identity() | |
) | |
(proj): Conv1d(258, 16, kernel_size=(1,), stride=(1,)) | |
(activation): ReLU() | |
) | |
(1): Dropout(p=0.15, inplace=False) | |
) | |
(encoder): Sequential( | |
(0): Conv1d(16, 16, kernel_size=(1,), stride=(2,)) | |
(1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
(2): ReLU() | |
(3): Sequential( | |
(0): ConvBlock( | |
(dw_conv): Sequential( | |
(0): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,), groups=16) | |
(1): Identity() | |
(2): ReLU() | |
) | |
(pw_conv): Sequential( | |
(0): Conv1d(16, 32, kernel_size=(1,), stride=(1,)) | |
(1): Identity() | |
) | |
(proj): Conv1d(16, 32, kernel_size=(1,), stride=(1,)) | |
(activation): ReLU() | |
) | |
(1): Dropout(p=0.15, inplace=False) | |
) | |
(4): Conv1d(32, 32, kernel_size=(1,), stride=(2,)) | |
(5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
(6): ReLU() | |
(7): Sequential( | |
(0): ConvBlock( | |
(dw_conv): Sequential( | |
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), groups=32) | |
(1): Identity() | |
(2): ReLU() | |
) | |
(pw_conv): Sequential( | |
(0): Conv1d(32, 32, kernel_size=(1,), stride=(1,)) | |
(1): Identity() | |
) | |
(proj): Identity() | |
(activation): ReLU() | |
) | |
(1): Dropout(p=0.15, inplace=False) | |
) | |
(8): Conv1d(32, 32, kernel_size=(1,), stride=(2,)) | |
(9): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
(10): ReLU() | |
(11): Sequential( | |
(0): ConvBlock( | |
(dw_conv): Sequential( | |
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), groups=32) | |
(1): Identity() | |
(2): ReLU() | |
) | |
(pw_conv): Sequential( | |
(0): Conv1d(32, 64, kernel_size=(1,), stride=(1,)) | |
(1): Identity() | |
) | |
(proj): Conv1d(32, 64, kernel_size=(1,), stride=(1,)) | |
(activation): ReLU() | |
) | |
(1): Dropout(p=0.15, inplace=False) | |
) | |
(12): Conv1d(64, 64, kernel_size=(1,), stride=(1,)) | |
(13): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
(14): ReLU() | |
) | |
(decoder): VADDecoderRNNJIT( | |
(rnn): LSTM(64, 64, num_layers=2, batch_first=True, dropout=0.1) | |
(decoder): Sequential( | |
(0): ReLU() | |
(1): Conv1d(64, 1, kernel_size=(1,), stride=(1,)) | |
(2): Sigmoid() | |
) | |
) | |
) | |
(_model_8k): VADRNNJIT( | |
(feature_extractor): STFT() | |
(adaptive_normalization): AdaptiveAudioNormalizationNew() | |
(first_layer): Sequential( | |
(0): ConvBlock( | |
(dw_conv): Sequential( | |
(0): Conv1d(258, 258, kernel_size=(5,), stride=(1,), padding=(2,), groups=258) | |
(1): Identity() | |
(2): ReLU() | |
) | |
(pw_conv): Sequential( | |
(0): Conv1d(258, 16, kernel_size=(1,), stride=(1,)) | |
(1): Identity() | |
) | |
(proj): Conv1d(258, 16, kernel_size=(1,), stride=(1,)) | |
(activation): ReLU() | |
) | |
(1): Dropout(p=0.15, inplace=False) | |
) | |
(encoder): Sequential( | |
(0): Conv1d(16, 16, kernel_size=(1,), stride=(2,)) | |
(1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
(2): ReLU() | |
(3): Sequential( | |
(0): ConvBlock( | |
(dw_conv): Sequential( | |
(0): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,), groups=16) | |
(1): Identity() | |
(2): ReLU() | |
) | |
(pw_conv): Sequential( | |
(0): Conv1d(16, 32, kernel_size=(1,), stride=(1,)) | |
(1): Identity() | |
) | |
(proj): Conv1d(16, 32, kernel_size=(1,), stride=(1,)) | |
(activation): ReLU() | |
) | |
(1): Dropout(p=0.15, inplace=False) | |
) | |
(4): Conv1d(32, 32, kernel_size=(1,), stride=(2,)) | |
(5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
(6): ReLU() | |
(7): Sequential( | |
(0): ConvBlock( | |
(dw_conv): Sequential( | |
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), groups=32) | |
(1): Identity() | |
(2): ReLU() | |
) | |
(pw_conv): Sequential( | |
(0): Conv1d(32, 32, kernel_size=(1,), stride=(1,)) | |
(1): Identity() | |
) | |
(proj): Identity() | |
(activation): ReLU() | |
) | |
(1): Dropout(p=0.15, inplace=False) | |
) | |
(8): Conv1d(32, 32, kernel_size=(1,), stride=(2,)) | |
(9): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
(10): ReLU() | |
(11): Sequential( | |
(0): ConvBlock( | |
(dw_conv): Sequential( | |
(0): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), groups=32) | |
(1): Identity() | |
(2): ReLU() | |
) | |
(pw_conv): Sequential( | |
(0): Conv1d(32, 64, kernel_size=(1,), stride=(1,)) | |
(1): Identity() | |
) | |
(proj): Conv1d(32, 64, kernel_size=(1,), stride=(1,)) | |
(activation): ReLU() | |
) | |
(1): Dropout(p=0.15, inplace=False) | |
) | |
(12): Conv1d(64, 64, kernel_size=(1,), stride=(1,)) | |
(13): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | |
(14): ReLU() | |
) | |
(decoder): VADDecoderRNNJIT( | |
(rnn): LSTM(64, 64, num_layers=2, batch_first=True, dropout=0.1) | |
(decoder): Sequential( | |
(0): ReLU() | |
(1): Conv1d(64, 1, kernel_size=(1,), stride=(1,)) | |
(2): Sigmoid() | |
) | |
) | |
) | |
) | |
tensor([[0.5635]]) tensor([[0.0592, 0.0315, 0.0283, ..., 0.3480, 0.3140, 0.2010]]) torch.Size([1, 1875]) | |
tensor([[0.5635]]) tensor([[0.0592, 0.0315, 0.0283, ..., 0.3480, 0.3140, 0.2010]]) torch.Size([1, 1875]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment