Last active
September 19, 2016 15:25
-
-
Save mehdidc/ea52a3524f0ae614f665bd08eff7d38e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def minibatcher(fn, batchsize=1000): | |
""" | |
fn : a function that takes an input and returns an output | |
batchsize : divide the total input into divisions of size batchsize at most | |
iterate through all the divisions, call fn, get the results, | |
then concatenate all the results. | |
""" | |
def f(X): | |
results = [] | |
for sl in iterate_minibatches(len(X), batchsize): | |
results.append(fn(X[sl])) | |
return np.concatenate(results, axis=0) | |
return f | |
def iterate_minibatches(nb_inputs, batchsize, shuffle=False): | |
if shuffle: | |
indices = np.arange(nb_inputs) | |
np.random.shuffle(indices) | |
for start_idx in range(0, max(nb_inputs, nb_inputs - batchsize + 1), batchsize): | |
if shuffle: | |
excerpt = indices[start_idx:start_idx + batchsize] | |
else: | |
excerpt = slice(start_idx, start_idx + batchsize) | |
yield excerpt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment