Last active
August 30, 2024 19:45
-
-
Save youkaichao/215f0c315c532c90b8e7d1310596834a to your computer and use it in GitHub Desktop.
torch.compile integration plan
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from typing import Optional | |
from torch._dynamo.backends.common import aot_autograd | |
@torch.library.custom_op("custom::unified_attention", mutates_args=[]) | |
def unified_attention(x: torch.Tensor, num_prefill_tokens: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: | |
if cache.numel() == 0: | |
return x * 2 | |
output = x.clone() | |
bs = x.size(0) | |
if num_prefill_tokens == 0: | |
... # call decode attention | |
else: | |
... # call prefill attention with x[:num_prefill_tokens] | |
... # call decode attention with x[num_prefill_tokens:] | |
output[:num_prefill_tokens] = 3 * x[:num_prefill_tokens] | |
output[num_prefill_tokens:] = 4 * x[num_prefill_tokens:] | |
return output | |
@unified_attention.register_fake | |
def _(x: torch.Tensor, num_prefill_tokens: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: | |
return torch.empty_like(x) | |
eager_model = True | |
def custom_compiler(gm, inputs): | |
# compilation options | |
# option 1: pass the full graph to inductor | |
# option 2: run the model in eager mode | |
# option 3: find subgraph and replace with kernels inside vLLM | |
print(gm._graph.python_code(root_module="self", verbose=True).src) | |
# selction logic | |
static_shape_graphs = dict() | |
dynamic_shape_graph = None | |
def forward(*args, **kwargs): | |
nonlocal static_shape_graphs, dynamic_shape_graph | |
batchsize = ... # Question: how to get batchsize from args? | |
if dynamic_shape_graph is None: | |
# if the input is symbolic shape, compile with dynamic shape support | |
# gm.forward means compilation | |
dynamic_shape_graph = gm.forward | |
if eager_model: | |
return dynamic_shape_graph(*args, **kwargs) | |
if batchsize not in static_shape_graphs: | |
# if the input is static shape, compile with static shape support | |
# gm.forward means compilation | |
static_shape_graphs[batchsize] = gm.forward | |
return static_shape_graphs[batchsize](*args, **kwargs) | |
return forward | |
def target_fn(x: torch.Tensor, num_prefill_tokens: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: | |
# cache is None: multiply by 50 | |
# cache is not None: multiply by 75 for prefill, and 100 for decode | |
x = (x + 1) * 5 - 5 | |
x = torch.ops.custom.unified_attention(x, num_prefill_tokens, cache) | |
x = (x + 2) * 5 - 10 | |
return x | |
compiled_target_fn = torch.compile(backend=aot_autograd(fw_compiler=custom_compiler))(target_fn) | |
compiled_codes = [] | |
def hook(old_colde, new_code): | |
if old_colde is target_fn.__code__: | |
compiled_codes.append(new_code) | |
torch._dynamo.convert_frame.register_bytecode_hook(hook) | |
def dispatcher(x, num_prefill_tokens: torch.Tensor, cache): | |
if len(compiled_codes) < 1: | |
return compiled_target_fn(x, num_prefill_tokens, cache) | |
else: | |
target_fn.__code__ = compiled_codes[0] | |
return target_fn(x, num_prefill_tokens, cache) | |
def test(): | |
# profile run, without kv cache, fully static shape, max size | |
num_prefill_tokens = torch.tensor(0, dtype=torch.int64) | |
cache = torch.tensor([], dtype=torch.int64) | |
x = torch.ones(8, 1) | |
torch._dynamo.mark_dynamic(x, 0) | |
out = dispatcher(x, num_prefill_tokens, cache) | |
print(out) | |
# create cache | |
cache = torch.ones(1, 1) | |
# the following run with not trigger Dynamo/Aot Autograd | |
# if we are using `--enforce-eager`, we want this to directly run | |
# with compiled kernel that can handle dynamic shape | |
y = torch.ones(5, 1) | |
num_prefill_tokens = torch.tensor(3, dtype=torch.int64) | |
out = dispatcher(y, num_prefill_tokens, cache) | |
print(out) | |
eager_model = False | |
# if we are using cudagraph, this is an additional warmup to capture cuda graph | |
for i in [1, 2, 4,]: | |
y = torch.ones(i, 1) | |
num_prefill_tokens = torch.tensor(i, dtype=torch.int64) | |
out = dispatcher(y, num_prefill_tokens, cache) | |
# and then, for later runs, we can directly run with compiled kernel if the shape | |
# matches the recorded shape. if not, run with dynamic shape | |
y = torch.ones(4, 1) | |
num_prefill_tokens = torch.tensor(2, dtype=torch.int64) | |
out = dispatcher(y, num_prefill_tokens, cache) | |
print(out) | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
rough idea:
for profiling: multiply the input by 50
for prefill: multiply the input by 75
for decode: multiply the input by 100