Skip to content

Instantly share code, notes, and snippets.

@ketanhdoshi
ketanhdoshi / sound_classification_inference.py
Created March 14, 2021 05:44
Sound Classification Inference
# ----------------------------
# Inference
# ----------------------------
def inference (model, val_dl):
correct_prediction = 0
total_prediction = 0
# Disable gradient updates
with torch.no_grad():
for data in val_dl:
@ketanhdoshi
ketanhdoshi / sound_classification_split.py
Last active March 14, 2021 13:43
Create Data Loaders for Training and Validation
from torch.utils.data import random_split
myds = SoundDS(df, data_path)
# Random split of 80:20 between training and validation
num_items = len(myds)
num_train = round(num_items * 0.8)
num_val = num_items - num_train
train_ds, val_ds = random_split(myds, [num_train, num_val])
@ketanhdoshi
ketanhdoshi / sound_classification_training.py
Last active March 14, 2021 05:44
Sound Classification Training
# ----------------------------
# Training Loop
# ----------------------------
def training(model, train_dl, num_epochs):
# Loss Function, Optimizer and Scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001,
steps_per_epoch=int(len(train_dl)),
epochs=num_epochs,
@ketanhdoshi
ketanhdoshi / sound_classification_model.py
Last active December 14, 2021 07:07
Sound Classification Model
import torch.nn.functional as F
from torch.nn import init
# ----------------------------
# Audio Classification Model
# ----------------------------
class AudioClassifier (nn.Module):
# ----------------------------
# Build the model architecture
# ----------------------------
@ketanhdoshi
ketanhdoshi / audio_metadata.py
Last active November 1, 2023 09:18
Audio Classification Metadata
# ----------------------------
# Prepare training data from Metadata file
# ----------------------------
import pandas as pd
from pathlib import Path
download_path = Path.cwd()/'UrbanSound8K'
# Read metadata file
metadata_file = download_path/'metadata'/'UrbanSound8K.csv'
from torch.utils.data import DataLoader, Dataset, random_split
import torchaudio
# ----------------------------
# Sound Dataset
# ----------------------------
class SoundDS(Dataset):
def __init__(self, df, data_path):
self.df = df
self.data_path = str(data_path)
@ketanhdoshi
ketanhdoshi / transform_spec_augment.py
Last active October 13, 2023 09:55
Transform SpecAugment on the Mel Spectrogram
# ----------------------------
# Augment the Spectrogram by masking out some sections of it in both the frequency
# dimension (ie. horizontal bars) and the time dimension (vertical bars) to prevent
# overfitting and to help the model generalise better. The masked sections are
# replaced with the mean value.
# ----------------------------
@staticmethod
def spectro_augment(spec, max_mask_pct=0.1, n_freq_masks=1, n_time_masks=1):
_, n_mels, n_steps = spec.shape
mask_value = spec.mean()
@ketanhdoshi
ketanhdoshi / transform_mel.py
Last active February 5, 2023 02:40
Transform Mel Spectrogram
# ----------------------------
# Generate a Spectrogram
# ----------------------------
@staticmethod
def spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None):
sig,sr = aud
top_db = 80
# spec has shape [channel, n_mels, time], where channel is mono, stereo etc
spec = transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)
@ketanhdoshi
ketanhdoshi / transform_resample.py
Last active March 7, 2021 09:40
Transform Change Sampling Rate
# ----------------------------
# Since Resample applies to a single channel, we resample one channel at a time
# ----------------------------
@staticmethod
def resample(aud, newsr):
sig, sr = aud
if (sr == newsr):
# Nothing to do
return aud
@ketanhdoshi
ketanhdoshi / transform_rechannel.py
Last active March 7, 2021 09:38
Transform Adjust Number of Channels
# ----------------------------
# Convert the given audio to the desired number of channels
# ----------------------------
@staticmethod
def rechannel(aud, new_channel):
sig, sr = aud
if (sig.shape[0] == new_channel):
# Nothing to do
return aud