Created
September 12, 2024 05:38
-
-
Save youkaichao/6aabb428f3762aeeca5d4ca566e02136 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from typing import Optional, Tuple, Union | |
torch.cuda.is_available() | |
def report_memory(prefix): | |
free, total = torch.cuda.mem_get_info() | |
used = total - free | |
print(f"{prefix}: Used: {used / 1024 / 1024} MB, Free: {free / 1024 / 1024} MB, Total: {total / 1024 / 1024} MB") | |
output_parallel = torch.randn(8192, 4096, dtype=torch.bfloat16, device="cuda") # 64 MB | |
residual: "bf16[8192, 4096]" = torch.empty_like(output_parallel) # 64 MB | |
weight = torch.randn(4096, dtype=torch.bfloat16, device="cuda") # 8KB | |
import vllm | |
def forward_native( | |
variance_epsilon, | |
weight, | |
x: torch.Tensor, | |
residual: Optional[torch.Tensor] = None, | |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
"""PyTorch-native implementation equivalent to forward().""" | |
orig_dtype = x.dtype | |
x = x.to(torch.float32) | |
if residual is not None: | |
x = x + residual.to(torch.float32) | |
residual = x.to(orig_dtype) | |
variance = x.pow(2).mean(dim=-1, keepdim=True) | |
x = x * torch.rsqrt(variance + variance_epsilon) | |
x = x.to(orig_dtype) * weight | |
if residual is None: | |
return x | |
else: | |
return x, residual | |
@torch.compile | |
def f_compile(output_parallel, residual, weight): | |
out, residual = forward_native(1e-5, weight, output_parallel, residual) | |
out.add_(1) | |
residual.add_(1) | |
return out, residual | |
def f_custom_op(output_parallel, residual, weight): | |
global torch | |
torch.ops._C.fused_add_rms_norm(output_parallel, residual, weight, 1e-5) | |
output_parallel.add_(1) | |
residual.add_(1) | |
return output_parallel, residual | |
report_memory("before") | |
f_custom_op(output_parallel, residual, weight) # takes 0 MB | |
# torch.compile(f_custom_op)(output_parallel, residual, weight) # takes 128 MB | |
# f_compile(output_parallel, residual, weight) # takes 128 MB | |
report_memory("after") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment