Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Last active August 30, 2024 19:45
Show Gist options
  • Save youkaichao/215f0c315c532c90b8e7d1310596834a to your computer and use it in GitHub Desktop.
Save youkaichao/215f0c315c532c90b8e7d1310596834a to your computer and use it in GitHub Desktop.
torch.compile integration plan
import torch
from typing import Optional
from torch._dynamo.backends.common import aot_autograd
@torch.library.custom_op("custom::unified_attention", mutates_args=[])
def unified_attention(x: torch.Tensor, num_prefill_tokens: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
if cache.numel() == 0:
return x * 2
output = x.clone()
bs = x.size(0)
if num_prefill_tokens == 0:
... # call decode attention
else:
... # call prefill attention with x[:num_prefill_tokens]
... # call decode attention with x[num_prefill_tokens:]
output[:num_prefill_tokens] = 3 * x[:num_prefill_tokens]
output[num_prefill_tokens:] = 4 * x[num_prefill_tokens:]
return output
@unified_attention.register_fake
def _(x: torch.Tensor, num_prefill_tokens: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
eager_model = True
def custom_compiler(gm, inputs):
# compilation options
# option 1: pass the full graph to inductor
# option 2: run the model in eager mode
# option 3: find subgraph and replace with kernels inside vLLM
print(gm._graph.python_code(root_module="self", verbose=True).src)
# selction logic
static_shape_graphs = dict()
dynamic_shape_graph = None
def forward(*args, **kwargs):
nonlocal static_shape_graphs, dynamic_shape_graph
batchsize = ... # Question: how to get batchsize from args?
if dynamic_shape_graph is None:
# if the input is symbolic shape, compile with dynamic shape support
# gm.forward means compilation
dynamic_shape_graph = gm.forward
if eager_model:
return dynamic_shape_graph(*args, **kwargs)
if batchsize not in static_shape_graphs:
# if the input is static shape, compile with static shape support
# gm.forward means compilation
static_shape_graphs[batchsize] = gm.forward
return static_shape_graphs[batchsize](*args, **kwargs)
return forward
def target_fn(x: torch.Tensor, num_prefill_tokens: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
# cache is None: multiply by 50
# cache is not None: multiply by 75 for prefill, and 100 for decode
x = (x + 1) * 5 - 5
x = torch.ops.custom.unified_attention(x, num_prefill_tokens, cache)
x = (x + 2) * 5 - 10
return x
compiled_target_fn = torch.compile(backend=aot_autograd(fw_compiler=custom_compiler))(target_fn)
compiled_codes = []
def hook(old_colde, new_code):
if old_colde is target_fn.__code__:
compiled_codes.append(new_code)
torch._dynamo.convert_frame.register_bytecode_hook(hook)
def dispatcher(x, num_prefill_tokens: torch.Tensor, cache):
if len(compiled_codes) < 1:
return compiled_target_fn(x, num_prefill_tokens, cache)
else:
target_fn.__code__ = compiled_codes[0]
return target_fn(x, num_prefill_tokens, cache)
def test():
# profile run, without kv cache, fully static shape, max size
num_prefill_tokens = torch.tensor(0, dtype=torch.int64)
cache = torch.tensor([], dtype=torch.int64)
x = torch.ones(8, 1)
torch._dynamo.mark_dynamic(x, 0)
out = dispatcher(x, num_prefill_tokens, cache)
print(out)
# create cache
cache = torch.ones(1, 1)
# the following run with not trigger Dynamo/Aot Autograd
# if we are using `--enforce-eager`, we want this to directly run
# with compiled kernel that can handle dynamic shape
y = torch.ones(5, 1)
num_prefill_tokens = torch.tensor(3, dtype=torch.int64)
out = dispatcher(y, num_prefill_tokens, cache)
print(out)
eager_model = False
# if we are using cudagraph, this is an additional warmup to capture cuda graph
for i in [1, 2, 4,]:
y = torch.ones(i, 1)
num_prefill_tokens = torch.tensor(i, dtype=torch.int64)
out = dispatcher(y, num_prefill_tokens, cache)
# and then, for later runs, we can directly run with compiled kernel if the shape
# matches the recorded shape. if not, run with dynamic shape
y = torch.ones(4, 1)
num_prefill_tokens = torch.tensor(2, dtype=torch.int64)
out = dispatcher(y, num_prefill_tokens, cache)
print(out)
if __name__ == "__main__":
test()
@youkaichao
Copy link
Author

rough idea:

for profiling: multiply the input by 50
for prefill: multiply the input by 75
for decode: multiply the input by 100

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment