Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active December 18, 2024 19:08
Show Gist options
  • Save davidberard98/944a83014a9ccee3bc5f83e79b462765 to your computer and use it in GitHub Desktop.
Save davidberard98/944a83014a9ccee3bc5f83e79b462765 to your computer and use it in GitHub Desktop.
from triton.testing import do_bench
import torch
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(2**14, 2**14).to(torch.bfloat16)
def forward(self, x):
return self.lin(x)
model = MyModule().to("cuda")
model_bf16 = MyModule().to("cuda")
from torchao.quantization import quantize_, int4_weight_only
group_size = 32
# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through
# use_hqq flag for `int4_weight_only` quantization
use_hqq = False
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
x = torch.rand(1, 2**14, dtype=torch.bfloat16, device="cuda")
time_int4 = do_bench(lambda: model(x))
time_bf16 = do_bench(lambda: model_bf16(x))
print(f"int4 : {time_int4}")
print(f"bf16 : {time_bf16}")
print(f" speedup : { time_bf16 / time_int4 }")
# On A100:
#
# int4 : 0.16697485744953156
# bf16 : 0.36696189641952515
# speedup : 2.197707499348757
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment