Last active
October 25, 2020 19:17
-
-
Save relativeflux/16434d4022d78bf4bbe79201b39c9103 to your computer and use it in GitHub Desktop.
Data augmentation for audio using audiomentations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import fnmatch | |
import numpy as np | |
import argparse | |
import librosa | |
import soundfile as sf | |
from audiomentations import (Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift) | |
def get_arguments(): | |
parser = argparse.ArgumentParser(description='Audio Data Augmentation') | |
parser.add_argument('--data_dir', type=str, required=True, help='Path to the data to be augmented.') | |
parser.add_argument('--output_dir', type=str, required=True, help='Output directory for the augmented data.') | |
return parser.parse_args() | |
args = get_arguments() | |
if not os.path.exists(args.output_dir): | |
os.makedirs(args.output_dir) | |
def find_files(directory, pattern='*.wav'): | |
'''Recursively finds all files matching the pattern.''' | |
files = [] | |
for root, dirnames, filenames in os.walk(directory): | |
for filename in fnmatch.filter(filenames, pattern): | |
files.append(os.path.join(root, filename)) | |
return files | |
def yield_from_list(list): | |
list_idx = [i for i in range(len(list))] | |
for idx in range(len(list)): | |
yield list[list_idx[idx]] | |
def load_audio(data_dir): | |
files = find_files(data_dir) | |
if not files: | |
raise ValueError("No audio files found in '{}'.".format(data_dir)) | |
for filename in yield_from_list(files): | |
samples, sr = librosa.load(filename, sr=None, mono=True) | |
yield (filename, samples, sr) | |
augment = Compose([ | |
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5), | |
TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5), | |
PitchShift(min_semitones=-4, max_semitones=4, p=0.5), | |
Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5), | |
]) | |
# Augment/transform/perturb the audio data, and write the results to disk... | |
print(f'Augmenting files in {args.data_dir}...') | |
for (filename, samples, sr) in load_audio(args.data_dir): | |
aug_samples = augment(samples=samples, sample_rate=sr) | |
filename = filename.split('/')[-1].split('.')[0] | |
aug_file_path = os.path.join(args.output_dir, f'{filename}_AUGMENTED.wav') | |
print(f'Wrting {aug_file_path}') | |
sf.write(aug_file_path, np.array(aug_samples), sr) | |
print('Done') | |
# Use like: | |
''' | |
python ../audio_data_augmentation.py \ | |
--data_dir path/to/input/data \ | |
--output_dir path/to/output/dir | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment