Last active
August 9, 2020 22:37
-
-
Save cdfox/63c2941f41757072bd17164165259c5b to your computer and use it in GitHub Desktop.
Reproducing Roblox DistilBERT Medium Post
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reproducing Roblox DistilBERT Medium Post | |
# https://blog.roblox.com/2020/05/scaled-bert-serve-1-billion-daily-requests-cpus/ | |
# | |
# 1. Launch C5 12xlarge with Deep Learning AMI (Ubuntu 18.04) Version 32.0 (ami-0dc2264cd927ca9eb) | |
# 2. pip install transformers[torch] | |
# 3. python reproduce_roblox_distilbert.py | |
import timeit | |
from transformers import DistilBertTokenizerFast, \ | |
DistilBertForSequenceClassification | |
import torch | |
torch.set_num_threads(1) | |
text = """ | |
Call me Ishmael. Some years ago—never mind how long precisely—having | |
little or no money in my purse, and nothing particular to interest me | |
on shore, I thought I would sail about a little and see the watery part | |
of the world. It is a way I have of driving off the spleen and | |
regulating the circulation. Whenever I find myself growing grim about | |
the mouth; whenever it is a damp, drizzly November in my soul; whenever | |
I find myself involuntarily pausing before coffin warehouses, and | |
bringing up the rear of every funeral I meet; and especially whenever | |
my hypos get such an upper hand of me, that it requires a strong moral | |
principle to prevent me from deliberately stepping into the street, and | |
methodically knocking people’s hats off—then, I account it high time to | |
get to sea as soon as I can. This is my substitute for pistol and ball. | |
With a philosophical flourish Cato throws himself upon his sword; I | |
quietly take to the ship. There is nothing surprising in this. If they | |
but knew it, almost all men in their degree, some time or other, | |
cherish very nearly the same feelings towards the ocean with me. | |
""" | |
# Make text longer than 512 tokens. | |
text = text * 3 | |
tokenizer = DistilBertTokenizerFast.from_pretrained( | |
'distilbert-base-uncased') | |
model = DistilBertForSequenceClassification.from_pretrained( | |
'distilbert-base-uncased') | |
model_quantized = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8) | |
def get_inputs(sequence_length): | |
return tokenizer( | |
text, | |
truncation='longest_first', | |
return_tensors="pt", | |
max_length=sequence_length) | |
def run(sequence_length): | |
with torch.no_grad(): | |
inputs = get_inputs(sequence_length) | |
outputs = model(**inputs) | |
def run_quantized(sequence_length): | |
with torch.no_grad(): | |
inputs = get_inputs(sequence_length) | |
outputs = model_quantized(**inputs) | |
def median_ms(stmt): | |
times = timeit.repeat( | |
stmt, | |
globals=globals(), | |
number=1, | |
repeat=100) | |
median_seconds = sorted(times)[len(times)//2] | |
return round(median_seconds * 1000) | |
def time(): | |
print('Before quantization.') | |
print(f'512 tokens: {median_ms("run(512)")}ms') | |
print(f'256 tokens: {median_ms("run(256)")}ms') | |
print(f'128 tokens: {median_ms("run(128)")}ms') | |
print(f'64 tokens: {median_ms("run(64)")}ms') | |
print(f'32 tokens: {median_ms("run(32)")}ms') | |
print(f'16 tokens: {median_ms("run(16)")}ms') | |
print('After quantization.') | |
print(f'512 tokens: {median_ms("run_quantized(512)")}ms') | |
print(f'256 tokens: {median_ms("run_quantized(256)")}ms') | |
print(f'128 tokens: {median_ms("run_quantized(128)")}ms') | |
print(f'64 tokens: {median_ms("run_quantized(64)")}ms') | |
print(f'32 tokens: {median_ms("run_quantized(32)")}ms') | |
print(f'16 tokens: {median_ms("run_quantized(16)")}ms') | |
if __name__ == '__main__': | |
time() | |
# Before quantization. | |
# 512 tokens: 345ms | |
# 256 tokens: 170ms | |
# 128 tokens: 93ms | |
# 64 tokens: 61ms | |
# 32 tokens: 45ms | |
# 16 tokens: 37ms | |
# After quantization. | |
# 512 tokens: 196ms | |
# 256 tokens: 79ms | |
# 128 tokens: 37ms | |
# 64 tokens: 21ms | |
# 32 tokens: 14ms | |
# 16 tokens: 10ms |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment