Created
October 31, 2024 21:15
-
-
Save youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 to your computer and use it in GitHub Desktop.
custom op overhead
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
import torch | |
from torch import nn | |
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
out: torch.Tensor) -> None: | |
out.copy_(q) | |
out += k | |
out += v | |
use_custom_op = True | |
if use_custom_op: | |
silly_attention = torch.library.custom_op("silly::attention", mutates_args=["out"])(silly_attention) | |
@silly_attention.register_fake | |
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
out: torch.Tensor) -> None: | |
return | |
@dataclass | |
class LlamaConfig: | |
hidden_size: int = 128 | |
mlp_size: int = 256 | |
vocab_size: int = 128 | |
num_layers: int = 2 | |
class LlamaMLP(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.gate_up_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.mlp_size * 2, | |
bias=False, | |
) | |
self.down_projection = nn.Linear( | |
in_features=config.mlp_size, | |
out_features=config.hidden_size, | |
bias=False, | |
) | |
self.gate_up_projection.weight.data.fill_(0.0) | |
self.down_projection.weight.data.fill_(0.0) | |
def forward(self, x): | |
x = self.gate_up_projection(x) | |
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( | |
x[:, x.size(1) // 2:]) | |
x = self.down_projection(x) | |
return x | |
class LlamaAttention(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.qkv_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.hidden_size * 3, | |
) | |
self.output_projection = nn.Linear( | |
in_features=config.hidden_size, | |
out_features=config.hidden_size, | |
) | |
self.qkv_projection.weight.data.fill_(0.0) | |
self.output_projection.weight.data.fill_(0.0) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
) -> torch.Tensor: | |
qkv = self.qkv_projection(hidden_states) | |
hidden_size = qkv.size(-1) // 3 | |
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1) | |
q = q + positions.unsqueeze(1) | |
k = k + positions.unsqueeze(1) | |
attn_output = torch.empty_like(q) | |
if use_custom_op: | |
torch.ops.silly.attention(q, k, v, attn_output) | |
else: | |
silly_attention(q, k, v, attn_output) | |
output = self.output_projection(attn_output) | |
return output | |
class LlamaDecoderLayer(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.self_attention = LlamaAttention(config) | |
self.mlp = LlamaMLP(config) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
residual: Optional[torch.Tensor], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if residual is None: | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
else: | |
hidden_states = hidden_states + residual | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
hidden_states = self.self_attention(positions=positions, | |
hidden_states=hidden_states) | |
hidden_states = hidden_states + residual | |
residual = hidden_states | |
hidden_states = hidden_states / 2 | |
hidden_states = self.mlp(hidden_states) | |
return hidden_states, residual | |
class LlamaModel(nn.Module): | |
def __init__(self, config: LlamaConfig) -> None: | |
super().__init__() | |
self.embedding_tokens = nn.Embedding( | |
num_embeddings=config.vocab_size, | |
embedding_dim=config.hidden_size, | |
) | |
self.layers = nn.ModuleList( | |
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]) | |
self.embedding_tokens.weight.data.fill_(0.0) | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor], | |
positions: torch.Tensor, | |
) -> torch.Tensor: | |
hidden_states = self.embedding_tokens(input_ids) | |
residual = None | |
for layer in self.layers: | |
hidden_states, residual = layer(positions, hidden_states, residual) | |
return hidden_states | |
@torch.inference_mode | |
def benchmark(): | |
from triton.testing import do_bench | |
cls = LlamaModel | |
# similar to llama 3.1-8B | |
llama_config = LlamaConfig(hidden_size=4096, | |
mlp_size=14336, | |
vocab_size=128 * 1024, | |
num_layers=32) | |
# a tiny model to measure the overhead | |
# of piecewise cudagraph | |
llama_config = LlamaConfig(hidden_size=40, | |
mlp_size=80, | |
vocab_size=128, | |
num_layers=2) | |
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] | |
eager_time = {} | |
full_cudagraph_time = {} | |
pool = torch.cuda.graph_pool_handle() | |
model = cls(llama_config).eval().cuda().to(torch.bfloat16) | |
B = 256 # max batch size | |
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() | |
positions = torch.arange(B).cuda().to(torch.bfloat16) | |
graphs = {} | |
model(input_ids, positions) | |
for b in cudagraph_sizes[::-1]: | |
graph = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(graph, pool=pool): | |
output = model(input_ids[:b], positions[:b]) | |
graphs[b] = (graph, output) | |
for b in cudagraph_sizes: | |
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa | |
eager_runtime = do_bench( | |
lambda: model(input_ids[:b], positions[:b])) # noqa | |
full_cudagraph_time[b] = runtime | |
eager_time[b] = eager_runtime | |
# print in tabular format | |
print("batch size\teager mode\tfull cudagraph") | |
for b in cudagraph_sizes: | |
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}")) | |
if __name__ == "__main__": | |
benchmark() |
my cpu:
$ lscpu
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8480C
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 2
Stepping: 8
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall n
x pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pn
i pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_de
adline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin c
dp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2
erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xs
aveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dt
herm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq
avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear seriali
ze tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization features:
Virtualization: VT-x
Caches (sum of all):
L1d: 5.3 MiB (112 instances)
L1i: 3.5 MiB (112 instances)
L2: 224 MiB (112 instances)
L3: 210 MiB (2 instances)
NUMA:
NUMA node(s): 2
NUMA node0 CPU(s): 0-55,112-167
NUMA node1 CPU(s): 56-111,168-223
Vulnerabilities:
Gather data sampling: Not affected
Itlb multihit: Not affected
L1tf: Not affected
Mds: Not affected
Meltdown: Not affected
Mmio stale data: Not affected
Retbleed: Not affected
Spec rstack overflow: Not affected
Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Spectre v2: Mitigation; Enhanced IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Srbds: Not affected
Tsx async abort: Not affected
the no mutation variant in https://gist.github.com/youkaichao/dcd04fcc42b276f5480c43b3690e51ea is basically the same .
running a large model with:
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
mlp_size=14336,
vocab_size=128 * 1024,
num_layers=32)
run with use_custom_op = True
:
batch size eager mode full cudagraph
1 7.129 6.474
2 7.192 6.614
4 7.282 6.668
8 7.339 6.665
16 7.399 6.792
24 7.551 6.842
32 7.631 6.975
40 7.580 6.912
48 7.583 6.932
56 7.623 6.979
64 7.701 7.033
72 8.214 7.535
80 8.290 7.594
88 8.345 7.642
96 8.367 7.691
104 8.344 7.642
112 8.391 7.682
120 8.430 7.725
128 8.538 7.778
136 9.142 8.482
144 9.092 8.444
152 9.121 8.458
160 9.156 8.492
168 9.075 8.421
176 9.124 8.465
184 9.158 8.514
192 9.217 8.564
200 10.093 9.349
208 10.191 9.402
216 10.172 9.463
224 10.307 9.539
232 10.184 9.513
240 10.222 9.549
248 10.262 9.597
256 10.312 9.636
run with use_custom_op = False
:
batch size eager mode full cudagraph
1 7.108 6.462
2 7.191 6.609
4 7.273 6.650
8 7.321 6.677
16 7.373 6.792
24 7.525 6.841
32 7.655 6.951
40 7.583 6.912
48 7.562 6.933
56 7.621 6.978
64 7.691 7.033
72 8.197 7.535
80 8.271 7.592
88 8.325 7.641
96 8.361 7.691
104 8.350 7.644
112 8.385 7.678
120 8.428 7.726
128 8.488 7.777
136 9.142 8.476
144 9.104 8.444
152 9.105 8.450
160 9.162 8.493
168 9.075 8.419
176 9.128 8.465
184 9.162 8.511
192 9.230 8.565
200 10.118 9.347
208 10.186 9.409
216 10.234 9.475
224 10.273 9.506
232 10.182 9.530
240 10.224 9.544
248 10.259 9.598
256 10.300 9.636
for this large model, the overhead of custom op is not significant.
however, as the overhead of custom op increases with the number of arguments ( and we have 1B/3B size small models, too), we still need to take care of it.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
run with
use_custom_op = True
:run with
use_custom_op = False
:full cudagraph
can be treated as the zero-overhead baseline.while
eager mode
has significant overhead, using custom op will bring additional 20~40 us overhead. (NOTE: the time in the table is in ms).from https://gist.github.com/youkaichao/14bde1121aab06a50872a5ec0227b1d2?permalink_comment_id=5262683#gistcomment-5262683 , it seems directly registering the op can remove the overhead: