Last active
December 18, 2024 19:08
-
-
Save davidberard98/944a83014a9ccee3bc5f83e79b462765 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from triton.testing import do_bench | |
import torch | |
class MyModule(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.lin = torch.nn.Linear(2**14, 2**14).to(torch.bfloat16) | |
def forward(self, x): | |
return self.lin(x) | |
model = MyModule().to("cuda") | |
model_bf16 = MyModule().to("cuda") | |
from torchao.quantization import quantize_, int4_weight_only | |
group_size = 32 | |
# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through | |
# use_hqq flag for `int4_weight_only` quantization | |
use_hqq = False | |
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) | |
x = torch.rand(1, 2**14, dtype=torch.bfloat16, device="cuda") | |
time_int4 = do_bench(lambda: model(x)) | |
time_bf16 = do_bench(lambda: model_bf16(x)) | |
print(f"int4 : {time_int4}") | |
print(f"bf16 : {time_bf16}") | |
print(f" speedup : { time_bf16 / time_int4 }") | |
# On A100: | |
# | |
# int4 : 0.16697485744953156 | |
# bf16 : 0.36696189641952515 | |
# speedup : 2.197707499348757 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment