Skip to content

Instantly share code, notes, and snippets.

@pommedeterresautee
Last active August 24, 2021 05:00
Show Gist options
  • Save pommedeterresautee/a3e207cb3cd1249868adda0ea807852c to your computer and use it in GitHub Desktop.
Save pommedeterresautee/a3e207cb3cd1249868adda0ea807852c to your computer and use it in GitHub Desktop.
Decrease Hugging Face Transformers training times by 2 - dataset
def load_train_data(path: str, sort: bool) -> List[Example]:
sentences = list()
with open(path) as f:
first = False
for line in f:
if not first:
first = True
continue
text_a, text_b, label = line.rstrip().split("\t")
lab = len(text_a) + len(text_b)
sentences.append((lab, Example(text_a=text_a, text_b=text_b, label=label_codes[label])))
if sort:
# important operation, we order strings by length
sentences.sort(key=lambda x: x[0])
return [e for (_, e) in sentences]
def build_batches(sentences: List[Example], batch_size: int) -> List[Example]:
batch_ordered_sentences = list()
while len(sentences) > 0:
to_take = min(batch_size, len(sentences))
select = random.randint(0, len(sentences) - to_take)
batch_ordered_sentences += sentences[select:select + to_take]
del sentences[select:select + to_take]
return batch_ordered_sentences
class TextDataset(Dataset):
def __init__(self, tokenizer: PreTrainedTokenizer, pad_to_max_length: bool, max_len: int,
examples: List[Example]) -> None:
self.tokenizer = tokenizer
self.max_len = max_len
self.examples: List[Example] = examples
self.current = 0
self.pad_to_max_length = pad_to_max_length
def encode(self, ex: Example) -> Features:
encode_dict = self.tokenizer.encode_plus(text=ex.text_a,
text_pair=ex.text_b,
add_special_tokens=True,
max_length=self.max_len,
pad_to_max_length=self.pad_to_max_length,
return_token_type_ids=False,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
)
return Features(input_ids=encode_dict["input_ids"],
attention_mask=encode_dict["attention_mask"],
label=ex.label)
def __getitem__(self, _) -> Features:
# Trainer doesn't support IterableDataset (define a sampler)
# so we build a Dataset but we don't respect the index requested
if self.current == len(self.examples):
self.current = 0
ex = self.examples[self.current]
self.current += 1
return self.encode(ex=ex)
def __len__(self):
return len(self.examples)
# ...
train_sentences = load_train_data(path="resources/XNLI-MT-1.0/multinli/multinli.train.fr.tsv",
sort=model_args.smart_batching)
train_batches = build_batches(sentences=train_sentences, batch_size=training_args.per_gpu_train_batch_size)
train_set = TextDataset(tokenizer=tokenizer,
max_len=max_sequence_len,
examples=train_batches,
pad_to_max_length=not model_args.dynamic_padding)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment