Last active
May 20, 2020 07:02
-
-
Save pommedeterresautee/44e787f44f9d6821c1cb85c61adaeb64 to your computer and use it in GitHub Desktop.
Decrease Hugging Face Transformers training times by 2 - collator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ... | |
def pad_seq(seq: List[int], max_batch_len: int, pad_value: int) -> List[int]: | |
# IRL, use pad_sequence | |
# https://pytorch.org/docs/master/generated/torch.nn.utils.rnn.pad_sequence.html | |
return seq + (max_batch_len - len(seq)) * [pad_value] | |
@dataclass | |
class SmartCollator(DataCollator): | |
pad_token_id: int | |
def collate_batch(self, batch: List[Features]) -> Dict[str, torch.Tensor]: | |
batch_inputs = list() | |
batch_attention_masks = list() | |
labels = list() | |
# find the max length of the mini batch | |
max_size = max([len(ex.input_ids) for ex in batch]) | |
for item in batch: | |
# apply padding at the mini batch level | |
batch_inputs += [pad_seq(item.input_ids, max_size, self.pad_token_id)] | |
batch_attention_masks += [pad_seq(item.attention_mask, max_size, 0)] | |
labels.append(item.label) | |
# expected Transformers input format (dict of Tensors) | |
return {"input_ids": torch.tensor(batch_inputs, dtype=torch.long), | |
"attention_mask": torch.tensor(batch_attention_masks, dtype=torch.long), | |
"labels": torch.tensor(labels, dtype=torch.long) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment