Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Last active September 6, 2020 12:49
Show Gist options
  • Select an option

  • Save MLWhiz/4650ff77114e105206ba61a6af03e40e to your computer and use it in GitHub Desktop.

Select an option

Save MLWhiz/4650ff77114e105206ba61a6af03e40e to your computer and use it in GitHub Desktop.
def collate_text(batch):
# get text sequences in batch
data = [item[0] for item in batch]
# get labels in batch
target = [item[1] for item in batch]
# get max_seq_length in batch
max_seq_len = max([len(x) for x in data])
# pad text sequences based on max_seq_len
data = [np.pad(p, (0, max_seq_len - len(p)), 'constant') for p in data]
# convert data and target to tensor
data = torch.LongTensor(data)
target = torch.LongTensor(target)
return [data, target]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment