Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created October 31, 2024 21:15
Show Gist options
  • Save youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 to your computer and use it in GitHub Desktop.
Save youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 to your computer and use it in GitHub Desktop.
custom op overhead
import os
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
use_custom_op = True
if use_custom_op:
silly_attention = torch.library.custom_op("silly::attention", mutates_args=["out"])(silly_attention)
@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
@dataclass
class LlamaConfig:
hidden_size: int = 128
mlp_size: int = 256
vocab_size: int = 128
num_layers: int = 2
class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.gate_up_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.mlp_size * 2,
bias=False,
)
self.down_projection = nn.Linear(
in_features=config.mlp_size,
out_features=config.hidden_size,
bias=False,
)
self.gate_up_projection.weight.data.fill_(0.0)
self.down_projection.weight.data.fill_(0.0)
def forward(self, x):
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
x = self.down_projection(x)
return x
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.qkv_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size * 3,
)
self.output_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.qkv_projection.weight.data.fill_(0.0)
self.output_projection.weight.data.fill_(0.0)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_projection(hidden_states)
hidden_size = qkv.size(-1) // 3
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
q = q + positions.unsqueeze(1)
k = k + positions.unsqueeze(1)
attn_output = torch.empty_like(q)
if use_custom_op:
torch.ops.silly.attention(q, k, v, attn_output)
else:
silly_attention(q, k, v, attn_output)
output = self.output_projection(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.self_attention = LlamaAttention(config)
self.mlp = LlamaMLP(config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = hidden_states / 2
else:
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
self.embedding_tokens.weight.data.fill_(0.0)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embedding_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
return hidden_states
@torch.inference_mode
def benchmark():
from triton.testing import do_bench
cls = LlamaModel
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
mlp_size=14336,
vocab_size=128 * 1024,
num_layers=32)
# a tiny model to measure the overhead
# of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40,
mlp_size=80,
vocab_size=128,
num_layers=2)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
eager_time = {}
full_cudagraph_time = {}
pool = torch.cuda.graph_pool_handle()
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {}
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
for b in cudagraph_sizes:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench(
lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime
# print in tabular format
print("batch size\teager mode\tfull cudagraph")
for b in cudagraph_sizes:
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"))
if __name__ == "__main__":
benchmark()
@youkaichao
Copy link
Author

the no mutation variant in https://gist.github.com/youkaichao/dcd04fcc42b276f5480c43b3690e51ea is basically the same .

@youkaichao
Copy link
Author

running a large model with:

    # similar to llama 3.1-8B
    llama_config = LlamaConfig(hidden_size=4096,
                               mlp_size=14336,
                               vocab_size=128 * 1024,
                               num_layers=32)

run with use_custom_op = True :

batch size      eager mode      full cudagraph
1       7.129   6.474
2       7.192   6.614
4       7.282   6.668
8       7.339   6.665
16      7.399   6.792
24      7.551   6.842
32      7.631   6.975
40      7.580   6.912
48      7.583   6.932
56      7.623   6.979
64      7.701   7.033
72      8.214   7.535
80      8.290   7.594
88      8.345   7.642
96      8.367   7.691
104     8.344   7.642
112     8.391   7.682
120     8.430   7.725
128     8.538   7.778
136     9.142   8.482
144     9.092   8.444
152     9.121   8.458
160     9.156   8.492
168     9.075   8.421
176     9.124   8.465
184     9.158   8.514
192     9.217   8.564
200     10.093  9.349
208     10.191  9.402
216     10.172  9.463
224     10.307  9.539
232     10.184  9.513
240     10.222  9.549
248     10.262  9.597
256     10.312  9.636

run with use_custom_op = False :

batch size      eager mode      full cudagraph
1       7.108   6.462
2       7.191   6.609
4       7.273   6.650
8       7.321   6.677
16      7.373   6.792
24      7.525   6.841
32      7.655   6.951
40      7.583   6.912
48      7.562   6.933
56      7.621   6.978
64      7.691   7.033
72      8.197   7.535
80      8.271   7.592
88      8.325   7.641
96      8.361   7.691
104     8.350   7.644
112     8.385   7.678
120     8.428   7.726
128     8.488   7.777
136     9.142   8.476
144     9.104   8.444
152     9.105   8.450
160     9.162   8.493
168     9.075   8.419
176     9.128   8.465
184     9.162   8.511
192     9.230   8.565
200     10.118  9.347
208     10.186  9.409
216     10.234  9.475
224     10.273  9.506
232     10.182  9.530
240     10.224  9.544
248     10.259  9.598
256     10.300  9.636

for this large model, the overhead of custom op is not significant.

however, as the overhead of custom op increases with the number of arguments ( and we have 1B/3B size small models, too), we still need to take care of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment