Created
October 31, 2024 22:51
-
-
Save youkaichao/14bde1121aab06a50872a5ec0227b1d2 to your computer and use it in GitHub Desktop.
direct custom op
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
import torch | |
from torch import nn | |
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
out: torch.Tensor) -> None: | |
out.copy_(q) | |
out += k | |
out += v | |
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
out: torch.Tensor) -> None: | |
return | |
NO_CUSTOM_OP = 0 | |
USE_CUSTOM_OP = 1 | |
DIRECT_CUSTOM_OP = 2 | |
use_custom_op = DIRECT_CUSTOM_OP | |
if use_custom_op == USE_CUSTOM_OP: | |
silly_attention = torch.library.custom_op("silly::attention", mutates_args=["out"])(silly_attention) | |
silly_attention.register_fake(silly_attention_fake) | |
elif use_custom_op == DIRECT_CUSTOM_OP: | |
from torch.library import Library | |
my_lib = Library("silly", "FRAGMENT") | |
my_lib.define("attention(Tensor q, Tensor k, Tensor v, Tensor(a3!) out) -> ()") | |
my_lib.impl("attention", silly_attention, "CUDA") | |
my_lib._register_fake("attention", silly_attention_fake) | |
@dataclass | |
class LlamaConfig: | |
hidden_size: int = 128 | |
mlp_size: int = 256 | |
vocab_size: int = 128 | |
num_layers: int = 2 | |
class LlamaMLP(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.gate_up_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.mlp_size * 2, | |
bias=False, | |
) | |
self.down_projection = nn.Linear( | |
in_features=config.mlp_size, | |
out_features=config.hidden_size, | |
bias=False, | |
) | |
self.gate_up_projection.weight.data.fill_(0.0) | |
self.down_projection.weight.data.fill_(0.0) | |
def forward(self, x): | |
x = self.gate_up_projection(x) | |
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( | |
x[:, x.size(1) // 2:]) | |
x = self.down_projection(x) | |
return x | |
class LlamaAttention(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.qkv_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.hidden_size * 3, | |
) | |
self.output_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.hidden_size, | |
) | |
self.qkv_projection.weight.data.fill_(0.0) | |
self.output_projection.weight.data.fill_(0.0) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
) -> torch.Tensor: | |
qkv = self.qkv_projection(hidden_states) | |
hidden_size = qkv.size(-1) // 3 | |
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1) | |
q = q + positions.unsqueeze(1) | |
k = k + positions.unsqueeze(1) | |
attn_output = torch.empty_like(q) | |
if use_custom_op: | |
torch.ops.silly.attention(q, k, v, attn_output) | |
else: | |
silly_attention(q, k, v, attn_output) | |
output = self.output_projection(attn_output) | |
return output | |
class LlamaDecoderLayer(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.self_attention = LlamaAttention(config) | |
self.mlp = LlamaMLP(config) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
residual: Optional[torch.Tensor], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if residual is None: | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
else: | |
hidden_states = hidden_states + residual | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
hidden_states = self.self_attention(positions=positions, | |
hidden_states=hidden_states) | |
hidden_states = hidden_states + residual | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
hidden_states = self.mlp(hidden_states) | |
return hidden_states, residual | |
class LlamaModel(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.embedding_tokens = nn.Embedding( | |
num_embeddings=config.vocab_size, | |
embedding_dim=config.hidden_size, | |
) | |
self.layers = nn.ModuleList( | |
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]) | |
self.embedding_tokens.weight.data.fill_(0.0) | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor], | |
positions: torch.Tensor, | |
) -> torch.Tensor: | |
hidden_states = self.embedding_tokens(input_ids) | |
residual = None | |
for layer in self.layers: | |
hidden_states, residual = layer(positions, hidden_states, residual) | |
return hidden_states | |
@torch.inference_mode | |
def benchmark(): | |
from triton.testing import do_bench | |
cls = LlamaModel | |
# similar to llama 3.1-8B | |
llama_config = LlamaConfig(hidden_size=4096, | |
mlp_size=14336, | |
vocab_size=128 * 1024, | |
num_layers=32) | |
# a tiny model to measure the overhead | |
# of piecewise cudagraph | |
llama_config = LlamaConfig(hidden_size=40, | |
mlp_size=80, | |
vocab_size=128, | |
num_layers=2) | |
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] | |
eager_time = {} | |
full_cudagraph_time = {} | |
pool = torch.cuda.graph_pool_handle() | |
model = cls(llama_config).eval().cuda().to(torch.bfloat16) | |
B = 256 # max batch size | |
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() | |
positions = torch.arange(B).cuda().to(torch.bfloat16) | |
graphs = {} | |
model(input_ids, positions) | |
for b in cudagraph_sizes[::-1]: | |
graph = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(graph, pool=pool): | |
output = model(input_ids[:b], positions[:b]) | |
graphs[b] = (graph, output) | |
for b in cudagraph_sizes: | |
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa | |
eager_runtime = do_bench( | |
lambda: model(input_ids[:b], positions[:b])) # noqa | |
full_cudagraph_time[b] = runtime | |
eager_time[b] = eager_runtime | |
# print in tabular format | |
print("batch size\teager mode\tfull cudagraph") | |
for b in cudagraph_sizes: | |
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}")) | |
if __name__ == "__main__": | |
benchmark() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
running it gives: