Created
October 31, 2024 21:27
-
-
Save youkaichao/dcd04fcc42b276f5480c43b3690e51ea to your computer and use it in GitHub Desktop.
custom op overhead (no mutation)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
import torch | |
from torch import nn | |
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
out = q.clone() | |
out += k | |
out += v | |
return out | |
use_custom_op = False | |
if use_custom_op: | |
silly_attention = torch.library.custom_op("silly::attention", mutates_args=[])(silly_attention) | |
@silly_attention.register_fake | |
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
return torch.empty_like(q) | |
@dataclass | |
class LlamaConfig: | |
hidden_size: int = 128 | |
mlp_size: int = 256 | |
vocab_size: int = 128 | |
num_layers: int = 2 | |
class LlamaMLP(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.gate_up_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.mlp_size * 2, | |
bias=False, | |
) | |
self.down_projection = nn.Linear( | |
in_features=config.mlp_size, | |
out_features=config.hidden_size, | |
bias=False, | |
) | |
self.gate_up_projection.weight.data.fill_(0.0) | |
self.down_projection.weight.data.fill_(0.0) | |
def forward(self, x): | |
x = self.gate_up_projection(x) | |
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( | |
x[:, x.size(1) // 2:]) | |
x = self.down_projection(x) | |
return x | |
class LlamaAttention(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.qkv_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.hidden_size * 3, | |
) | |
self.output_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.hidden_size, | |
) | |
self.qkv_projection.weight.data.fill_(0.0) | |
self.output_projection.weight.data.fill_(0.0) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
) -> torch.Tensor: | |
qkv = self.qkv_projection(hidden_states) | |
hidden_size = qkv.size(-1) // 3 | |
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1) | |
q = q + positions.unsqueeze(1) | |
k = k + positions.unsqueeze(1) | |
if use_custom_op: | |
attn_output = torch.ops.silly.attention(q, k, v) | |
else: | |
attn_output = silly_attention(q, k, v) | |
output = self.output_projection(attn_output) | |
return output | |
class LlamaDecoderLayer(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.self_attention = LlamaAttention(config) | |
self.mlp = LlamaMLP(config) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
residual: Optional[torch.Tensor], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if residual is None: | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
else: | |
hidden_states = hidden_states + residual | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
hidden_states = self.self_attention(positions=positions, | |
hidden_states=hidden_states) | |
hidden_states = hidden_states + residual | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
hidden_states = self.mlp(hidden_states) | |
return hidden_states, residual | |
class LlamaModel(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.embedding_tokens = nn.Embedding( | |
num_embeddings=config.vocab_size, | |
embedding_dim=config.hidden_size, | |
) | |
self.layers = nn.ModuleList( | |
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]) | |
self.embedding_tokens.weight.data.fill_(0.0) | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor], | |
positions: torch.Tensor, | |
) -> torch.Tensor: | |
hidden_states = self.embedding_tokens(input_ids) | |
residual = None | |
for layer in self.layers: | |
hidden_states, residual = layer(positions, hidden_states, residual) | |
return hidden_states | |
@torch.inference_mode | |
def benchmark(): | |
from triton.testing import do_bench | |
cls = LlamaModel | |
# similar to llama 3.1-8B | |
llama_config = LlamaConfig(hidden_size=4096, | |
mlp_size=14336, | |
vocab_size=128 * 1024, | |
num_layers=32) | |
# a tiny model to measure the overhead | |
# of piecewise cudagraph | |
llama_config = LlamaConfig(hidden_size=40, | |
mlp_size=80, | |
vocab_size=128, | |
num_layers=2) | |
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] | |
eager_time = {} | |
full_cudagraph_time = {} | |
pool = torch.cuda.graph_pool_handle() | |
model = cls(llama_config).eval().cuda().to(torch.bfloat16) | |
B = 256 # max batch size | |
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() | |
positions = torch.arange(B).cuda().to(torch.bfloat16) | |
graphs = {} | |
model(input_ids, positions) | |
for b in cudagraph_sizes[::-1]: | |
graph = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(graph, pool=pool): | |
output = model(input_ids[:b], positions[:b]) | |
graphs[b] = (graph, output) | |
for b in cudagraph_sizes: | |
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa | |
eager_runtime = do_bench( | |
lambda: model(input_ids[:b], positions[:b])) # noqa | |
full_cudagraph_time[b] = runtime | |
eager_time[b] = eager_runtime | |
# print in tabular format | |
print("batch size\teager mode\tfull cudagraph") | |
for b in cudagraph_sizes: | |
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}")) | |
if __name__ == "__main__": | |
benchmark() |
I'm on a AMD EPYC 7763 64-Core Processor
torch 2.5.1, pytorch 3.12.7
run with use_custom_op = True
batch size eager mode full cudagraph
1 0.443 0.081
2 0.476 0.078
4 0.474 0.077
8 0.472 0.086
16 0.475 0.093
24 0.493 0.090
32 0.479 0.089
40 0.478 0.089
48 0.473 0.089
56 0.478 0.090
64 0.478 0.089
72 0.482 0.091
80 0.478 0.090
88 0.479 0.091
96 0.477 0.090
104 0.478 0.090
112 0.477 0.091
120 0.483 0.091
128 0.477 0.091
136 0.478 0.091
144 0.480 0.091
152 0.476 0.092
160 0.480 0.091
168 0.477 0.091
176 0.475 0.092
184 0.481 0.091
192 0.479 0.091
200 0.477 0.092
208 0.474 0.092
216 0.472 0.093
224 0.479 0.092
232 0.473 0.092
240 0.474 0.092
248 0.474 0.092
256 0.484 0.092
run with use_custom_op = False
batch size eager mode full cudagraph
1 0.384 0.081
2 0.413 0.077
4 0.414 0.077
8 0.414 0.086
16 0.416 0.093
24 0.415 0.090
32 0.417 0.088
40 0.416 0.089
48 0.437 0.089
56 0.416 0.090
64 0.418 0.089
72 0.418 0.091
80 0.417 0.090
88 0.422 0.091
96 0.416 0.090
104 0.417 0.090
112 0.417 0.091
120 0.418 0.091
128 0.419 0.091
136 0.417 0.091
144 0.423 0.091
152 0.419 0.092
160 0.418 0.091
168 0.418 0.091
176 0.418 0.092
184 0.418 0.091
192 0.419 0.091
200 0.418 0.092
208 0.420 0.091
216 0.417 0.093
224 0.419 0.092
232 0.417 0.092
240 0.419 0.092
248 0.419 0.092
256 0.419 0.092
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
no mutation oversion of https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
run with
use_custom_op = True
:run with
use_custom_op = False
: