Created
October 7, 2024 16:25
-
-
Save HDCharles/cd59f1df5b3e2b31d869409144fe03ec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 ipython3 benchmark_triton.py #select the right number of threads based on your machine | |
#You can change the matmul_dtype: GEMM, GEMV or AUTO | |
#Note: bfloat16 only supported in GEMM mode with float32 accumulation | |
################################################################################################################################# | |
import torch | |
import numpy as np | |
device = 'cuda:0' | |
compute_dtype = torch.float16 | |
in_feature_list = [4096, 4096*2, 4096*4, 4096*8] | |
out_feature_list = [4096, 4096*2, 4096*4, 4096*8] | |
W_nbits_list = [4, 2] | |
group_size_list = [128] | |
for in_features, out_features in zip(in_feature_list, out_feature_list): | |
for W_nbits in W_nbits_list: | |
for group_size in group_size_list: | |
#in_features, out_features = 4096, 4096 | |
#in_features, out_features = 4096*2, 4096*2 | |
#in_features, out_features = 4096*4, 4096*4 | |
# in_features, out_features = 4096*8, 4096*8 | |
#W_nbits, group_size = 8, in_features | |
# W_nbits, group_size = 4, 128 | |
#W_nbits, group_size = 2, 128 | |
matmul_type = "AUTO" #GEMM, GEMV, "AUTO" | |
################################################################################################################################# | |
from triton.testing import do_bench | |
def eval_time(fct, params): | |
return do_bench(lambda: fct(**params), warmup=25, rep=200, fast_flush=True, return_mode='min') | |
def check_valid(x, W, quant_linear, tol=1e-3): | |
y_ref = torch.matmul(x, W.T) | |
y_q = quant_linear(x) | |
try: | |
assert (y_ref - y_q).abs().mean() < tol | |
except: | |
raise Error('Assertion Failed') | |
################################################################################################################################# | |
#TorchAO Int8 settings | |
torch._dynamo.config.capture_scalar_outputs = True | |
torch._inductor.config.coordinate_descent_tuning = True | |
@torch.compile() | |
def matmul_torch_A16W8SYM(x, W_q, scales, out_features): | |
out_shape = x.shape[:-1] + (out_features,) | |
out = ((x.view((-1, x.shape[-1])) @ W_q.T.to(x.dtype)) / scales.view(1, -1)).view(out_shape) | |
return out | |
class Torch_A16W8SYM(torch.nn.Module): | |
def __init__(self, in_features, out_features, W_q, scales, bias=None): | |
super().__init__() | |
self.W_q = W_q | |
self.in_features = in_features | |
self.out_features = out_features | |
self.scales = scales | |
self.bias = bias | |
self.device = W_q.device | |
self.compute_dtype = scales.dtype | |
def forward(self, x): | |
out = matmul_torch_A16W8SYM(x.to(self.device), self.W_q, self.scales, self.out_features) | |
if(self.bias is not None): | |
out += self.bias | |
return out | |
#HQQ | |
from hqq.core.quantize import * | |
from hqq.backends.bitblas import patch_hqq_to_bitblas, HQQLinearBitBlas | |
from hqq.backends.torchao import patch_hqq_to_aoint4 | |
HQQLinearBitBlas.check = lambda hqq_layer: True | |
HQQLinearBitBlas.BIT_TO_DTYPE = {8:"uint8", 4: "uint4", 2: "uint2", 1: "uint1"} | |
#GemLite | |
from gemlite.core import GemLiteLinearTriton, DType | |
class empty_linear(torch.nn.Module): | |
def __init__(self, in_features, out_features, compute_dtype, device): | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.device = device | |
self.compute_dtype = compute_dtype | |
def gen_data(in_features, out_features, W_nbits, group_size, device=device): | |
linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False, device='cpu') | |
quant_config = BaseQuantizeConfig(nbits=W_nbits, group_size=group_size, quant_zero=False, quant_scale=False, axis=1) | |
quant_config['weight_quant_params']['optimize'] = False | |
hqq_layer = HQQLinear(linear, quant_config=quant_config, compute_dtype=compute_dtype, device=device, del_orig=False) #bfloat16 | |
orig_shape = (out_features, in_features) | |
W = hqq_layer.dequantize().reshape(orig_shape) | |
gemlite_linear, torchao_linear, bitblas_linear, marlin_linear = [None]*4 | |
#GemLite | |
if(W_nbits in [8, 4, 2, 1]): | |
input_dtype, output_dtype, acc_dtype = None, None, None | |
if(compute_dtype == torch.float16): | |
input_dtype, output_dtype, acc_dtype = DType.FP16, DType.FP16, DType.FP16 | |
if(compute_dtype == torch.bfloat16): | |
input_dtype, output_dtype, acc_dtype = DType.BF16, DType.BF16, DType.FP32 #FP16 acc not supported with bfloat16 | |
if(None in (input_dtype, output_dtype, acc_dtype)): | |
raise Exception('Unsupported compute config', (input_dtype, output_dtype, acc_dtype)) | |
gemlite_linear = GemLiteLinearTriton(W_nbits=W_nbits, | |
group_size=group_size, in_features=in_features, out_features=out_features, | |
input_dtype=input_dtype, output_dtype=output_dtype, acc_dtype=acc_dtype) | |
gemlite_linear.pack(hqq_layer.unpack().view(orig_shape), hqq_layer.meta['scale'], hqq_layer.meta['zero'], None); | |
#TorchAO | |
if(W_nbits==8): | |
torchao_linear = Torch_A16W8SYM(in_features, out_features, (W_q.int() - 127).to(torch.int8), scales, bias=None) | |
if(W_nbits==4): | |
hqq_layer.compute_dtype = torch.bfloat16 | |
hqq_layer.meta['scale'] = hqq_layer.meta['scale'].to(torch.bfloat16).view((-1, 1)) | |
hqq_layer.meta['zero'] = hqq_layer.meta['zero'].to(torch.bfloat16).view((-1, 1)) | |
torchao_linear = patch_hqq_to_aoint4(hqq_layer, None) | |
torch.cuda.empty_cache() | |
# Bitblas | |
if(W_nbits in [8, 4, 2]): | |
bitblas_linear = patch_hqq_to_bitblas(HQQLinear(linear, quant_config=quant_config, compute_dtype=torch.float16, device=device, del_orig=False), None) | |
torch.cuda.empty_cache() | |
# ################################################################# | |
# #Marlin | |
# from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinLinearMethod as MarlinLinearMethod | |
# from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig as MarlinConfig | |
# if(W_nbits==4): | |
# _marlin_linear = MarlinLinearMethod(MarlinConfig(weight_bits=W_nbits, group_size=group_size, has_zp=True, lm_head_quantized=False)) | |
# marlin_linear = empty_linear(in_features, out_features, compute_dtype=torch.float16, device='cuda:0') | |
# _marlin_linear.create_weights(layer=marlin_linear, | |
# input_size_per_partition=in_features, | |
# output_partition_sizes=[out_features], | |
# input_size=in_features, | |
# output_size=out_features, | |
# params_dtype=torch.float16) | |
# marlin_linear = marlin_linear.cuda() | |
# _marlin_linear.process_weights_after_loading(marlin_linear) | |
# marlin_linear.scales.data = torch.zeros_like(marlin_linear.scales.data) + 1 | |
# marlin_linear.bias = None | |
# marlin_linear.forward = lambda x: _marlin_linear.apply(layer=marlin_linear, x=x, bias=marlin_linear.bias) | |
# torch.cuda.empty_cache() | |
# ################################################################# | |
return W, gemlite_linear, torchao_linear, bitblas_linear, marlin_linear | |
############################################################################################################# | |
W, gemlite_linear, torchao_linear, bitblas_linear, marlin_linear = gen_data(in_features, out_features, W_nbits, group_size) | |
if(matmul_type == "AUTO"): | |
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] | |
HQQLinearBitBlas.DEFAULT_BATCHSIZE = [1, 16] | |
if(matmul_type == "GEMV"): | |
BATCH_SIZES = [1, 2, 4, 8] | |
gemlite_linear.forward = lambda x: gemlite_linear.forward_manual(x, matmul_type="GEMV") | |
HQQLinearBitBlas.DEFAULT_BATCHSIZE = [1] | |
if(matmul_type == "GEMM"): | |
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] | |
gemlite_linear.forward = lambda x: gemlite_linear.forward_manual(x, matmul_type="GEMM") | |
HQQLinearBitBlas.DEFAULT_BATCHSIZE = [16] | |
print("W_nbits", W_nbits, "group_size", group_size, "matmul_type", matmul_type) | |
for batch_size in BATCH_SIZES: | |
try: | |
x = torch.randn((batch_size, in_features), dtype=gemlite_linear.compute_dtype, device='cuda:0')/10. | |
check_valid(x, W, gemlite_linear) | |
except Exception as e: | |
print(e) | |
ref_time = eval_time(lambda x: torch.matmul(x, W.T), {'x':x.to(W.dtype)}) | |
print("ref_time", ref_time) | |
if(gemlite_linear is not None): | |
try: | |
triton_time = eval_time(lambda x: gemlite_linear(x), {'x':x.to(gemlite_linear.compute_dtype)}) | |
print((batch_size, in_features, out_features), 'Triton Speed-up vs. torch.matmul', np.round(ref_time/triton_time, 2), 'time', triton_time) | |
except Exception as e: | |
print(e) | |
if(torchao_linear is not None): | |
torchao_time = eval_time(lambda x: torchao_linear(x), {'x':x.to(torchao_linear.compute_dtype)}) | |
print((batch_size, in_features, out_features), 'Torchao Speed-up vs. torch.matmul', np.round(ref_time/torchao_time, 2)) | |
if(bitblas_linear is not None): | |
bitblas_time = eval_time(lambda x: bitblas_linear(x), {'x':x.to(bitblas_linear.compute_dtype)}) | |
print((batch_size, in_features, out_features), 'Bitblas Speed-up vs. torch.matmul', np.round(ref_time/bitblas_time, 2)) | |
if(marlin_linear is not None): | |
marlin_time = eval_time(lambda x: marlin_linear.forward(x), {'x':x.to(marlin_linear.compute_dtype)}) | |
print((batch_size, in_features, out_features), 'Marlin Speed-up vs. torch.matmul', np.round(ref_time/marlin_time, 2)) | |
print('----------------------------------------------') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
2024-10-04 08:30:24 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using
export TVM_TARGET=<target>
, where is one of the available targets can be found in the output oftools/get_available_targets.py
.W_nbits 4 group_size 128 matmul_type AUTO
ref_time 0.03884800150990486
(1, 4096, 4096) Triton Speed-up vs. torch.matmul 1.54 time 0.025280000641942024
(1, 4096, 4096) Torchao Speed-up vs. torch.matmul 2.19
(1, 4096, 4096) Bitblas Speed-up vs. torch.matmul 2.47
ref_time 0.038784001022577286
(2, 4096, 4096) Triton Speed-up vs. torch.matmul 1.04 time 0.03728000074625015
(2, 4096, 4096) Torchao Speed-up vs. torch.matmul 2.14
(2, 4096, 4096) Bitblas Speed-up vs. torch.matmul 1.81
ref_time 0.038784001022577286
(4, 4096, 4096) Triton Speed-up vs. torch.matmul 0.79 time 0.04918399825692177
(4, 4096, 4096) Torchao Speed-up vs. torch.matmul 1.86
(4, 4096, 4096) Bitblas Speed-up vs. torch.matmul 1.14
ref_time 0.03884800150990486
(8, 4096, 4096) Triton Speed-up vs. torch.matmul 0.79 time 0.04902400076389313
(8, 4096, 4096) Torchao Speed-up vs. torch.matmul 1.34
(8, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.66
ref_time 0.03888000175356865
(16, 4096, 4096) Triton Speed-up vs. torch.matmul 0.79 time 0.04931199923157692
(16, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.87
(16, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.36
ref_time 0.040192000567913055
(32, 4096, 4096) Triton Speed-up vs. torch.matmul 0.75 time 0.053568001836538315
(32, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.58
(32, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.2
ref_time 0.03948799893260002
(64, 4096, 4096) Triton Speed-up vs. torch.matmul 0.66 time 0.05967999994754791
(64, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.32
(64, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.1
ref_time 0.044096000492572784
(128, 4096, 4096) Triton Speed-up vs. torch.matmul 0.58 time 0.07587199658155441
(128, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.18
(128, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.05
ref_time 0.061184000223875046
(256, 4096, 4096) Triton Speed-up vs. torch.matmul 0.55 time 0.11116799712181091
(256, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.13
(256, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.04
ref_time 0.09014400094747543
(512, 4096, 4096) Triton Speed-up vs. torch.matmul 0.56 time 0.1605760008096695
(512, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.1
(512, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.03
ref_time 0.16927999258041382
(1024, 4096, 4096) Triton Speed-up vs. torch.matmul 0.6 time 0.2815679907798767
(1024, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.09
(1024, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.02
2024-10-04 08:34:33 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using
export TVM_TARGET=<target>
, where is one of the available targets can be found in the output oftools/get_available_targets.py
.W_nbits 2 group_size 128 matmul_type AUTO
ref_time 0.038816001266241074
(1, 4096, 4096) Triton Speed-up vs. torch.matmul 1.61 time 0.024064000695943832
(1, 4096, 4096) Bitblas Speed-up vs. torch.matmul 2.61
ref_time 0.038975998759269714
(2, 4096, 4096) Triton Speed-up vs. torch.matmul 1.07 time 0.03651199862360954
(2, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.03907199949026108
(4, 4096, 4096) Triton Speed-up vs. torch.matmul 0.77 time 0.05104000121355057
(4, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.03888000175356865
(8, 4096, 4096) Triton Speed-up vs. torch.matmul 0.76 time 0.05129599943757057
(8, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.58
ref_time 0.03903999924659729
(16, 4096, 4096) Triton Speed-up vs. torch.matmul 0.76 time 0.05110400170087814
(16, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.039903998374938965
(32, 4096, 4096) Triton Speed-up vs. torch.matmul 0.76 time 0.052639998495578766
(32, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.42
ref_time 0.0398080013692379
(64, 4096, 4096) Triton Speed-up vs. torch.matmul 0.67 time 0.05910399928689003
(64, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.3
ref_time 0.043776001781225204
(128, 4096, 4096) Triton Speed-up vs. torch.matmul 0.59 time 0.07446400076150894
(128, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.2
ref_time 0.06128000095486641
(256, 4096, 4096) Triton Speed-up vs. torch.matmul 0.64 time 0.09590400010347366
(256, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.15
ref_time 0.09091199934482574
(512, 4096, 4096) Triton Speed-up vs. torch.matmul 0.6 time 0.1523520052433014
(512, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.12
ref_time 0.1709440052509308
(1024, 4096, 4096) Triton Speed-up vs. torch.matmul 0.65 time 0.26291200518608093
(1024, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.11
2024-10-04 08:38:22 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using
export TVM_TARGET=<target>
, where is one of the available targets can be found in the output oftools/get_available_targets.py
.W_nbits 4 group_size 128 matmul_type AUTO
ref_time 0.10550399869680405
(1, 8192, 8192) Triton Speed-up vs. torch.matmul 1.54 time 0.06831999868154526
(1, 8192, 8192) Torchao Speed-up vs. torch.matmul 2.38
(1, 8192, 8192) Bitblas Speed-up vs. torch.matmul 2.23
ref_time 0.10697600245475769
(2, 8192, 8192) Triton Speed-up vs. torch.matmul 1.03 time 0.10355199873447418
(2, 8192, 8192) Torchao Speed-up vs. torch.matmul 2.36
(2, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.54
ref_time 0.10704000294208527
(4, 8192, 8192) Triton Speed-up vs. torch.matmul 1.03 time 0.1035199984908104
(4, 8192, 8192) Torchao Speed-up vs. torch.matmul 2.13
(4, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.55
ref_time 0.10732799768447876
(8, 8192, 8192) Triton Speed-up vs. torch.matmul 1.04 time 0.10361599922180176
(8, 8192, 8192) Torchao Speed-up vs. torch.matmul 1.27
(8, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.10793600231409073
(16, 8192, 8192) Triton Speed-up vs. torch.matmul 1.05 time 0.10284800082445145
(16, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.83
(16, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.58
ref_time 0.10697600245475769
(32, 8192, 8192) Triton Speed-up vs. torch.matmul 0.94 time 0.11427199840545654
(32, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.43
(32, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.31
ref_time 0.10755199939012527
(64, 8192, 8192) Triton Speed-up vs. torch.matmul 0.76 time 0.1414719969034195
(64, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.22
(64, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.21
ref_time 0.125791996717453
(128, 8192, 8192) Triton Speed-up vs. torch.matmul 0.58 time 0.217056006193161
(128, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.13
(128, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.15
ref_time 0.17155200242996216
(256, 8192, 8192) Triton Speed-up vs. torch.matmul 0.54 time 0.31753599643707275
(256, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.09
(256, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.1
ref_time 0.3309760093688965
(512, 8192, 8192) Triton Speed-up vs. torch.matmul 0.53 time 0.6275200247764587
(512, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.09
(512, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.1
ref_time 0.6511359810829163
(1024, 8192, 8192) Triton Speed-up vs. torch.matmul 0.57 time 1.1463680267333984
(1024, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.09
(1024, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.1
2024-10-04 08:42:59 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using
export TVM_TARGET=<target>
, where is one of the available targets can be found in the output oftools/get_available_targets.py
.W_nbits 2 group_size 128 matmul_type AUTO
ref_time 0.10675200074911118
(1, 8192, 8192) Triton Speed-up vs. torch.matmul 1.61 time 0.06646399945020676
(1, 8192, 8192) Bitblas Speed-up vs. torch.matmul 2.58
ref_time 0.10672000050544739
(2, 8192, 8192) Triton Speed-up vs. torch.matmul 1.0 time 0.1066880002617836
(2, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.56
ref_time 0.10732799768447876
(4, 8192, 8192) Triton Speed-up vs. torch.matmul 1.01 time 0.10665600001811981
(4, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.1074879989027977
(8, 8192, 8192) Triton Speed-up vs. torch.matmul 1.01 time 0.10675200074911118
(8, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.59
ref_time 0.1080000028014183
(16, 8192, 8192) Triton Speed-up vs. torch.matmul 1.01 time 0.10716799646615982
(16, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.6
ref_time 0.10649599879980087
(32, 8192, 8192) Triton Speed-up vs. torch.matmul 0.95 time 0.11209599673748016
(32, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.43
ref_time 0.10780800133943558
(64, 8192, 8192) Triton Speed-up vs. torch.matmul 0.77 time 0.1404159963130951
(64, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.26
ref_time 0.12700800597667694
(128, 8192, 8192) Triton Speed-up vs. torch.matmul 0.67 time 0.19008000195026398
(128, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.16
ref_time 0.17235200107097626
(256, 8192, 8192) Triton Speed-up vs. torch.matmul 0.57 time 0.30425599217414856
(256, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.11
ref_time 0.3341119885444641
(512, 8192, 8192) Triton Speed-up vs. torch.matmul 0.64 time 0.5241600275039673
(512, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.11
ref_time 0.6436799764633179
(1024, 8192, 8192) Triton Speed-up vs. torch.matmul 0.61 time 1.05404794216156
(1024, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.1
2024-10-04 08:47:33 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using
export TVM_TARGET=<target>
, where is one of the available targets can be found in the output oftools/get_available_targets.py
.W_nbits 4 group_size 128 matmul_type AUTO
ref_time 0.3417600095272064
(1, 16384, 16384) Triton Speed-up vs. torch.matmul 1.43 time 0.2396160066127777
(1, 16384, 16384) Torchao Speed-up vs. torch.matmul 2.76
(1, 16384, 16384) Bitblas Speed-up vs. torch.matmul 2.53
ref_time 0.3420160114765167
(2, 16384, 16384) Triton Speed-up vs. torch.matmul 1.22 time 0.28147199749946594
(2, 16384, 16384) Torchao Speed-up vs. torch.matmul 2.63
(2, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.62
ref_time 0.3447999954223633
(4, 16384, 16384) Triton Speed-up vs. torch.matmul 1.23 time 0.2813439965248108
(4, 16384, 16384) Torchao Speed-up vs. torch.matmul 2.1
(4, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.63
ref_time 0.34703999757766724
(8, 16384, 16384) Triton Speed-up vs. torch.matmul 1.22 time 0.28415998816490173
(8, 16384, 16384) Torchao Speed-up vs. torch.matmul 1.27
(8, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.66
ref_time 0.35043200850486755
(16, 16384, 16384) Triton Speed-up vs. torch.matmul 1.2 time 0.29129600524902344
(16, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.74
(16, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.67
ref_time 0.35417601466178894
(32, 16384, 16384) Triton Speed-up vs. torch.matmul 1.0 time 0.35420799255371094
(32, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.38
(32, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.49
ref_time 0.35763201117515564
(64, 16384, 16384) Triton Speed-up vs. torch.matmul 0.78 time 0.45686399936676025
(64, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.2
(64, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.28
ref_time 0.4419519901275635
(128, 16384, 16384) Triton Speed-up vs. torch.matmul 0.73 time 0.601855993270874
(128, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.12
(128, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.18
ref_time 0.6263999938964844
(256, 16384, 16384) Triton Speed-up vs. torch.matmul 0.56 time 1.124351978302002
(256, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.09
(256, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.13
ref_time 1.1796480417251587
(512, 16384, 16384) Triton Speed-up vs. torch.matmul 0.52 time 2.2871038913726807
(512, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.08
(512, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.12
ref_time 2.451647996902466
(1024, 16384, 16384) Triton Speed-up vs. torch.matmul 0.56 time 4.351327896118164
(1024, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.08
(1024, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.13
2024-10-04 08:53:11 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using
export TVM_TARGET=<target>
, where is one of the available targets can be found in the output oftools/get_available_targets.py
.W_nbits 2 group_size 128 matmul_type AUTO
ref_time 0.34540799260139465
(1, 16384, 16384) Triton Speed-up vs. torch.matmul 1.47 time 0.23452800512313843
(1, 16384, 16384) Bitblas Speed-up vs. torch.matmul 2.71
ref_time 0.3437120020389557
(2, 16384, 16384) Triton Speed-up vs. torch.matmul 1.27 time 0.27132800221443176
(2, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.65
ref_time 0.34672001004219055
(4, 16384, 16384) Triton Speed-up vs. torch.matmul 1.28 time 0.27113598585128784
(4, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.67
ref_time 0.3494400084018707
(8, 16384, 16384) Triton Speed-up vs. torch.matmul 1.28 time 0.2725760042667389
(8, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.7
ref_time 0.35254400968551636
(16, 16384, 16384) Triton Speed-up vs. torch.matmul 1.28 time 0.27478399872779846
(16, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.74
ref_time 0.35417601466178894
(32, 16384, 16384) Triton Speed-up vs. torch.matmul 1.2 time 0.2962239980697632
(32, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.44
ref_time 0.3561280071735382
(64, 16384, 16384) Triton Speed-up vs. torch.matmul 0.97 time 0.36822399497032166
(64, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.22
ref_time 0.44419199228286743
(128, 16384, 16384) Triton Speed-up vs. torch.matmul 0.73 time 0.6117759943008423
(128, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.15
ref_time 0.640608012676239
(256, 16384, 16384) Triton Speed-up vs. torch.matmul 0.61 time 1.0531200170516968
(256, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.11
ref_time 1.2030080556869507
(512, 16384, 16384) Triton Speed-up vs. torch.matmul 0.57 time 2.12825608253479
(512, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.1
ref_time 2.5157759189605713
(1024, 16384, 16384) Triton Speed-up vs. torch.matmul 0.62 time 4.054175853729248
(1024, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.1
2024-10-04 08:58:45 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using
export TVM_TARGET=<target>
, where is one of the available targets can be found in the output oftools/get_available_targets.py
.W_nbits 4 group_size 128 matmul_type AUTO
ref_time 1.2765120267868042
(1, 32768, 32768) Triton Speed-up vs. torch.matmul 1.35 time 0.9481279850006104
(1, 32768, 32768) Torchao Speed-up vs. torch.matmul 2.76
(1, 32768, 32768) Bitblas Speed-up vs. torch.matmul 2.56
ref_time 1.3089280128479004
(2, 32768, 32768) Triton Speed-up vs. torch.matmul 1.37 time 0.9542400240898132
(2, 32768, 32768) Torchao Speed-up vs. torch.matmul 2.63
(2, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.86
ref_time 1.3063360452651978
(4, 32768, 32768) Triton Speed-up vs. torch.matmul 1.37 time 0.9519360065460205
(4, 32768, 32768) Torchao Speed-up vs. torch.matmul 2.02
(4, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.86
ref_time 1.3099839687347412
(8, 32768, 32768) Triton Speed-up vs. torch.matmul 1.37 time 0.9582080245018005
(8, 32768, 32768) Torchao Speed-up vs. torch.matmul 1.21
(8, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.89
ref_time 1.3158719539642334
(16, 32768, 32768) Triton Speed-up vs. torch.matmul 1.34 time 0.9845439791679382
(16, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.72
(16, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.89
ref_time 1.3440959453582764
(32, 32768, 32768) Triton Speed-up vs. torch.matmul 1.13 time 1.1911360025405884
(32, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.37
(32, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.53
ref_time 1.3771519660949707
(64, 32768, 32768) Triton Speed-up vs. torch.matmul 0.87 time 1.5784000158309937
(64, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.19
(64, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.27
ref_time 1.5383360385894775
(128, 32768, 32768) Triton Speed-up vs. torch.matmul 0.63 time 2.423840045928955
(128, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.11
(128, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.16
ref_time 2.5936319828033447
(256, 32768, 32768) Triton Speed-up vs. torch.matmul 0.56 time 4.658976078033447
(256, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.09
(256, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.14
ref_time 5.009151935577393
(512, 32768, 32768) Triton Speed-up vs. torch.matmul 0.57 time 8.741791725158691
(512, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.08
(512, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13
ref_time 9.679488182067871
(1024, 32768, 32768) Triton Speed-up vs. torch.matmul 0.52 time 18.480607986450195
(1024, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.08
(1024, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13
2024-10-04 09:07:17 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using
export TVM_TARGET=<target>
, where is one of the available targets can be found in the output oftools/get_available_targets.py
.W_nbits 2 group_size 128 matmul_type AUTO
ref_time 1.2772480249404907
(1, 32768, 32768) Triton Speed-up vs. torch.matmul 1.69 time 0.7576320171356201
(1, 32768, 32768) Bitblas Speed-up vs. torch.matmul 2.82
ref_time 1.3076159954071045
(2, 32768, 32768) Triton Speed-up vs. torch.matmul 1.73 time 0.7579519748687744
(2, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.84
ref_time 1.3032000064849854
(4, 32768, 32768) Triton Speed-up vs. torch.matmul 1.72 time 0.7576000094413757
(4, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.85
ref_time 1.3141759634017944
(8, 32768, 32768) Triton Speed-up vs. torch.matmul 1.71 time 0.7696319818496704
(8, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.88
ref_time 1.3204799890518188
(16, 32768, 32768) Triton Speed-up vs. torch.matmul 1.68 time 0.7845759987831116
(16, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.9
ref_time 1.3463040590286255
(32, 32768, 32768) Triton Speed-up vs. torch.matmul 1.28 time 1.0521600246429443
(32, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.51
ref_time 1.3624639511108398
(64, 32768, 32768) Triton Speed-up vs. torch.matmul 0.98 time 1.3894399404525757
(64, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.27
ref_time 1.546463966369629
(128, 32768, 32768) Triton Speed-up vs. torch.matmul 0.71 time 2.176255941390991
(128, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.16
ref_time 2.424799919128418
(256, 32768, 32768) Triton Speed-up vs. torch.matmul 0.56 time 4.307519912719727
(256, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13
ref_time 5.034048080444336
(512, 32768, 32768) Triton Speed-up vs. torch.matmul 0.62 time 8.100383758544922
(512, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13
ref_time 9.681568145751953
(1024, 32768, 32768) Triton Speed-up vs. torch.matmul 0.58 time 16.774080276489258
(1024, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13