Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created October 31, 2024 21:27
Show Gist options
  • Save youkaichao/dcd04fcc42b276f5480c43b3690e51ea to your computer and use it in GitHub Desktop.
Save youkaichao/dcd04fcc42b276f5480c43b3690e51ea to your computer and use it in GitHub Desktop.
custom op overhead (no mutation)
import os
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
out = q.clone()
out += k
out += v
return out
use_custom_op = False
if use_custom_op:
silly_attention = torch.library.custom_op("silly::attention", mutates_args=[])(silly_attention)
@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
return torch.empty_like(q)
@dataclass
class LlamaConfig:
hidden_size: int = 128
mlp_size: int = 256
vocab_size: int = 128
num_layers: int = 2
class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.gate_up_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.mlp_size * 2,
bias=False,
)
self.down_projection = nn.Linear(
in_features=config.mlp_size,
out_features=config.hidden_size,
bias=False,
)
self.gate_up_projection.weight.data.fill_(0.0)
self.down_projection.weight.data.fill_(0.0)
def forward(self, x):
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
x = self.down_projection(x)
return x
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.qkv_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size * 3,
)
self.output_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.qkv_projection.weight.data.fill_(0.0)
self.output_projection.weight.data.fill_(0.0)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_projection(hidden_states)
hidden_size = qkv.size(-1) // 3
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
q = q + positions.unsqueeze(1)
k = k + positions.unsqueeze(1)
if use_custom_op:
attn_output = torch.ops.silly.attention(q, k, v)
else:
attn_output = silly_attention(q, k, v)
output = self.output_projection(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.self_attention = LlamaAttention(config)
self.mlp = LlamaMLP(config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = hidden_states / 2
else:
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
self.embedding_tokens.weight.data.fill_(0.0)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embedding_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
return hidden_states
@torch.inference_mode
def benchmark():
from triton.testing import do_bench
cls = LlamaModel
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
mlp_size=14336,
vocab_size=128 * 1024,
num_layers=32)
# a tiny model to measure the overhead
# of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40,
mlp_size=80,
vocab_size=128,
num_layers=2)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
eager_time = {}
full_cudagraph_time = {}
pool = torch.cuda.graph_pool_handle()
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {}
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
for b in cudagraph_sizes:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench(
lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime
# print in tabular format
print("batch size\teager mode\tfull cudagraph")
for b in cudagraph_sizes:
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"))
if __name__ == "__main__":
benchmark()
@youkaichao
Copy link
Author

no mutation oversion of https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5

run with use_custom_op = True :

batch size      eager mode      full cudagraph
1       0.188   0.070
2       0.208   0.073
4       0.198   0.077
8       0.197   0.082
16      0.200   0.087
24      0.213   0.091
32      0.213   0.091
40      0.214   0.091
48      0.215   0.092
56      0.215   0.092
64      0.215   0.091
72      0.215   0.093
80      0.214   0.093
88      0.217   0.094
96      0.215   0.094
104     0.214   0.095
112     0.214   0.095
120     0.214   0.096
128     0.214   0.095
136     0.214   0.096
144     0.214   0.096
152     0.215   0.096
160     0.215   0.096
168     0.214   0.095
176     0.214   0.095
184     0.214   0.096
192     0.213   0.096
200     0.213   0.097
208     0.214   0.096
216     0.214   0.097
224     0.217   0.096
232     0.214   0.096
240     0.215   0.097
248     0.213   0.097
256     0.214   0.097

run with use_custom_op = False :

batch size      eager mode      full cudagraph
1       0.168   0.070
2       0.188   0.073
4       0.186   0.077
8       0.181   0.082
16      0.169   0.087
24      0.181   0.092
32      0.206   0.091
40      0.243   0.091
48      0.214   0.092
56      0.183   0.092
64      0.189   0.092
72      0.230   0.094
80      0.285   0.094
88      0.192   0.095
96      0.187   0.094
104     0.243   0.095
112     0.204   0.096
120     0.178   0.096
128     0.179   0.096
136     0.177   0.096
144     0.178   0.096
152     0.177   0.096
160     0.177   0.097
168     0.179   0.096
176     0.179   0.096
184     0.178   0.096
192     0.178   0.096
200     0.200   0.097
208     0.202   0.096
216     0.180   0.097
224     0.178   0.097
232     0.176   0.097
240     0.176   0.097
248     0.178   0.097
256     0.176   0.097

@tlrmchlsmth
Copy link

I'm on a AMD EPYC 7763 64-Core Processor
torch 2.5.1, pytorch 3.12.7

run with use_custom_op = True

batch size  eager mode  full cudagraph
1   0.443   0.081
2   0.476   0.078
4   0.474   0.077
8   0.472   0.086
16  0.475   0.093
24  0.493   0.090
32  0.479   0.089
40  0.478   0.089
48  0.473   0.089
56  0.478   0.090
64  0.478   0.089
72  0.482   0.091
80  0.478   0.090
88  0.479   0.091
96  0.477   0.090
104 0.478   0.090
112 0.477   0.091
120 0.483   0.091
128 0.477   0.091
136 0.478   0.091
144 0.480   0.091
152 0.476   0.092
160 0.480   0.091
168 0.477   0.091
176 0.475   0.092
184 0.481   0.091
192 0.479   0.091
200 0.477   0.092
208 0.474   0.092
216 0.472   0.093
224 0.479   0.092
232 0.473   0.092
240 0.474   0.092
248 0.474   0.092
256 0.484   0.092

run with use_custom_op = False

batch size	eager mode	full cudagraph
1	0.384	0.081
2	0.413	0.077
4	0.414	0.077
8	0.414	0.086
16	0.416	0.093
24	0.415	0.090
32	0.417	0.088
40	0.416	0.089
48	0.437	0.089
56	0.416	0.090
64	0.418	0.089
72	0.418	0.091
80	0.417	0.090
88	0.422	0.091
96	0.416	0.090
104	0.417	0.090
112	0.417	0.091
120	0.418	0.091
128	0.419	0.091
136	0.417	0.091
144	0.423	0.091
152	0.419	0.092
160	0.418	0.091
168	0.418	0.091
176	0.418	0.092
184	0.418	0.091
192	0.419	0.091
200	0.418	0.092
208	0.420	0.091
216	0.417	0.093
224	0.419	0.092
232	0.417	0.092
240	0.419	0.092
248	0.419	0.092
256	0.419	0.092

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment