Created
January 4, 2018 10:01
-
-
Save andreh7/b444c932b0e6dfda51c2efd19732581c to your computer and use it in GitHub Desktop.
script for packing Kaggle Tensorflow Speech Recognition Challenge data into single npy files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#---------------------------------------------------------------------- | |
# main | |
#---------------------------------------------------------------------- | |
import sys | |
argv = sys.argv[1:] | |
if len(argv) != 2: | |
print("usage: pack-for-gcp.py data-directory output-directory") | |
print() | |
print("WARNING: existing files with the same name in the output directory will be overwritten without warning") | |
sys.exit(1) | |
data_dir, output_dir = argv | |
import os, fnmatch | |
import numpy as np | |
from scipy.io import wavfile | |
from tqdm import tqdm | |
# we divide the samples into three samples: | |
# train files in directory train/ | |
# noise train samples longer than one second | |
# test test samples for leaderboard | |
# | |
# note that this does not take into account the list in files | |
# train/testing_list.txt and train/validation_list.txt | |
cwd = os.getcwd() | |
#---------- | |
os.chdir(data_dir) | |
# read list of validation files within the train set | |
validation_files = set(open("train/validation_list.txt").read().splitlines()) | |
import gc | |
for subdir in ('train', 'test'): | |
print("processing",subdir) | |
# key is sample name (see above) | |
filenames = {} | |
labels = {} | |
speakers = {} | |
is_validation = {} | |
wav_data = {} | |
# run garbage collection after trashing previous | |
# directory's data to minimize chances of running out of memory | |
gc.collect() | |
progbar = tqdm(mininterval = 1, unit = 'files') | |
for dirname, dirnames, fnames in os.walk(subdir): | |
if dirname == '.': | |
continue | |
if dirname.startswith("./"): | |
dirname = dirname[2:] | |
if dirname.startswith("train"): | |
dirkey = 'train' | |
elif dirname.startswith("test"): | |
dirkey = 'test' | |
else: | |
raise Exception("do not know what type of directory " + os.path.abspath(dirname) + " is") | |
label = None | |
for filename in fnmatch.filter(fnames, "*.wav"): | |
if dirkey != 'test' and label is None: | |
parts = dirname.split('/') | |
assert(parts[1] == 'audio'), "parts[1] is " + parts[1] | |
label = parts[2] | |
full_filename = os.path.join(dirname, filename) | |
# assume all files have the same sample rate (16 kHz) | |
# note that this will give np arrays with dtype int16 | |
sample_rate, samples = wavfile.read(full_filename) | |
if len(samples) > 16000: | |
assert dirkey == 'train' | |
key = 'noise' | |
else: | |
key = dirkey | |
filenames.setdefault(key,[]).append(full_filename) | |
wav_data.setdefault(key,[]).append(samples) | |
if label is not None: | |
labels.setdefault(key,[]).append(label) | |
if dirkey == 'train': | |
speakers.setdefault(key,[]).append(filename.split('_')[0]) | |
short_name = "/".join(full_filename.split("/")[-2:]) | |
is_validation.setdefault(key,[]).append(int(short_name in validation_files)) | |
progbar.update() | |
progbar.close() | |
os.chdir(cwd) | |
#---------- | |
# pad to maximum length seen for noise (longer) and train samples | |
#---------- | |
for key in ('noise', 'train'): | |
if not key in wav_data: | |
continue | |
this_wav_data = wav_data[key] | |
max_len = max([ len(u) for u in this_wav_data]) | |
for index in range(len(this_wav_data)): | |
if len(this_wav_data[index]) < max_len: | |
# pad on the right (append silence at the end) | |
this_wav_data[index] = np.pad(this_wav_data[index], (0, max_len - len(this_wav_data[index])), 'median') | |
assert(len(this_wav_data[index]) == max_len) | |
# make a copy of keys so we can delete entries | |
# while iterating | |
keys = list(wav_data.keys()) | |
for key in keys: | |
# stack | |
print("stacking", key) | |
wav_data[key] = np.stack(wav_data[key]) | |
# run garbage collection to minimize chances of running out of memory | |
gc.collect() | |
# do NOT convert to fp32 before storing to GCE -- | |
# this would double the storage volume ! | |
# wav_data[key] = wav_data[key].astype(np.float32) / np.iinfo(np.int16).max | |
num_rows = len(wav_data[key]) | |
# write out wav data | |
fname = os.path.join(output_dir, "wav-%s.npy" % key) | |
np.save(fname, wav_data[key]) | |
print("wrote",fname) | |
del wav_data[key] | |
gc.collect() | |
# write out file lists | |
fname = os.path.join(output_dir, "files-%s.csv" % key) | |
fout = open(fname, "w") | |
header = [ 'file' ] | |
data = [ filenames[key] ] | |
if key in labels: | |
header.append('label') | |
data.append(labels[key]) | |
if key in speakers: | |
header.append('speaker') | |
data.append(speakers[key]) | |
if key in is_validation: | |
header.append('is_validation') | |
data.append(is_validation[key]) | |
# print header | |
print(",".join(header), file = fout) | |
for items in zip(*data): | |
print(",".join([ str(x) for x in items]), file = fout) | |
fout.close() | |
print("wrote",fname) | |
assert len(wav_data) == 0, "unexpected keys left: " + ", ".join([ str(x) for x in wav_data.keys()]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment