Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created October 31, 2024 22:51
Show Gist options
  • Save youkaichao/14bde1121aab06a50872a5ec0227b1d2 to your computer and use it in GitHub Desktop.
Save youkaichao/14bde1121aab06a50872a5ec0227b1d2 to your computer and use it in GitHub Desktop.
direct custom op
import os
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
NO_CUSTOM_OP = 0
USE_CUSTOM_OP = 1
DIRECT_CUSTOM_OP = 2
use_custom_op = DIRECT_CUSTOM_OP
if use_custom_op == USE_CUSTOM_OP:
silly_attention = torch.library.custom_op("silly::attention", mutates_args=["out"])(silly_attention)
silly_attention.register_fake(silly_attention_fake)
elif use_custom_op == DIRECT_CUSTOM_OP:
from torch.library import Library
my_lib = Library("silly", "FRAGMENT")
my_lib.define("attention(Tensor q, Tensor k, Tensor v, Tensor(a3!) out) -> ()")
my_lib.impl("attention", silly_attention, "CUDA")
my_lib._register_fake("attention", silly_attention_fake)
@dataclass
class LlamaConfig:
hidden_size: int = 128
mlp_size: int = 256
vocab_size: int = 128
num_layers: int = 2
class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.gate_up_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.mlp_size * 2,
bias=False,
)
self.down_projection = nn.Linear(
in_features=config.mlp_size,
out_features=config.hidden_size,
bias=False,
)
self.gate_up_projection.weight.data.fill_(0.0)
self.down_projection.weight.data.fill_(0.0)
def forward(self, x):
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
x = self.down_projection(x)
return x
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.qkv_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size * 3,
)
self.output_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.qkv_projection.weight.data.fill_(0.0)
self.output_projection.weight.data.fill_(0.0)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_projection(hidden_states)
hidden_size = qkv.size(-1) // 3
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
q = q + positions.unsqueeze(1)
k = k + positions.unsqueeze(1)
attn_output = torch.empty_like(q)
if use_custom_op:
torch.ops.silly.attention(q, k, v, attn_output)
else:
silly_attention(q, k, v, attn_output)
output = self.output_projection(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.self_attention = LlamaAttention(config)
self.mlp = LlamaMLP(config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = hidden_states / 2
else:
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
self.embedding_tokens.weight.data.fill_(0.0)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embedding_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
return hidden_states
@torch.inference_mode
def benchmark():
from triton.testing import do_bench
cls = LlamaModel
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
mlp_size=14336,
vocab_size=128 * 1024,
num_layers=32)
# a tiny model to measure the overhead
# of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40,
mlp_size=80,
vocab_size=128,
num_layers=2)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
eager_time = {}
full_cudagraph_time = {}
pool = torch.cuda.graph_pool_handle()
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {}
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
for b in cudagraph_sizes:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench(
lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime
# print in tabular format
print("batch size\teager mode\tfull cudagraph")
for b in cudagraph_sizes:
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"))
if __name__ == "__main__":
benchmark()
@youkaichao
Copy link
Author

running it gives:

batch size      eager mode      full cudagraph
1       0.158   0.070
2       0.196   0.073
4       0.167   0.076
8       0.165   0.082
16      0.167   0.088
24      0.182   0.092
32      0.178   0.091
40      0.177   0.091
48      0.177   0.093
56      0.190   0.093
64      0.187   0.092
72      0.177   0.094
80      0.177   0.094
88      0.183   0.095
96      0.177   0.095
104     0.177   0.095
112     0.177   0.095
120     0.192   0.096
128     0.177   0.096
136     0.181   0.096
144     0.175   0.096
152     0.191   0.096
160     0.183   0.097
168     0.176   0.096
176     0.177   0.096
184     0.175   0.096
192     0.179   0.096
200     0.182   0.097
208     0.176   0.097
216     0.173   0.097
224     0.174   0.097
232     0.181   0.097
240     0.191   0.097
248     0.174   0.097
256     0.174   0.097

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment