Skip to content

Instantly share code, notes, and snippets.

@kygoh
Created October 2, 2022 00:21
Show Gist options
  • Save kygoh/1f429e49d82e830762f4b70ee8c7c092 to your computer and use it in GitHub Desktop.
Save kygoh/1f429e49d82e830762f4b70ee8c7c092 to your computer and use it in GitHub Desktop.
classifier on the Google speech commands dataset v2 for the Key Word Spotting (KWS) using more complex non-linear models based on https://github.com/heungky/nnAudio_tutorial
"""
Step 1: import related libraries
"""
# Libraries related to PyTorch
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import WeightedRandomSampler,DataLoader
import torch.optim as optim
# import torch.nn.functional as F
# Libraries related to PyTorch Lightning
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
# Libraries used in lightning module
from sklearn.metrics import precision_recall_fscore_support
# Libraried related to dataset
from AudioLoader.speech import SPEECHCOMMANDS_12C #for 12 classes KWS task
# nnAudio Front-end
from nnAudio.features.mel import MelSpectrogram
"""
Step 2: setting up configuration
"""
device = 'cuda:0'
gpus = 1
batch_size= 100
max_epochs = 200
check_val_every_n_epoch = 2
num_sanity_val_steps = 5
data_root = './' # Download the data here
download_option = True
n_mels = 40
#number of Mel bins
output_dim = 12
"""
Step 3: setting up nnAudio basis functions
"""
mel_layer = MelSpectrogram(sr=16000,
n_fft=480,
win_length=None,
n_mels=n_mels,
hop_length=160,
window='hann',
center=True,
pad_mode='reflect',
power=2.0,
htk=False,
fmin=0.0,
fmax=None,
norm=1,
trainable_mel=False,
trainable_STFT=False,
verbose=True)
"""
Step 4: setting up dataset
"""
trainset = SPEECHCOMMANDS_12C(root=data_root,
url='speech_commands_v0.02',
folder_in_archive='SpeechCommands',
download= download_option,subset= 'training')
validset = SPEECHCOMMANDS_12C(root=data_root,
url='speech_commands_v0.02',
folder_in_archive='SpeechCommands',
download= download_option,subset= 'validation')
testset = SPEECHCOMMANDS_12C(root=data_root,
url='speech_commands_v0.02',
folder_in_archive='SpeechCommands',
download= download_option,subset= 'testing')
"""
Step 5: data rebalancing
"""
class_weights = [1,1,1,1,1,1,1,1,1,1,4.6,1/17]
sample_weights = [0] * len(trainset)
#create a list as per length of trainset
for idx, (data,rate,label,speaker_id, _) in enumerate(trainset):
class_weight = class_weights[label]
sample_weights[idx] = class_weight
#apply sample_weights in each data base on their label class in class_weight
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights),replacement=True)
"""
Step 6: data processing and loading
"""
#data processing
def data_processing(data):
waveforms = []
labels = []
for batch in data:
waveforms.append(batch[0].squeeze(0)) #after squeeze => (audio_len) tensor # remove batch dim
labels.append(batch[2])
waveform_padded = nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
output_batch = {'waveforms': waveform_padded,
'labels': torch.tensor(labels),
}
return output_batch
#data loading
trainloader = DataLoader(trainset,
collate_fn=lambda x: data_processing(x),
batch_size=batch_size,sampler=sampler,num_workers=1)
validloader = DataLoader(validset,
collate_fn=lambda x: data_processing(x),
batch_size=batch_size,num_workers=1)
testloader = DataLoader(testset,
collate_fn=lambda x: data_processing(x),
batch_size=batch_size,num_workers=1)
"""
Step 7: setting up the lightning module
"""
class SpeechCommand(LightningModule):
def training_step(self, batch, batch_idx):
outputs, spec = self(batch['waveforms'])
#return outputs [2D] for calculate loss, return spec [3D] for visual
loss = self.criterion(outputs, batch['labels'].long())
acc = sum(outputs.argmax(-1) == batch['labels'])/outputs.shape[0] #batch wise
self.log('Train/acc', acc, on_step=False, on_epoch=True)
self.log('Train/Loss', loss, on_step=False, on_epoch=True)
#log(graph title, take acc as data, on_step: plot every step, on_epch: plot every epoch)
return loss
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
optimizer.step(closure=optimizer_closure)
with torch.no_grad():
torch.clamp_(self.mel_layer.mel_basis, 0, 1)
#after optimizer step, do clamp function on mel_basis
def validation_step(self, batch, batch_idx):
outputs, spec = self(batch['waveforms'])
loss = self.criterion(outputs, batch['labels'].long())
self.log('Validation/Loss', loss, on_step=False, on_epoch=True)
output_dict = {'outputs': outputs,
'labels': batch['labels']}
return output_dict
def validation_epoch_end(self, outputs):
pred = []
label = []
for output in outputs:
pred.append(output['outputs'])
label.append(output['labels'])
label = torch.cat(label, 0)
pred = torch.cat(pred, 0)
acc = sum(pred.argmax(-1) == label)/label.shape[0]
self.log('Validation/acc', acc, on_step=False, on_epoch=True)
#use the return value from validation_step: output_dict , to calculate the overall accuracy
def test_step(self, batch, batch_idx):
outputs, spec = self(batch['waveforms'])
loss = self.criterion(outputs, batch['labels'].long())
self.log('Test/Loss', loss, on_step=False, on_epoch=True)
output_dict = {'outputs': outputs,
'labels': batch['labels']}
return output_dict
def test_epoch_end(self, outputs):
pred = []
label = []
for output in outputs:
pred.append(output['outputs'])
label.append(output['labels'])
label = torch.cat(label, 0)
pred = torch.cat(pred, 0)
result_dict = {}
for key in [None, 'micro', 'macro', 'weighted']:
result_dict[key] = {}
p, r, f1, _ = precision_recall_fscore_support(label.cpu(), pred.argmax(-1).cpu(), average=key, zero_division=0)
result_dict[key]['precision'] = p
result_dict[key]['recall'] = r
result_dict[key]['f1'] = f1
acc = sum(pred.argmax(-1) == label)/label.shape[0]
self.log('Test/acc', acc, on_step=False, on_epoch=True)
self.log('Test/micro_f1', result_dict['micro']['f1'], on_step=False, on_epoch=True)
self.log('Test/macro_f1', result_dict['macro']['f1'], on_step=False, on_epoch=True)
self.log('Test/weighted_f1', result_dict['weighted']['f1'], on_step=False, on_epoch=True)
return result_dict
def configure_optimizers(self):
model_param = []
for name, params in self.named_parameters():
if 'mel_layer.' in name:
pass
else:
model_param.append(params)
optimizer = optim.SGD(model_param, lr=1e-3, momentum= 0.9, weight_decay= 0.001)
return [optimizer]
"""
Step 8: setting up model
"""
#BC_ResNet model
class SubSpectralNorm(LightningModule):
def __init__(self, C, S, eps=1e-5):
super(SubSpectralNorm, self).__init__()
self.S = S
self.eps = eps
self.bn = nn.BatchNorm2d(C*S)
def forward(self, x):
# x: input features with shape {N, C, F, T}
# S: number of sub-bands
N, C, F, T = x.size()
x = x.view(N, C * self.S, F // self.S, T)
x = self.bn(x)
return x.view(N, C, F, T)
class BroadcastedBlock(LightningModule):
def __init__(
self,
planes: int,
dilation=1,
stride=1,
temp_pad=(0, 1),
) -> None:
super(BroadcastedBlock, self).__init__()
self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes,
dilation=dilation,
stride=stride, bias=False)
self.ssn1 = SubSpectralNorm(planes, 5)
self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes,
dilation=dilation, stride=stride, bias=False)
self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.channel_drop = nn.Dropout2d(p=0.1)
self.swish = nn.SiLU()
self.conv1x1 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False)
def forward(self, x: Tensor) -> Tensor:
identity = x
# f2
##########################
out = self.freq_dw_conv(x)
out = self.ssn1(out)
##########################
auxilary = out
out = out.mean(2, keepdim=True) # frequency average pooling
# f1
############################
out = self.temp_dw_conv(out)
out = self.bn(out)
out = self.swish(out)
out = self.conv1x1(out)
out = self.channel_drop(out)
############################
out = out + identity + auxilary
out = self.relu(out)
return out
class TransitionBlock(LightningModule):
def __init__(
self,
inplanes: int,
planes: int,
dilation=1,
stride=1,
temp_pad=(0, 1),
) -> None:
super(TransitionBlock, self).__init__()
self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes,
stride=stride,
dilation=dilation, bias=False)
self.ssn = SubSpectralNorm(planes, 5)
self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes,
dilation=dilation, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.channel_drop = nn.Dropout2d(p=0.5)
self.swish = nn.SiLU()
self.conv1x1_1 = nn.Conv2d(inplanes, planes, kernel_size=(1, 1), bias=False)
self.conv1x1_2 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False)
def forward(self, x: Tensor) -> Tensor:
# f2
#############################
out = self.conv1x1_1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.freq_dw_conv(out)
out = self.ssn(out)
#############################
auxilary = out
out = out.mean(2, keepdim=True) # frequency average pooling
# f1
#############################
out = self.temp_dw_conv(out)
out = self.bn2(out)
out = self.swish(out)
out = self.conv1x1_2(out)
out = self.channel_drop(out)
#############################
out = auxilary + out
out = self.relu(out)
return out
class BCResNet_nnAudio(SpeechCommand):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 5, stride=(2, 1), padding=(2, 2))
self.block1_1 = TransitionBlock(16, 8)
self.block1_2 = BroadcastedBlock(8)
self.block2_1 = TransitionBlock(8, 12, stride=(2, 1), dilation=(1, 2), temp_pad=(0, 2))
self.block2_2 = BroadcastedBlock(12, dilation=(1, 2), temp_pad=(0, 2))
self.block3_1 = TransitionBlock(12, 16, stride=(2, 1), dilation=(1, 4), temp_pad=(0, 4))
self.block3_2 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
self.block3_3 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
self.block3_4 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
self.block4_1 = TransitionBlock(16, 20, dilation=(1, 8), temp_pad=(0, 8))
self.block4_2 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
self.block4_3 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
self.block4_4 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
self.conv2 = nn.Conv2d(20, 20, 5, groups=20, padding=(0, 2))
self.conv3 = nn.Conv2d(20, 32, 1, bias=False)
self.conv4 = nn.Conv2d(32, output_dim, 1, bias=False)
self.mel_layer = mel_layer
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
#x: 2D [Batch_size,16000]
spec = self.mel_layer(x)
#spec: 3D [B,F(40),T]
spec = torch.log(spec+1e-10)
spec = spec.unsqueeze(1)
#spec: bcoz conv1 need 4D [B,1,F,T]
out = self.conv1(spec)
out = self.block1_1(out)
out = self.block1_2(out)
out = self.block2_1(out)
out = self.block2_2(out)
out = self.block3_1(out)
out = self.block3_2(out)
out = self.block3_3(out)
out = self.block3_4(out)
out = self.block4_1(out)
out = self.block4_2(out)
out = self.block4_3(out)
out = self.block4_4(out)
out = self.conv2(out)
out = self.conv3(out)
out = out.mean(-1, keepdim=True)
out = self.conv4(out)
#out: 4D [8, 35, 1, 1]
out = out.squeeze(2).squeeze(2)
#out: 2D
#crossentropy expect[B, C], so need to squeeze to be 2D
spec = spec.squeeze(1)
#spec: from 4D [B,1,F,T] to 3D [B,F,T]
#the return spec is for plot log_images, so need 3D
return out, spec
model = BCResNet_nnAudio()
model = model.to(device)
"""
Step 9: training model
"""
trainer = Trainer(gpus=gpus, max_epochs=max_epochs,
check_val_every_n_epoch= check_val_every_n_epoch,
num_sanity_val_steps=num_sanity_val_steps)
trainer.fit(model, trainloader, validloader)
trainer.test(model, testloader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment