Skip to content

Instantly share code, notes, and snippets.

@astariul
Created March 6, 2020 02:06
Show Gist options
  • Save astariul/cc418d19e5e107f462bac306f53ba994 to your computer and use it in GitHub Desktop.
Save astariul/cc418d19e5e107f462bac306f53ba994 to your computer and use it in GitHub Desktop.
Quick benchmark to see the performance of BART between the FairSeq implementation and the HuggingFace implementation
import time
import torch
import argparse
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer
FS_MODEL = "FairSeq"
HF_MODEL = "HuggingFace"
class HuggingFace():
def __init__(self, fp16=True):
self.bart = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True)
self.tokenizer = BartTokenizer.from_pretrained("bart-large-cnn")
self.bart.cuda()
self.bart.eval()
if fp16:
self.bart.half()
def predict(self, samples):
inputs = self.tokenizer.batch_encode_plus(samples,
max_length=self.bart.config.max_position_embeddings,
pad_to_max_length=True,
return_tensors="pt")
with torch.no_grad():
hypo = self.bart.generate(input_ids=inputs["input_ids"].cuda(),
attention_mask=inputs["attention_mask"].cuda(),
num_beams=4,
length_penalty=2.0,
max_length=140,
min_len=55,
no_repeat_ngram_size=3)
return [self.tokenizer.decode(h, skip_special_tokens=True) for h in hypo]
class FairSeq():
def __init__(self, fp16=True):
self.bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn')
self.bart.cuda()
self.bart.eval()
if fp16:
self.bart.half()
def predict(self, samples):
with torch.no_grad():
hypo = self.bart.sample(samples,
beam=4,
lenpen=2.0,
max_len_b=140,
min_len=55,
no_repeat_ngram_size=3)
return hypo
def _get_samples(file):
with open(file) as source:
for line in source:
yield line.strip()
def timed_predict(model, batch):
t0 = time.time()
model.predict(batch)
t1 = time.time()
return t1 - t0
def main(args):
if args.model == HF_MODEL:
model = HuggingFace(fp16=not args.fp32)
elif args.model == FS_MODEL:
model = FairSeq(fp16=not args.fp32)
count_sample = 0
count_batch = 0
batch = []
tot_time = 0
for sample in tqdm(_get_samples(args.source), total=args.samples):
count_sample += 1
batch.append(sample)
if len(batch) % args.batch_size == 0:
count_batch += 1
t = timed_predict(model, batch)
tot_time += t
batch = []
if count_sample > args.samples:
break
if len(batch) != 0:
count_batch += 1
t = timed_predict(model, batch)
tot_time += t
print("Using {} model, with batch size of {}, it took :\n".format(args.model, args.batch_size))
print("{:.4f} s per batch\n".format(tot_time / count_batch))
print("{:.4f} s per sample\n".format(tot_time / count_sample))
print("(Average over {} samples)".format(args.samples))
if __name__ == "__main__":
parser = argparse.ArgumentParser("Benchmark between HuggingFace BART and FairSeq BART. "
"Need latest version of transformers (master branch)")
parser.add_argument("--source", type=str, default="./cnndm/test.source")
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--samples", type=int, default=500)
parser.add_argument("--model", type=str, default=FS_MODEL, choices=[FS_MODEL, HF_MODEL])
parser.add_argument('--fp32', dest='fp32', default=False, action='store_true')
main(parser.parse_args())
@kleag
Copy link

kleag commented Oct 13, 2021

It does not work with current pytorch, transformers, and fairseq.
Do you remember which versions you used when you wrote this gist?

@astariul
Copy link
Author

@kleag It's quite some time I wrote this code ^^

I don't remember, but I would try with the latest stable version of each package considering the date it was written.
The gist was written on Mar 6, 2020.

So I would try with the following versions :

  • transformers 2.5.1
  • fairseq 0.90
  • torch 1.4.0

Let me know if it still not work, I'll take a deeper look :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment