Last active
May 5, 2024 02:00
-
-
Save zhuangh/df682b954e36484e6c274f3446b408aa to your computer and use it in GitHub Desktop.
MQA reshape_go_faster.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
baseline runtime(s) 0.5243263244628906 | |
with reshape runtime (s) 0.0022399425506591797 | |
@ cpu | |
========= | |
baseline runtime (s) 0.25386476516723633 | |
with reshape runtime (s) 0.0008966922760009766 | |
@ cuda:0 | |
""" | |
import torch | |
import torch.nn.functional as F | |
from einops import einsum, rearrange | |
import time | |
# Start timing | |
# b, seq, g*head, emb | |
qq = torch.ones(1, 256, 8, 64) | |
# b, seq, head, emb | |
kk = torch.ones(1, 256, 2, 64) | |
vv = torch.ones(1, 256, 2, 64) | |
num_head_groups = qq.shape[2] // kk.shape[2] | |
scale = qq.size(-1) ** 0.5 | |
start_time = time.time() | |
q = rearrange(qq, "b s (g h) e -> b s g h e", g=num_head_groups) | |
scores = einsum(q, kk, "b s g h e, b s m e -> b s h m") | |
att = F.softmax(scores / scale, dim=-1) | |
out1 = einsum(att, vv, "b s h m, b s m e -> b s h e") | |
end_time = time.time() - start_time | |
print("baseline runtime(s)", end_time) | |
start_time = time.time() | |
q = rearrange(qq, "b s g e -> b g s e") | |
v = rearrange(vv, "b s h e -> b h s e") | |
k = rearrange(kk, "b s h e -> b h s e") | |
q = rearrange(q, "b (g h) s e -> b g h s e", g=num_head_groups) | |
scores = einsum(q, k, "b g h s e, b h ss e -> b h s ss") | |
att = F.softmax(scores / scale, dim=-1) | |
out = einsum(att, v, "b h s ss, b h ss e -> b h s e") | |
out = rearrange(out, "b h s e -> b s h e") | |
end_time = time.time() - start_time | |
print("with reshape runtime (s)", end_time) | |
torch.testing.assert_close(out1, out) | |
print("@", out.device) | |
print("=========") | |
# Start timing | |
# b, seq, g*head, emb | |
qq = torch.ones(1, 256, 8, 64).to("cuda") | |
# b, seq, head, emb | |
kk = torch.ones(1, 256, 2, 64).to("cuda") | |
vv = torch.ones(1, 256, 2, 64).to("cuda") | |
num_head_groups = qq.shape[2] // kk.shape[2] | |
scale = qq.size(-1) ** 0.5 | |
start_time = time.time() | |
q = rearrange(qq, "b s (g h) e -> b s g h e", g=num_head_groups) | |
scores = einsum(q, kk, "b s g h e, b s m e -> b s h m") | |
att = F.softmax(scores / scale, dim=-1) | |
out1 = einsum(att, vv, "b s h m, b s m e -> b s h e") | |
end_time = time.time() - start_time | |
print("baseline runtime (s)", end_time) | |
start_time = time.time() | |
q = rearrange(qq, "b s g e -> b g s e") | |
v = rearrange(vv, "b s h e -> b h s e") | |
k = rearrange(kk, "b s h e -> b h s e") | |
q = rearrange(q, "b (g h) s e -> b g h s e", g=num_head_groups) | |
scores = einsum(q, k, "b g h s e, b h ss e -> b h s ss") | |
att = F.softmax(scores / scale, dim=-1) | |
out = einsum(att, v, "b h s ss, b h ss e -> b h s e") | |
out = rearrange(out, "b h s e -> b s h e") | |
end_time = time.time() - start_time | |
print("with reshape runtime (s)", end_time) | |
torch.testing.assert_close(out1, out) | |
print("@", out.device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment