Last active
September 6, 2020 12:49
-
-
Save MLWhiz/4650ff77114e105206ba61a6af03e40e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def collate_text(batch): | |
| # get text sequences in batch | |
| data = [item[0] for item in batch] | |
| # get labels in batch | |
| target = [item[1] for item in batch] | |
| # get max_seq_length in batch | |
| max_seq_len = max([len(x) for x in data]) | |
| # pad text sequences based on max_seq_len | |
| data = [np.pad(p, (0, max_seq_len - len(p)), 'constant') for p in data] | |
| # convert data and target to tensor | |
| data = torch.LongTensor(data) | |
| target = torch.LongTensor(target) | |
| return [data, target] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment