Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created September 12, 2024 05:38
Show Gist options
  • Save youkaichao/6aabb428f3762aeeca5d4ca566e02136 to your computer and use it in GitHub Desktop.
Save youkaichao/6aabb428f3762aeeca5d4ca566e02136 to your computer and use it in GitHub Desktop.
import torch
from typing import Optional, Tuple, Union
torch.cuda.is_available()
def report_memory(prefix):
free, total = torch.cuda.mem_get_info()
used = total - free
print(f"{prefix}: Used: {used / 1024 / 1024} MB, Free: {free / 1024 / 1024} MB, Total: {total / 1024 / 1024} MB")
output_parallel = torch.randn(8192, 4096, dtype=torch.bfloat16, device="cuda") # 64 MB
residual: "bf16[8192, 4096]" = torch.empty_like(output_parallel) # 64 MB
weight = torch.randn(4096, dtype=torch.bfloat16, device="cuda") # 8KB
import vllm
def forward_native(
variance_epsilon,
weight,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) * weight
if residual is None:
return x
else:
return x, residual
@torch.compile
def f_compile(output_parallel, residual, weight):
out, residual = forward_native(1e-5, weight, output_parallel, residual)
out.add_(1)
residual.add_(1)
return out, residual
def f_custom_op(output_parallel, residual, weight):
global torch
torch.ops._C.fused_add_rms_norm(output_parallel, residual, weight, 1e-5)
output_parallel.add_(1)
residual.add_(1)
return output_parallel, residual
report_memory("before")
f_custom_op(output_parallel, residual, weight) # takes 0 MB
# torch.compile(f_custom_op)(output_parallel, residual, weight) # takes 128 MB
# f_compile(output_parallel, residual, weight) # takes 128 MB
report_memory("after")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment