Skip to content

Instantly share code, notes, and snippets.

@cdfox
Last active August 9, 2020 22:37
Show Gist options
  • Save cdfox/63c2941f41757072bd17164165259c5b to your computer and use it in GitHub Desktop.
Save cdfox/63c2941f41757072bd17164165259c5b to your computer and use it in GitHub Desktop.
Reproducing Roblox DistilBERT Medium Post
# Reproducing Roblox DistilBERT Medium Post
# https://blog.roblox.com/2020/05/scaled-bert-serve-1-billion-daily-requests-cpus/
#
# 1. Launch C5 12xlarge with Deep Learning AMI (Ubuntu 18.04) Version 32.0 (ami-0dc2264cd927ca9eb)
# 2. pip install transformers[torch]
# 3. python reproduce_roblox_distilbert.py
import timeit
from transformers import DistilBertTokenizerFast, \
DistilBertForSequenceClassification
import torch
torch.set_num_threads(1)
text = """
Call me Ishmael. Some years ago—never mind how long precisely—having
little or no money in my purse, and nothing particular to interest me
on shore, I thought I would sail about a little and see the watery part
of the world. It is a way I have of driving off the spleen and
regulating the circulation. Whenever I find myself growing grim about
the mouth; whenever it is a damp, drizzly November in my soul; whenever
I find myself involuntarily pausing before coffin warehouses, and
bringing up the rear of every funeral I meet; and especially whenever
my hypos get such an upper hand of me, that it requires a strong moral
principle to prevent me from deliberately stepping into the street, and
methodically knocking people’s hats off—then, I account it high time to
get to sea as soon as I can. This is my substitute for pistol and ball.
With a philosophical flourish Cato throws himself upon his sword; I
quietly take to the ship. There is nothing surprising in this. If they
but knew it, almost all men in their degree, some time or other,
cherish very nearly the same feelings towards the ocean with me.
"""
# Make text longer than 512 tokens.
text = text * 3
tokenizer = DistilBertTokenizerFast.from_pretrained(
'distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained(
'distilbert-base-uncased')
model_quantized = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8)
def get_inputs(sequence_length):
return tokenizer(
text,
truncation='longest_first',
return_tensors="pt",
max_length=sequence_length)
def run(sequence_length):
with torch.no_grad():
inputs = get_inputs(sequence_length)
outputs = model(**inputs)
def run_quantized(sequence_length):
with torch.no_grad():
inputs = get_inputs(sequence_length)
outputs = model_quantized(**inputs)
def median_ms(stmt):
times = timeit.repeat(
stmt,
globals=globals(),
number=1,
repeat=100)
median_seconds = sorted(times)[len(times)//2]
return round(median_seconds * 1000)
def time():
print('Before quantization.')
print(f'512 tokens: {median_ms("run(512)")}ms')
print(f'256 tokens: {median_ms("run(256)")}ms')
print(f'128 tokens: {median_ms("run(128)")}ms')
print(f'64 tokens: {median_ms("run(64)")}ms')
print(f'32 tokens: {median_ms("run(32)")}ms')
print(f'16 tokens: {median_ms("run(16)")}ms')
print('After quantization.')
print(f'512 tokens: {median_ms("run_quantized(512)")}ms')
print(f'256 tokens: {median_ms("run_quantized(256)")}ms')
print(f'128 tokens: {median_ms("run_quantized(128)")}ms')
print(f'64 tokens: {median_ms("run_quantized(64)")}ms')
print(f'32 tokens: {median_ms("run_quantized(32)")}ms')
print(f'16 tokens: {median_ms("run_quantized(16)")}ms')
if __name__ == '__main__':
time()
# Before quantization.
# 512 tokens: 345ms
# 256 tokens: 170ms
# 128 tokens: 93ms
# 64 tokens: 61ms
# 32 tokens: 45ms
# 16 tokens: 37ms
# After quantization.
# 512 tokens: 196ms
# 256 tokens: 79ms
# 128 tokens: 37ms
# 64 tokens: 21ms
# 32 tokens: 14ms
# 16 tokens: 10ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment