This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------- | |
# Inference | |
# ---------------------------- | |
def inference (model, val_dl): | |
correct_prediction = 0 | |
total_prediction = 0 | |
# Disable gradient updates | |
with torch.no_grad(): | |
for data in val_dl: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import random_split | |
myds = SoundDS(df, data_path) | |
# Random split of 80:20 between training and validation | |
num_items = len(myds) | |
num_train = round(num_items * 0.8) | |
num_val = num_items - num_train | |
train_ds, val_ds = random_split(myds, [num_train, num_val]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------- | |
# Training Loop | |
# ---------------------------- | |
def training(model, train_dl, num_epochs): | |
# Loss Function, Optimizer and Scheduler | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(),lr=0.001) | |
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, | |
steps_per_epoch=int(len(train_dl)), | |
epochs=num_epochs, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn.functional as F | |
from torch.nn import init | |
# ---------------------------- | |
# Audio Classification Model | |
# ---------------------------- | |
class AudioClassifier (nn.Module): | |
# ---------------------------- | |
# Build the model architecture | |
# ---------------------------- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------- | |
# Prepare training data from Metadata file | |
# ---------------------------- | |
import pandas as pd | |
from pathlib import Path | |
download_path = Path.cwd()/'UrbanSound8K' | |
# Read metadata file | |
metadata_file = download_path/'metadata'/'UrbanSound8K.csv' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import DataLoader, Dataset, random_split | |
import torchaudio | |
# ---------------------------- | |
# Sound Dataset | |
# ---------------------------- | |
class SoundDS(Dataset): | |
def __init__(self, df, data_path): | |
self.df = df | |
self.data_path = str(data_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------- | |
# Augment the Spectrogram by masking out some sections of it in both the frequency | |
# dimension (ie. horizontal bars) and the time dimension (vertical bars) to prevent | |
# overfitting and to help the model generalise better. The masked sections are | |
# replaced with the mean value. | |
# ---------------------------- | |
@staticmethod | |
def spectro_augment(spec, max_mask_pct=0.1, n_freq_masks=1, n_time_masks=1): | |
_, n_mels, n_steps = spec.shape | |
mask_value = spec.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------- | |
# Generate a Spectrogram | |
# ---------------------------- | |
@staticmethod | |
def spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None): | |
sig,sr = aud | |
top_db = 80 | |
# spec has shape [channel, n_mels, time], where channel is mono, stereo etc | |
spec = transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------- | |
# Since Resample applies to a single channel, we resample one channel at a time | |
# ---------------------------- | |
@staticmethod | |
def resample(aud, newsr): | |
sig, sr = aud | |
if (sr == newsr): | |
# Nothing to do | |
return aud |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------- | |
# Convert the given audio to the desired number of channels | |
# ---------------------------- | |
@staticmethod | |
def rechannel(aud, new_channel): | |
sig, sr = aud | |
if (sig.shape[0] == new_channel): | |
# Nothing to do | |
return aud |
NewerOlder