Created
October 2, 2022 00:21
-
-
Save kygoh/1f429e49d82e830762f4b70ee8c7c092 to your computer and use it in GitHub Desktop.
classifier on the Google speech commands dataset v2 for the Key Word Spotting (KWS) using more complex non-linear models based on https://github.com/heungky/nnAudio_tutorial
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Step 1: import related libraries | |
""" | |
# Libraries related to PyTorch | |
import torch | |
from torch import Tensor | |
import torch.nn as nn | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data import WeightedRandomSampler,DataLoader | |
import torch.optim as optim | |
# import torch.nn.functional as F | |
# Libraries related to PyTorch Lightning | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.core.lightning import LightningModule | |
# Libraries used in lightning module | |
from sklearn.metrics import precision_recall_fscore_support | |
# Libraried related to dataset | |
from AudioLoader.speech import SPEECHCOMMANDS_12C #for 12 classes KWS task | |
# nnAudio Front-end | |
from nnAudio.features.mel import MelSpectrogram | |
""" | |
Step 2: setting up configuration | |
""" | |
device = 'cuda:0' | |
gpus = 1 | |
batch_size= 100 | |
max_epochs = 200 | |
check_val_every_n_epoch = 2 | |
num_sanity_val_steps = 5 | |
data_root = './' # Download the data here | |
download_option = True | |
n_mels = 40 | |
#number of Mel bins | |
output_dim = 12 | |
""" | |
Step 3: setting up nnAudio basis functions | |
""" | |
mel_layer = MelSpectrogram(sr=16000, | |
n_fft=480, | |
win_length=None, | |
n_mels=n_mels, | |
hop_length=160, | |
window='hann', | |
center=True, | |
pad_mode='reflect', | |
power=2.0, | |
htk=False, | |
fmin=0.0, | |
fmax=None, | |
norm=1, | |
trainable_mel=False, | |
trainable_STFT=False, | |
verbose=True) | |
""" | |
Step 4: setting up dataset | |
""" | |
trainset = SPEECHCOMMANDS_12C(root=data_root, | |
url='speech_commands_v0.02', | |
folder_in_archive='SpeechCommands', | |
download= download_option,subset= 'training') | |
validset = SPEECHCOMMANDS_12C(root=data_root, | |
url='speech_commands_v0.02', | |
folder_in_archive='SpeechCommands', | |
download= download_option,subset= 'validation') | |
testset = SPEECHCOMMANDS_12C(root=data_root, | |
url='speech_commands_v0.02', | |
folder_in_archive='SpeechCommands', | |
download= download_option,subset= 'testing') | |
""" | |
Step 5: data rebalancing | |
""" | |
class_weights = [1,1,1,1,1,1,1,1,1,1,4.6,1/17] | |
sample_weights = [0] * len(trainset) | |
#create a list as per length of trainset | |
for idx, (data,rate,label,speaker_id, _) in enumerate(trainset): | |
class_weight = class_weights[label] | |
sample_weights[idx] = class_weight | |
#apply sample_weights in each data base on their label class in class_weight | |
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights),replacement=True) | |
""" | |
Step 6: data processing and loading | |
""" | |
#data processing | |
def data_processing(data): | |
waveforms = [] | |
labels = [] | |
for batch in data: | |
waveforms.append(batch[0].squeeze(0)) #after squeeze => (audio_len) tensor # remove batch dim | |
labels.append(batch[2]) | |
waveform_padded = nn.utils.rnn.pad_sequence(waveforms, batch_first=True) | |
output_batch = {'waveforms': waveform_padded, | |
'labels': torch.tensor(labels), | |
} | |
return output_batch | |
#data loading | |
trainloader = DataLoader(trainset, | |
collate_fn=lambda x: data_processing(x), | |
batch_size=batch_size,sampler=sampler,num_workers=1) | |
validloader = DataLoader(validset, | |
collate_fn=lambda x: data_processing(x), | |
batch_size=batch_size,num_workers=1) | |
testloader = DataLoader(testset, | |
collate_fn=lambda x: data_processing(x), | |
batch_size=batch_size,num_workers=1) | |
""" | |
Step 7: setting up the lightning module | |
""" | |
class SpeechCommand(LightningModule): | |
def training_step(self, batch, batch_idx): | |
outputs, spec = self(batch['waveforms']) | |
#return outputs [2D] for calculate loss, return spec [3D] for visual | |
loss = self.criterion(outputs, batch['labels'].long()) | |
acc = sum(outputs.argmax(-1) == batch['labels'])/outputs.shape[0] #batch wise | |
self.log('Train/acc', acc, on_step=False, on_epoch=True) | |
self.log('Train/Loss', loss, on_step=False, on_epoch=True) | |
#log(graph title, take acc as data, on_step: plot every step, on_epch: plot every epoch) | |
return loss | |
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, | |
optimizer_closure, on_tpu, using_native_amp, using_lbfgs): | |
optimizer.step(closure=optimizer_closure) | |
with torch.no_grad(): | |
torch.clamp_(self.mel_layer.mel_basis, 0, 1) | |
#after optimizer step, do clamp function on mel_basis | |
def validation_step(self, batch, batch_idx): | |
outputs, spec = self(batch['waveforms']) | |
loss = self.criterion(outputs, batch['labels'].long()) | |
self.log('Validation/Loss', loss, on_step=False, on_epoch=True) | |
output_dict = {'outputs': outputs, | |
'labels': batch['labels']} | |
return output_dict | |
def validation_epoch_end(self, outputs): | |
pred = [] | |
label = [] | |
for output in outputs: | |
pred.append(output['outputs']) | |
label.append(output['labels']) | |
label = torch.cat(label, 0) | |
pred = torch.cat(pred, 0) | |
acc = sum(pred.argmax(-1) == label)/label.shape[0] | |
self.log('Validation/acc', acc, on_step=False, on_epoch=True) | |
#use the return value from validation_step: output_dict , to calculate the overall accuracy | |
def test_step(self, batch, batch_idx): | |
outputs, spec = self(batch['waveforms']) | |
loss = self.criterion(outputs, batch['labels'].long()) | |
self.log('Test/Loss', loss, on_step=False, on_epoch=True) | |
output_dict = {'outputs': outputs, | |
'labels': batch['labels']} | |
return output_dict | |
def test_epoch_end(self, outputs): | |
pred = [] | |
label = [] | |
for output in outputs: | |
pred.append(output['outputs']) | |
label.append(output['labels']) | |
label = torch.cat(label, 0) | |
pred = torch.cat(pred, 0) | |
result_dict = {} | |
for key in [None, 'micro', 'macro', 'weighted']: | |
result_dict[key] = {} | |
p, r, f1, _ = precision_recall_fscore_support(label.cpu(), pred.argmax(-1).cpu(), average=key, zero_division=0) | |
result_dict[key]['precision'] = p | |
result_dict[key]['recall'] = r | |
result_dict[key]['f1'] = f1 | |
acc = sum(pred.argmax(-1) == label)/label.shape[0] | |
self.log('Test/acc', acc, on_step=False, on_epoch=True) | |
self.log('Test/micro_f1', result_dict['micro']['f1'], on_step=False, on_epoch=True) | |
self.log('Test/macro_f1', result_dict['macro']['f1'], on_step=False, on_epoch=True) | |
self.log('Test/weighted_f1', result_dict['weighted']['f1'], on_step=False, on_epoch=True) | |
return result_dict | |
def configure_optimizers(self): | |
model_param = [] | |
for name, params in self.named_parameters(): | |
if 'mel_layer.' in name: | |
pass | |
else: | |
model_param.append(params) | |
optimizer = optim.SGD(model_param, lr=1e-3, momentum= 0.9, weight_decay= 0.001) | |
return [optimizer] | |
""" | |
Step 8: setting up model | |
""" | |
#BC_ResNet model | |
class SubSpectralNorm(LightningModule): | |
def __init__(self, C, S, eps=1e-5): | |
super(SubSpectralNorm, self).__init__() | |
self.S = S | |
self.eps = eps | |
self.bn = nn.BatchNorm2d(C*S) | |
def forward(self, x): | |
# x: input features with shape {N, C, F, T} | |
# S: number of sub-bands | |
N, C, F, T = x.size() | |
x = x.view(N, C * self.S, F // self.S, T) | |
x = self.bn(x) | |
return x.view(N, C, F, T) | |
class BroadcastedBlock(LightningModule): | |
def __init__( | |
self, | |
planes: int, | |
dilation=1, | |
stride=1, | |
temp_pad=(0, 1), | |
) -> None: | |
super(BroadcastedBlock, self).__init__() | |
self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes, | |
dilation=dilation, | |
stride=stride, bias=False) | |
self.ssn1 = SubSpectralNorm(planes, 5) | |
self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes, | |
dilation=dilation, stride=stride, bias=False) | |
self.bn = nn.BatchNorm2d(planes) | |
self.relu = nn.ReLU(inplace=True) | |
self.channel_drop = nn.Dropout2d(p=0.1) | |
self.swish = nn.SiLU() | |
self.conv1x1 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False) | |
def forward(self, x: Tensor) -> Tensor: | |
identity = x | |
# f2 | |
########################## | |
out = self.freq_dw_conv(x) | |
out = self.ssn1(out) | |
########################## | |
auxilary = out | |
out = out.mean(2, keepdim=True) # frequency average pooling | |
# f1 | |
############################ | |
out = self.temp_dw_conv(out) | |
out = self.bn(out) | |
out = self.swish(out) | |
out = self.conv1x1(out) | |
out = self.channel_drop(out) | |
############################ | |
out = out + identity + auxilary | |
out = self.relu(out) | |
return out | |
class TransitionBlock(LightningModule): | |
def __init__( | |
self, | |
inplanes: int, | |
planes: int, | |
dilation=1, | |
stride=1, | |
temp_pad=(0, 1), | |
) -> None: | |
super(TransitionBlock, self).__init__() | |
self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes, | |
stride=stride, | |
dilation=dilation, bias=False) | |
self.ssn = SubSpectralNorm(planes, 5) | |
self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes, | |
dilation=dilation, stride=stride, bias=False) | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.bn2 = nn.BatchNorm2d(planes) | |
self.relu = nn.ReLU(inplace=True) | |
self.channel_drop = nn.Dropout2d(p=0.5) | |
self.swish = nn.SiLU() | |
self.conv1x1_1 = nn.Conv2d(inplanes, planes, kernel_size=(1, 1), bias=False) | |
self.conv1x1_2 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False) | |
def forward(self, x: Tensor) -> Tensor: | |
# f2 | |
############################# | |
out = self.conv1x1_1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.freq_dw_conv(out) | |
out = self.ssn(out) | |
############################# | |
auxilary = out | |
out = out.mean(2, keepdim=True) # frequency average pooling | |
# f1 | |
############################# | |
out = self.temp_dw_conv(out) | |
out = self.bn2(out) | |
out = self.swish(out) | |
out = self.conv1x1_2(out) | |
out = self.channel_drop(out) | |
############################# | |
out = auxilary + out | |
out = self.relu(out) | |
return out | |
class BCResNet_nnAudio(SpeechCommand): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(1, 16, 5, stride=(2, 1), padding=(2, 2)) | |
self.block1_1 = TransitionBlock(16, 8) | |
self.block1_2 = BroadcastedBlock(8) | |
self.block2_1 = TransitionBlock(8, 12, stride=(2, 1), dilation=(1, 2), temp_pad=(0, 2)) | |
self.block2_2 = BroadcastedBlock(12, dilation=(1, 2), temp_pad=(0, 2)) | |
self.block3_1 = TransitionBlock(12, 16, stride=(2, 1), dilation=(1, 4), temp_pad=(0, 4)) | |
self.block3_2 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4)) | |
self.block3_3 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4)) | |
self.block3_4 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4)) | |
self.block4_1 = TransitionBlock(16, 20, dilation=(1, 8), temp_pad=(0, 8)) | |
self.block4_2 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8)) | |
self.block4_3 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8)) | |
self.block4_4 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8)) | |
self.conv2 = nn.Conv2d(20, 20, 5, groups=20, padding=(0, 2)) | |
self.conv3 = nn.Conv2d(20, 32, 1, bias=False) | |
self.conv4 = nn.Conv2d(32, output_dim, 1, bias=False) | |
self.mel_layer = mel_layer | |
self.criterion = nn.CrossEntropyLoss() | |
def forward(self, x): | |
#x: 2D [Batch_size,16000] | |
spec = self.mel_layer(x) | |
#spec: 3D [B,F(40),T] | |
spec = torch.log(spec+1e-10) | |
spec = spec.unsqueeze(1) | |
#spec: bcoz conv1 need 4D [B,1,F,T] | |
out = self.conv1(spec) | |
out = self.block1_1(out) | |
out = self.block1_2(out) | |
out = self.block2_1(out) | |
out = self.block2_2(out) | |
out = self.block3_1(out) | |
out = self.block3_2(out) | |
out = self.block3_3(out) | |
out = self.block3_4(out) | |
out = self.block4_1(out) | |
out = self.block4_2(out) | |
out = self.block4_3(out) | |
out = self.block4_4(out) | |
out = self.conv2(out) | |
out = self.conv3(out) | |
out = out.mean(-1, keepdim=True) | |
out = self.conv4(out) | |
#out: 4D [8, 35, 1, 1] | |
out = out.squeeze(2).squeeze(2) | |
#out: 2D | |
#crossentropy expect[B, C], so need to squeeze to be 2D | |
spec = spec.squeeze(1) | |
#spec: from 4D [B,1,F,T] to 3D [B,F,T] | |
#the return spec is for plot log_images, so need 3D | |
return out, spec | |
model = BCResNet_nnAudio() | |
model = model.to(device) | |
""" | |
Step 9: training model | |
""" | |
trainer = Trainer(gpus=gpus, max_epochs=max_epochs, | |
check_val_every_n_epoch= check_val_every_n_epoch, | |
num_sanity_val_steps=num_sanity_val_steps) | |
trainer.fit(model, trainloader, validloader) | |
trainer.test(model, testloader) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment