Last active
October 31, 2024 17:48
-
-
Save kklemon/6935630ef5d08149bf0652222c572035 to your computer and use it in GitHub Desktop.
PyTorch Transformer Benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
bz = 128 | |
seq_len = 512 | |
d_model = 64 | |
n_heads = 8 | |
batch_first = False | |
if batch_first: | |
x = torch.randn(bz, seq_len, d_model).cuda() | |
else: | |
x = torch.randn(seq_len, bz, d_model).cuda() | |
dropout_rate = 0.2 | |
num_trials = 100 | |
transformer = nn.TransformerEncoder( | |
nn.TransformerEncoderLayer(d_model, n_heads, 512, batch_first=batch_first), | |
num_layers=8 | |
).cuda() | |
with torch.autocast('cuda'): | |
with torch.backends.cuda.sdp_kernel( | |
enable_flash=False, enable_math=True, enable_mem_efficient=False | |
): | |
# warmup | |
transformer(x) | |
torch.cuda.synchronize() | |
start = time.time() | |
for i in range(num_trials): | |
out = transformer(x) | |
out.mean().backward() | |
torch.cuda.synchronize() | |
end = time.time() | |
print('Standard attention took {} seconds for {} trials'.format(end - start, num_trials)) | |
with torch.backends.cuda.sdp_kernel( | |
enable_flash=True, enable_math=False, enable_mem_efficient=False | |
): | |
# warmup | |
transformer(x) | |
torch.cuda.synchronize() | |
start = time.time() | |
for i in range(num_trials): | |
out = transformer(x) | |
out.mean().backward() # .reshape(bz, seq_len, n_heads*dims) | |
torch.cuda.synchronize() | |
end = time.time() | |
print('Flash attention took {} seconds for {} trials'.format(end - start, num_trials)) | |
transformer = torch.compile(transformer) | |
with torch.backends.cuda.sdp_kernel( | |
enable_flash=False, enable_math=True, enable_mem_efficient=False | |
): | |
# warmup | |
transformer(x) | |
torch.cuda.synchronize() | |
start = time.time() | |
for i in range(num_trials): | |
out = transformer(x) | |
out.mean().backward() | |
torch.cuda.synchronize() | |
end = time.time() | |
print('Standard attention + torch.compile() took {} seconds for {} trials'.format(end - start, num_trials)) | |
with torch.backends.cuda.sdp_kernel( | |
enable_flash=True, enable_math=False, enable_mem_efficient=False | |
): | |
# warmup | |
transformer(x) | |
torch.cuda.synchronize() | |
start = time.time() | |
for i in range(num_trials): | |
out = transformer(x) | |
out.mean().backward() # .reshape(bz, seq_len, n_heads*dims) | |
torch.cuda.synchronize() | |
end = time.time() | |
print('Flash attention + torch.compile() took {} seconds for {} trials'.format(end - start, num_trials)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment