Created
July 27, 2016 17:40
-
-
Save keunwoochoi/8bb92d08541ad5add7d0e26455a1b151 to your computer and use it in GitHub Desktop.
How to prepare audio on-the-fly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/benanne/kaggle-ndsb/blob/master/buffering.py | |
import multiprocessing as mp | |
import Queue | |
import threading | |
def buffered_gen_mp(source_gen, buffer_size=2): | |
""" | |
Generator that runs a slow source generator in a separate process. | |
buffer_size: the maximal number of items to pre-generate (length of the buffer) | |
""" | |
if buffer_size < 2: | |
raise RuntimeError("Minimal buffer size is 2!") | |
buffer = mp.Queue(maxsize=buffer_size - 1) | |
# the effective buffer size is one less, because the generation process | |
# will generate one extra element and block until there is room in the buffer. | |
def _buffered_generation_process(source_gen, buffer): | |
for data in source_gen: | |
buffer.put(data, block=True) | |
buffer.put(None) # sentinel: signal the end of the iterator | |
buffer.close() # unfortunately this does not suffice as a signal: if buffer.get() | |
# was called and subsequently the buffer is closed, it will block forever. | |
process = mp.Process(target=_buffered_generation_process, args=(source_gen, buffer)) | |
process.start() | |
for data in iter(buffer.get, None): | |
yield data | |
def buffered_gen_threaded(source_gen, buffer_size=2): | |
""" | |
Generator that runs a slow source generator in a separate thread. Beware of the GIL! | |
buffer_size: the maximal number of items to pre-generate (length of the buffer) | |
""" | |
if buffer_size < 2: | |
raise RuntimeError("Minimal buffer size is 2!") | |
buffer = Queue.Queue(maxsize=buffer_size - 1) | |
# the effective buffer size is one less, because the generation process | |
# will generate one extra element and block until there is room in the buffer. | |
def _buffered_generation_thread(source_gen, buffer): | |
for data in source_gen: | |
buffer.put(data, block=True) | |
buffer.put(None) # sentinel: signal the end of the iterator | |
thread = threading.Thread(target=_buffered_generation_thread, args=(source_gen, buffer)) | |
thread.daemon = True | |
thread.start() | |
for data in iter(buffer.get, None): | |
yield data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import buffering | |
import time | |
import os, sys | |
import multiprocessing as mp | |
import librosa | |
import numpy as np | |
import pdb | |
import time | |
STFT_DURA = 10. | |
TRAIN_DURA = 3. | |
BATCH_SIZE = 8 | |
NUM_CPU = mp.cpu_count() | |
def training_nn_dummy(X, batch_idx): | |
print '-'*30 | |
# print 'Start training, size of input X', X.shape | |
print 'Start training of batch idx: %d' % batch_idx | |
print 'Will wait for %5.3f seconds as if it is learning something' % TRAIN_DURA | |
time.sleep(TRAIN_DURA) | |
print 'Work (super cool *DEEP* something) is done!' | |
print '-'*30 | |
return | |
def get_TF(x_path): | |
transform_func = librosa.stft | |
load_args = [] | |
trans_args = [] | |
x, _ = librosa.load(x_path, *load_args, duration=STFT_DURA) | |
X = transform_func(x, *trans_args) | |
print 'stft done for %s' % x_path | |
return X[np.newaxis, :] | |
def gen_TF(x_paths):#, transform_func, load_args, trans_args): | |
''' | |
generating some time-frequency representations | |
''' | |
transform_func = librosa.stft | |
load_args = [] | |
trans_args = [] | |
for i_file, x_path in enumerate(x_paths): | |
x, _ = librosa.load(x_path, *load_args, duration=STFT_DURA) | |
X = transform_func(x, *trans_args) | |
time.sleep(1.0) # let's say it takes too long | |
print ' '*30 + 'yielding one transform of [%d]' % i_file | |
yield librosa.logamplitude(np.abs(X))[np.newaxis, :] | |
def create_X_batch(x_paths, batch_size): | |
''' create_something function ''' | |
buffer_size = batch_size | |
gen = gen_TF(x_paths) | |
def gen_chunk(): | |
for tf in gen: | |
yield tf | |
return buffering.buffered_gen_mp(gen_chunk(), buffer_size) | |
def simple_TF_mp(source_gen, x_paths, batch_idx, buffer_size=2): | |
if buffer_size < 2: | |
raise RuntimeError("Minimal buffer size is 2!") | |
print "Let's do STFT or something for batch %d" % batch_idx | |
n_worker = min(buffer_size, NUM_CPU) | |
p = mp.Pool(n_worker) | |
# print 'multiprocessing starts!' | |
results = p.map(source_gen, x_paths) | |
data = np.zeros((0, results[0].shape[1], results[0].shape[2])) | |
for tf in results: | |
data = np.concatenate((data, tf), axis=0) | |
# print 'multiprocessing for %d files with %d workers is done' % (buffer_size, n_worker) | |
# print 'returning data is a shape of', data.shape | |
return data | |
def main(): | |
# queue | |
wav_path = '../Srcs/SS_HPS/' # whatever that has some wav files. | |
x_paths = [wav_path+p for p in os.listdir(wav_path) if p.endswith('.wav')][:40] | |
batch_size = BATCH_SIZE | |
accm_batch_idx = 0 | |
for i_epoch in range(2): | |
print '='*10 + ('EPOCH %d' % i_epoch) + '='*10 | |
num_batch = len(x_paths) / batch_size | |
gen_data = create_X_batch(x_paths, batch_size) | |
for i_batch in range(num_batch): | |
accm_batch_idx += 1 | |
x_paths_chunk = x_paths[i_batch*batch_size: (i_batch+1)*batch_size] | |
data = simple_TF_mp(get_TF, x_paths_chunk, accm_batch_idx, buffer_size=batch_size) | |
print "STFT or something for batch %d is ready." % accm_batch_idx | |
try: | |
print 'Waiting for the previous processing complete...' | |
p_train.join() | |
print 'Waiting is done!' | |
except: | |
pass | |
p_train = mp.Process(target=training_nn_dummy, args=(data,accm_batch_idx)) | |
p_train.start() | |
p_train.join() | |
if __name__=='__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment