Skip to content

Instantly share code, notes, and snippets.

@MaximumEntropy
Created January 17, 2017 22:29
Show Gist options
  • Save MaximumEntropy/4610b67fd8cdf2c18039f59cd1584274 to your computer and use it in GitHub Desktop.
Save MaximumEntropy/4610b67fd8cdf2c18039f59cd1584274 to your computer and use it in GitHub Desktop.
Returns a minibatch for teacher forcing with a mask and lengths of each sentence in minibatch
def get_minibatch(lines, index, batch_size, word2ind, max_len, add_start=False, add_end=True):
"""Prepare minibatch."""
if add_start and add_end:
lines = [
['<s>'] + line + ['</s>']
for line in lines[index:index + batch_size]
]
elif add_start and not add_end:
lines = [
['<s>'] + line
for line in lines[index:index + batch_size]
]
elif not add_start and add_end:
lines = [
line + ['</s>']
for line in lines[index:index + batch_size]
]
lines = [line[:max_len] for line in lines]
lens = [len(line) for line in lines]
max_len = max(lens)
input_lines = np.array([
[word2ind[w] if w in word2ind else word2ind['<unk>'] for w in line[:-1]] +
[word2ind['<pad>']] * (max_len - len(line))
for line in lines
]).astype(np.int32)
output_lines = np.array([
[word2ind[w] if w in word2ind else word2ind['<unk>'] for w in line[1:]] +
[word2ind['<pad>']] * (max_len - len(line))
for line in lines
]).astype(np.int32)
mask = np.array(
[
([1] * (l - 1)) + ([0] * (max_len - l))
for l in lens
]
).astype(np.float32)
return input_lines, output_lines, lens, mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment