Created
March 6, 2020 02:06
-
-
Save astariul/cc418d19e5e107f462bac306f53ba994 to your computer and use it in GitHub Desktop.
Quick benchmark to see the performance of BART between the FairSeq implementation and the HuggingFace implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import argparse | |
from tqdm import tqdm | |
from transformers import BartForConditionalGeneration, BartTokenizer | |
FS_MODEL = "FairSeq" | |
HF_MODEL = "HuggingFace" | |
class HuggingFace(): | |
def __init__(self, fp16=True): | |
self.bart = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True) | |
self.tokenizer = BartTokenizer.from_pretrained("bart-large-cnn") | |
self.bart.cuda() | |
self.bart.eval() | |
if fp16: | |
self.bart.half() | |
def predict(self, samples): | |
inputs = self.tokenizer.batch_encode_plus(samples, | |
max_length=self.bart.config.max_position_embeddings, | |
pad_to_max_length=True, | |
return_tensors="pt") | |
with torch.no_grad(): | |
hypo = self.bart.generate(input_ids=inputs["input_ids"].cuda(), | |
attention_mask=inputs["attention_mask"].cuda(), | |
num_beams=4, | |
length_penalty=2.0, | |
max_length=140, | |
min_len=55, | |
no_repeat_ngram_size=3) | |
return [self.tokenizer.decode(h, skip_special_tokens=True) for h in hypo] | |
class FairSeq(): | |
def __init__(self, fp16=True): | |
self.bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn') | |
self.bart.cuda() | |
self.bart.eval() | |
if fp16: | |
self.bart.half() | |
def predict(self, samples): | |
with torch.no_grad(): | |
hypo = self.bart.sample(samples, | |
beam=4, | |
lenpen=2.0, | |
max_len_b=140, | |
min_len=55, | |
no_repeat_ngram_size=3) | |
return hypo | |
def _get_samples(file): | |
with open(file) as source: | |
for line in source: | |
yield line.strip() | |
def timed_predict(model, batch): | |
t0 = time.time() | |
model.predict(batch) | |
t1 = time.time() | |
return t1 - t0 | |
def main(args): | |
if args.model == HF_MODEL: | |
model = HuggingFace(fp16=not args.fp32) | |
elif args.model == FS_MODEL: | |
model = FairSeq(fp16=not args.fp32) | |
count_sample = 0 | |
count_batch = 0 | |
batch = [] | |
tot_time = 0 | |
for sample in tqdm(_get_samples(args.source), total=args.samples): | |
count_sample += 1 | |
batch.append(sample) | |
if len(batch) % args.batch_size == 0: | |
count_batch += 1 | |
t = timed_predict(model, batch) | |
tot_time += t | |
batch = [] | |
if count_sample > args.samples: | |
break | |
if len(batch) != 0: | |
count_batch += 1 | |
t = timed_predict(model, batch) | |
tot_time += t | |
print("Using {} model, with batch size of {}, it took :\n".format(args.model, args.batch_size)) | |
print("{:.4f} s per batch\n".format(tot_time / count_batch)) | |
print("{:.4f} s per sample\n".format(tot_time / count_sample)) | |
print("(Average over {} samples)".format(args.samples)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Benchmark between HuggingFace BART and FairSeq BART. " | |
"Need latest version of transformers (master branch)") | |
parser.add_argument("--source", type=str, default="./cnndm/test.source") | |
parser.add_argument("--batch-size", type=int, default=16) | |
parser.add_argument("--samples", type=int, default=500) | |
parser.add_argument("--model", type=str, default=FS_MODEL, choices=[FS_MODEL, HF_MODEL]) | |
parser.add_argument('--fp32', dest='fp32', default=False, action='store_true') | |
main(parser.parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@kleag It's quite some time I wrote this code ^^
I don't remember, but I would try with the latest stable version of each package considering the date it was written.
The gist was written on Mar 6, 2020.
So I would try with the following versions :
transformers
2.5.1
fairseq
0.90
torch
1.4.0
Let me know if it still not work, I'll take a deeper look :)