Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created October 31, 2024 21:27
Show Gist options
  • Save youkaichao/dcd04fcc42b276f5480c43b3690e51ea to your computer and use it in GitHub Desktop.
Save youkaichao/dcd04fcc42b276f5480c43b3690e51ea to your computer and use it in GitHub Desktop.
custom op overhead (no mutation)
import os
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
out = q.clone()
out += k
out += v
return out
use_custom_op = False
if use_custom_op:
silly_attention = torch.library.custom_op("silly::attention", mutates_args=[])(silly_attention)
@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
return torch.empty_like(q)
@dataclass
class LlamaConfig:
hidden_size: int = 128
mlp_size: int = 256
vocab_size: int = 128
num_layers: int = 2
class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.gate_up_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.mlp_size * 2,
bias=False,
)
self.down_projection = nn.Linear(
in_features=config.mlp_size,
out_features=config.hidden_size,
bias=False,
)
self.gate_up_projection.weight.data.fill_(0.0)
self.down_projection.weight.data.fill_(0.0)
def forward(self, x):
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
x = self.down_projection(x)
return x
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.qkv_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size * 3,
)
self.output_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.qkv_projection.weight.data.fill_(0.0)
self.output_projection.weight.data.fill_(0.0)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_projection(hidden_states)
hidden_size = qkv.size(-1) // 3
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
q = q + positions.unsqueeze(1)
k = k + positions.unsqueeze(1)
if use_custom_op:
attn_output = torch.ops.silly.attention(q, k, v)
else:
attn_output = silly_attention(q, k, v)
output = self.output_projection(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.self_attention = LlamaAttention(config)
self.mlp = LlamaMLP(config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = hidden_states / 2
else:
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
self.embedding_tokens.weight.data.fill_(0.0)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embedding_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
return hidden_states
@torch.inference_mode
def benchmark():
from triton.testing import do_bench
cls = LlamaModel
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
mlp_size=14336,
vocab_size=128 * 1024,
num_layers=32)
# a tiny model to measure the overhead
# of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40,
mlp_size=80,
vocab_size=128,
num_layers=2)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
eager_time = {}
full_cudagraph_time = {}
pool = torch.cuda.graph_pool_handle()
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {}
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
for b in cudagraph_sizes:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench(
lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime
# print in tabular format
print("batch size\teager mode\tfull cudagraph")
for b in cudagraph_sizes:
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"))
if __name__ == "__main__":
benchmark()
@tlrmchlsmth
Copy link

I'm on a AMD EPYC 7763 64-Core Processor
torch 2.5.1, pytorch 3.12.7

run with use_custom_op = True

batch size  eager mode  full cudagraph
1   0.443   0.081
2   0.476   0.078
4   0.474   0.077
8   0.472   0.086
16  0.475   0.093
24  0.493   0.090
32  0.479   0.089
40  0.478   0.089
48  0.473   0.089
56  0.478   0.090
64  0.478   0.089
72  0.482   0.091
80  0.478   0.090
88  0.479   0.091
96  0.477   0.090
104 0.478   0.090
112 0.477   0.091
120 0.483   0.091
128 0.477   0.091
136 0.478   0.091
144 0.480   0.091
152 0.476   0.092
160 0.480   0.091
168 0.477   0.091
176 0.475   0.092
184 0.481   0.091
192 0.479   0.091
200 0.477   0.092
208 0.474   0.092
216 0.472   0.093
224 0.479   0.092
232 0.473   0.092
240 0.474   0.092
248 0.474   0.092
256 0.484   0.092

run with use_custom_op = False

batch size	eager mode	full cudagraph
1	0.384	0.081
2	0.413	0.077
4	0.414	0.077
8	0.414	0.086
16	0.416	0.093
24	0.415	0.090
32	0.417	0.088
40	0.416	0.089
48	0.437	0.089
56	0.416	0.090
64	0.418	0.089
72	0.418	0.091
80	0.417	0.090
88	0.422	0.091
96	0.416	0.090
104	0.417	0.090
112	0.417	0.091
120	0.418	0.091
128	0.419	0.091
136	0.417	0.091
144	0.423	0.091
152	0.419	0.092
160	0.418	0.091
168	0.418	0.091
176	0.418	0.092
184	0.418	0.091
192	0.419	0.091
200	0.418	0.092
208	0.420	0.091
216	0.417	0.093
224	0.419	0.092
232	0.417	0.092
240	0.419	0.092
248	0.419	0.092
256	0.419	0.092

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment