Created
January 21, 2019 21:07
-
-
Save AntonOsika/f163c9d9e8201330b53f7d5bdc365627 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_sequences(X, y, mask, batch_size=32, seed=0): | |
""" | |
Returns a generator of batched timeseries padded to the longest sequence in the batch, | |
using right zero padding. | |
Can be used directly with model.fit_generator if X-keys matches keras Input tensors. | |
Note that masks typically have 1 dimension less than labels. | |
Args: | |
X, list. Each element is a dictionary of 'feature_name': np.array | |
y, list. Each element is np.array of labels | |
mask, list. Each element is a np.array of 1 / 0 used as weight in loss. | |
Should be an array of ones per default (which will be padded by zeros) | |
batch_size, int. Size of batch. | |
seed, int. Seed of randomness for permutations. Uses global np.random. | |
Returns: | |
Generator of (dict, np.array, np.array) batches. | |
The dict has same as keys as X. | |
steps, the number of steps in one epoch, computed from the batch_size. | |
""" | |
np.random.seed(seed) | |
steps = int(math.ceil(1.0 * len(y) / batch_size)) | |
def _gen(): | |
while True: | |
perm = np.random.permutation(range(len(y))) | |
for i in range(steps): | |
xx = {k: [] for k in X[0]} | |
yy = [] | |
mm = [] | |
for j in perm[i * batch_size:(i + 1) * batch_size]: | |
for k in xx: | |
xx[k].append(X[j][k]) | |
yy.append(y[j]) | |
mm.append(mask[j]) | |
xx = { | |
k: pad_sequences(v, padding='post', dtype=v[0].dtype) | |
for k, v in xx.items() | |
} | |
yy = pad_sequences(yy, padding='post') | |
mm = pad_sequences(mm, padding='post') | |
yield xx, yy, mm | |
return _gen(), steps |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment