Skip to content

Instantly share code, notes, and snippets.

@relativeflux
Last active October 25, 2020 19:17
Show Gist options
  • Save relativeflux/16434d4022d78bf4bbe79201b39c9103 to your computer and use it in GitHub Desktop.
Save relativeflux/16434d4022d78bf4bbe79201b39c9103 to your computer and use it in GitHub Desktop.
Data augmentation for audio using audiomentations
import os
import fnmatch
import numpy as np
import argparse
import librosa
import soundfile as sf
from audiomentations import (Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift)
def get_arguments():
parser = argparse.ArgumentParser(description='Audio Data Augmentation')
parser.add_argument('--data_dir', type=str, required=True, help='Path to the data to be augmented.')
parser.add_argument('--output_dir', type=str, required=True, help='Output directory for the augmented data.')
return parser.parse_args()
args = get_arguments()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
def find_files(directory, pattern='*.wav'):
'''Recursively finds all files matching the pattern.'''
files = []
for root, dirnames, filenames in os.walk(directory):
for filename in fnmatch.filter(filenames, pattern):
files.append(os.path.join(root, filename))
return files
def yield_from_list(list):
list_idx = [i for i in range(len(list))]
for idx in range(len(list)):
yield list[list_idx[idx]]
def load_audio(data_dir):
files = find_files(data_dir)
if not files:
raise ValueError("No audio files found in '{}'.".format(data_dir))
for filename in yield_from_list(files):
samples, sr = librosa.load(filename, sr=None, mono=True)
yield (filename, samples, sr)
augment = Compose([
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5),
])
# Augment/transform/perturb the audio data, and write the results to disk...
print(f'Augmenting files in {args.data_dir}...')
for (filename, samples, sr) in load_audio(args.data_dir):
aug_samples = augment(samples=samples, sample_rate=sr)
filename = filename.split('/')[-1].split('.')[0]
aug_file_path = os.path.join(args.output_dir, f'{filename}_AUGMENTED.wav')
print(f'Wrting {aug_file_path}')
sf.write(aug_file_path, np.array(aug_samples), sr)
print('Done')
# Use like:
'''
python ../audio_data_augmentation.py \
--data_dir path/to/input/data \
--output_dir path/to/output/dir
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment