Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created October 7, 2024 16:25
Show Gist options
  • Save HDCharles/cd59f1df5b3e2b31d869409144fe03ec to your computer and use it in GitHub Desktop.
Save HDCharles/cd59f1df5b3e2b31d869409144fe03ec to your computer and use it in GitHub Desktop.
#OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 ipython3 benchmark_triton.py #select the right number of threads based on your machine
#You can change the matmul_dtype: GEMM, GEMV or AUTO
#Note: bfloat16 only supported in GEMM mode with float32 accumulation
#################################################################################################################################
import torch
import numpy as np
device = 'cuda:0'
compute_dtype = torch.float16
in_feature_list = [4096, 4096*2, 4096*4, 4096*8]
out_feature_list = [4096, 4096*2, 4096*4, 4096*8]
W_nbits_list = [4, 2]
group_size_list = [128]
for in_features, out_features in zip(in_feature_list, out_feature_list):
for W_nbits in W_nbits_list:
for group_size in group_size_list:
#in_features, out_features = 4096, 4096
#in_features, out_features = 4096*2, 4096*2
#in_features, out_features = 4096*4, 4096*4
# in_features, out_features = 4096*8, 4096*8
#W_nbits, group_size = 8, in_features
# W_nbits, group_size = 4, 128
#W_nbits, group_size = 2, 128
matmul_type = "AUTO" #GEMM, GEMV, "AUTO"
#################################################################################################################################
from triton.testing import do_bench
def eval_time(fct, params):
return do_bench(lambda: fct(**params), warmup=25, rep=200, fast_flush=True, return_mode='min')
def check_valid(x, W, quant_linear, tol=1e-3):
y_ref = torch.matmul(x, W.T)
y_q = quant_linear(x)
try:
assert (y_ref - y_q).abs().mean() < tol
except:
raise Error('Assertion Failed')
#################################################################################################################################
#TorchAO Int8 settings
torch._dynamo.config.capture_scalar_outputs = True
torch._inductor.config.coordinate_descent_tuning = True
@torch.compile()
def matmul_torch_A16W8SYM(x, W_q, scales, out_features):
out_shape = x.shape[:-1] + (out_features,)
out = ((x.view((-1, x.shape[-1])) @ W_q.T.to(x.dtype)) / scales.view(1, -1)).view(out_shape)
return out
class Torch_A16W8SYM(torch.nn.Module):
def __init__(self, in_features, out_features, W_q, scales, bias=None):
super().__init__()
self.W_q = W_q
self.in_features = in_features
self.out_features = out_features
self.scales = scales
self.bias = bias
self.device = W_q.device
self.compute_dtype = scales.dtype
def forward(self, x):
out = matmul_torch_A16W8SYM(x.to(self.device), self.W_q, self.scales, self.out_features)
if(self.bias is not None):
out += self.bias
return out
#HQQ
from hqq.core.quantize import *
from hqq.backends.bitblas import patch_hqq_to_bitblas, HQQLinearBitBlas
from hqq.backends.torchao import patch_hqq_to_aoint4
HQQLinearBitBlas.check = lambda hqq_layer: True
HQQLinearBitBlas.BIT_TO_DTYPE = {8:"uint8", 4: "uint4", 2: "uint2", 1: "uint1"}
#GemLite
from gemlite.core import GemLiteLinearTriton, DType
class empty_linear(torch.nn.Module):
def __init__(self, in_features, out_features, compute_dtype, device):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.device = device
self.compute_dtype = compute_dtype
def gen_data(in_features, out_features, W_nbits, group_size, device=device):
linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False, device='cpu')
quant_config = BaseQuantizeConfig(nbits=W_nbits, group_size=group_size, quant_zero=False, quant_scale=False, axis=1)
quant_config['weight_quant_params']['optimize'] = False
hqq_layer = HQQLinear(linear, quant_config=quant_config, compute_dtype=compute_dtype, device=device, del_orig=False) #bfloat16
orig_shape = (out_features, in_features)
W = hqq_layer.dequantize().reshape(orig_shape)
gemlite_linear, torchao_linear, bitblas_linear, marlin_linear = [None]*4
#GemLite
if(W_nbits in [8, 4, 2, 1]):
input_dtype, output_dtype, acc_dtype = None, None, None
if(compute_dtype == torch.float16):
input_dtype, output_dtype, acc_dtype = DType.FP16, DType.FP16, DType.FP16
if(compute_dtype == torch.bfloat16):
input_dtype, output_dtype, acc_dtype = DType.BF16, DType.BF16, DType.FP32 #FP16 acc not supported with bfloat16
if(None in (input_dtype, output_dtype, acc_dtype)):
raise Exception('Unsupported compute config', (input_dtype, output_dtype, acc_dtype))
gemlite_linear = GemLiteLinearTriton(W_nbits=W_nbits,
group_size=group_size, in_features=in_features, out_features=out_features,
input_dtype=input_dtype, output_dtype=output_dtype, acc_dtype=acc_dtype)
gemlite_linear.pack(hqq_layer.unpack().view(orig_shape), hqq_layer.meta['scale'], hqq_layer.meta['zero'], None);
#TorchAO
if(W_nbits==8):
torchao_linear = Torch_A16W8SYM(in_features, out_features, (W_q.int() - 127).to(torch.int8), scales, bias=None)
if(W_nbits==4):
hqq_layer.compute_dtype = torch.bfloat16
hqq_layer.meta['scale'] = hqq_layer.meta['scale'].to(torch.bfloat16).view((-1, 1))
hqq_layer.meta['zero'] = hqq_layer.meta['zero'].to(torch.bfloat16).view((-1, 1))
torchao_linear = patch_hqq_to_aoint4(hqq_layer, None)
torch.cuda.empty_cache()
# Bitblas
if(W_nbits in [8, 4, 2]):
bitblas_linear = patch_hqq_to_bitblas(HQQLinear(linear, quant_config=quant_config, compute_dtype=torch.float16, device=device, del_orig=False), None)
torch.cuda.empty_cache()
# #################################################################
# #Marlin
# from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinLinearMethod as MarlinLinearMethod
# from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig as MarlinConfig
# if(W_nbits==4):
# _marlin_linear = MarlinLinearMethod(MarlinConfig(weight_bits=W_nbits, group_size=group_size, has_zp=True, lm_head_quantized=False))
# marlin_linear = empty_linear(in_features, out_features, compute_dtype=torch.float16, device='cuda:0')
# _marlin_linear.create_weights(layer=marlin_linear,
# input_size_per_partition=in_features,
# output_partition_sizes=[out_features],
# input_size=in_features,
# output_size=out_features,
# params_dtype=torch.float16)
# marlin_linear = marlin_linear.cuda()
# _marlin_linear.process_weights_after_loading(marlin_linear)
# marlin_linear.scales.data = torch.zeros_like(marlin_linear.scales.data) + 1
# marlin_linear.bias = None
# marlin_linear.forward = lambda x: _marlin_linear.apply(layer=marlin_linear, x=x, bias=marlin_linear.bias)
# torch.cuda.empty_cache()
# #################################################################
return W, gemlite_linear, torchao_linear, bitblas_linear, marlin_linear
#############################################################################################################
W, gemlite_linear, torchao_linear, bitblas_linear, marlin_linear = gen_data(in_features, out_features, W_nbits, group_size)
if(matmul_type == "AUTO"):
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
HQQLinearBitBlas.DEFAULT_BATCHSIZE = [1, 16]
if(matmul_type == "GEMV"):
BATCH_SIZES = [1, 2, 4, 8]
gemlite_linear.forward = lambda x: gemlite_linear.forward_manual(x, matmul_type="GEMV")
HQQLinearBitBlas.DEFAULT_BATCHSIZE = [1]
if(matmul_type == "GEMM"):
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
gemlite_linear.forward = lambda x: gemlite_linear.forward_manual(x, matmul_type="GEMM")
HQQLinearBitBlas.DEFAULT_BATCHSIZE = [16]
print("W_nbits", W_nbits, "group_size", group_size, "matmul_type", matmul_type)
for batch_size in BATCH_SIZES:
try:
x = torch.randn((batch_size, in_features), dtype=gemlite_linear.compute_dtype, device='cuda:0')/10.
check_valid(x, W, gemlite_linear)
except Exception as e:
print(e)
ref_time = eval_time(lambda x: torch.matmul(x, W.T), {'x':x.to(W.dtype)})
print("ref_time", ref_time)
if(gemlite_linear is not None):
try:
triton_time = eval_time(lambda x: gemlite_linear(x), {'x':x.to(gemlite_linear.compute_dtype)})
print((batch_size, in_features, out_features), 'Triton Speed-up vs. torch.matmul', np.round(ref_time/triton_time, 2), 'time', triton_time)
except Exception as e:
print(e)
if(torchao_linear is not None):
torchao_time = eval_time(lambda x: torchao_linear(x), {'x':x.to(torchao_linear.compute_dtype)})
print((batch_size, in_features, out_features), 'Torchao Speed-up vs. torch.matmul', np.round(ref_time/torchao_time, 2))
if(bitblas_linear is not None):
bitblas_time = eval_time(lambda x: bitblas_linear(x), {'x':x.to(bitblas_linear.compute_dtype)})
print((batch_size, in_features, out_features), 'Bitblas Speed-up vs. torch.matmul', np.round(ref_time/bitblas_time, 2))
if(marlin_linear is not None):
marlin_time = eval_time(lambda x: marlin_linear.forward(x), {'x':x.to(marlin_linear.compute_dtype)})
print((batch_size, in_features, out_features), 'Marlin Speed-up vs. torch.matmul', np.round(ref_time/marlin_time, 2))
print('----------------------------------------------')
@HDCharles
Copy link
Author

HDCharles commented Oct 7, 2024

2024-10-04 08:30:24 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of tools/get_available_targets.py.
W_nbits 4 group_size 128 matmul_type AUTO
ref_time 0.03884800150990486
(1, 4096, 4096) Triton Speed-up vs. torch.matmul 1.54 time 0.025280000641942024
(1, 4096, 4096) Torchao Speed-up vs. torch.matmul 2.19
(1, 4096, 4096) Bitblas Speed-up vs. torch.matmul 2.47
ref_time 0.038784001022577286
(2, 4096, 4096) Triton Speed-up vs. torch.matmul 1.04 time 0.03728000074625015
(2, 4096, 4096) Torchao Speed-up vs. torch.matmul 2.14
(2, 4096, 4096) Bitblas Speed-up vs. torch.matmul 1.81
ref_time 0.038784001022577286
(4, 4096, 4096) Triton Speed-up vs. torch.matmul 0.79 time 0.04918399825692177
(4, 4096, 4096) Torchao Speed-up vs. torch.matmul 1.86
(4, 4096, 4096) Bitblas Speed-up vs. torch.matmul 1.14
ref_time 0.03884800150990486
(8, 4096, 4096) Triton Speed-up vs. torch.matmul 0.79 time 0.04902400076389313
(8, 4096, 4096) Torchao Speed-up vs. torch.matmul 1.34
(8, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.66
ref_time 0.03888000175356865
(16, 4096, 4096) Triton Speed-up vs. torch.matmul 0.79 time 0.04931199923157692
(16, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.87
(16, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.36
ref_time 0.040192000567913055
(32, 4096, 4096) Triton Speed-up vs. torch.matmul 0.75 time 0.053568001836538315
(32, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.58
(32, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.2
ref_time 0.03948799893260002
(64, 4096, 4096) Triton Speed-up vs. torch.matmul 0.66 time 0.05967999994754791
(64, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.32
(64, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.1
ref_time 0.044096000492572784
(128, 4096, 4096) Triton Speed-up vs. torch.matmul 0.58 time 0.07587199658155441
(128, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.18
(128, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.05
ref_time 0.061184000223875046
(256, 4096, 4096) Triton Speed-up vs. torch.matmul 0.55 time 0.11116799712181091
(256, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.13
(256, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.04
ref_time 0.09014400094747543
(512, 4096, 4096) Triton Speed-up vs. torch.matmul 0.56 time 0.1605760008096695
(512, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.1
(512, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.03
ref_time 0.16927999258041382
(1024, 4096, 4096) Triton Speed-up vs. torch.matmul 0.6 time 0.2815679907798767
(1024, 4096, 4096) Torchao Speed-up vs. torch.matmul 0.09
(1024, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.02
2024-10-04 08:34:33 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of tools/get_available_targets.py.
W_nbits 2 group_size 128 matmul_type AUTO
ref_time 0.038816001266241074
(1, 4096, 4096) Triton Speed-up vs. torch.matmul 1.61 time 0.024064000695943832
(1, 4096, 4096) Bitblas Speed-up vs. torch.matmul 2.61
ref_time 0.038975998759269714
(2, 4096, 4096) Triton Speed-up vs. torch.matmul 1.07 time 0.03651199862360954
(2, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.03907199949026108
(4, 4096, 4096) Triton Speed-up vs. torch.matmul 0.77 time 0.05104000121355057
(4, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.03888000175356865
(8, 4096, 4096) Triton Speed-up vs. torch.matmul 0.76 time 0.05129599943757057
(8, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.58
ref_time 0.03903999924659729
(16, 4096, 4096) Triton Speed-up vs. torch.matmul 0.76 time 0.05110400170087814
(16, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.039903998374938965
(32, 4096, 4096) Triton Speed-up vs. torch.matmul 0.76 time 0.052639998495578766
(32, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.42
ref_time 0.0398080013692379
(64, 4096, 4096) Triton Speed-up vs. torch.matmul 0.67 time 0.05910399928689003
(64, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.3
ref_time 0.043776001781225204
(128, 4096, 4096) Triton Speed-up vs. torch.matmul 0.59 time 0.07446400076150894
(128, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.2
ref_time 0.06128000095486641
(256, 4096, 4096) Triton Speed-up vs. torch.matmul 0.64 time 0.09590400010347366
(256, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.15
ref_time 0.09091199934482574
(512, 4096, 4096) Triton Speed-up vs. torch.matmul 0.6 time 0.1523520052433014
(512, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.12
ref_time 0.1709440052509308
(1024, 4096, 4096) Triton Speed-up vs. torch.matmul 0.65 time 0.26291200518608093
(1024, 4096, 4096) Bitblas Speed-up vs. torch.matmul 0.11
2024-10-04 08:38:22 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of tools/get_available_targets.py.
W_nbits 4 group_size 128 matmul_type AUTO
ref_time 0.10550399869680405
(1, 8192, 8192) Triton Speed-up vs. torch.matmul 1.54 time 0.06831999868154526
(1, 8192, 8192) Torchao Speed-up vs. torch.matmul 2.38
(1, 8192, 8192) Bitblas Speed-up vs. torch.matmul 2.23
ref_time 0.10697600245475769
(2, 8192, 8192) Triton Speed-up vs. torch.matmul 1.03 time 0.10355199873447418
(2, 8192, 8192) Torchao Speed-up vs. torch.matmul 2.36
(2, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.54
ref_time 0.10704000294208527
(4, 8192, 8192) Triton Speed-up vs. torch.matmul 1.03 time 0.1035199984908104
(4, 8192, 8192) Torchao Speed-up vs. torch.matmul 2.13
(4, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.55
ref_time 0.10732799768447876
(8, 8192, 8192) Triton Speed-up vs. torch.matmul 1.04 time 0.10361599922180176
(8, 8192, 8192) Torchao Speed-up vs. torch.matmul 1.27
(8, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.10793600231409073
(16, 8192, 8192) Triton Speed-up vs. torch.matmul 1.05 time 0.10284800082445145
(16, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.83
(16, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.58
ref_time 0.10697600245475769
(32, 8192, 8192) Triton Speed-up vs. torch.matmul 0.94 time 0.11427199840545654
(32, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.43
(32, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.31
ref_time 0.10755199939012527
(64, 8192, 8192) Triton Speed-up vs. torch.matmul 0.76 time 0.1414719969034195
(64, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.22
(64, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.21
ref_time 0.125791996717453
(128, 8192, 8192) Triton Speed-up vs. torch.matmul 0.58 time 0.217056006193161
(128, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.13
(128, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.15
ref_time 0.17155200242996216
(256, 8192, 8192) Triton Speed-up vs. torch.matmul 0.54 time 0.31753599643707275
(256, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.09
(256, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.1
ref_time 0.3309760093688965
(512, 8192, 8192) Triton Speed-up vs. torch.matmul 0.53 time 0.6275200247764587
(512, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.09
(512, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.1
ref_time 0.6511359810829163
(1024, 8192, 8192) Triton Speed-up vs. torch.matmul 0.57 time 1.1463680267333984
(1024, 8192, 8192) Torchao Speed-up vs. torch.matmul 0.09
(1024, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.1
2024-10-04 08:42:59 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of tools/get_available_targets.py.
W_nbits 2 group_size 128 matmul_type AUTO
ref_time 0.10675200074911118
(1, 8192, 8192) Triton Speed-up vs. torch.matmul 1.61 time 0.06646399945020676
(1, 8192, 8192) Bitblas Speed-up vs. torch.matmul 2.58
ref_time 0.10672000050544739
(2, 8192, 8192) Triton Speed-up vs. torch.matmul 1.0 time 0.1066880002617836
(2, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.56
ref_time 0.10732799768447876
(4, 8192, 8192) Triton Speed-up vs. torch.matmul 1.01 time 0.10665600001811981
(4, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.57
ref_time 0.1074879989027977
(8, 8192, 8192) Triton Speed-up vs. torch.matmul 1.01 time 0.10675200074911118
(8, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.59
ref_time 0.1080000028014183
(16, 8192, 8192) Triton Speed-up vs. torch.matmul 1.01 time 0.10716799646615982
(16, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.6
ref_time 0.10649599879980087
(32, 8192, 8192) Triton Speed-up vs. torch.matmul 0.95 time 0.11209599673748016
(32, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.43
ref_time 0.10780800133943558
(64, 8192, 8192) Triton Speed-up vs. torch.matmul 0.77 time 0.1404159963130951
(64, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.26
ref_time 0.12700800597667694
(128, 8192, 8192) Triton Speed-up vs. torch.matmul 0.67 time 0.19008000195026398
(128, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.16
ref_time 0.17235200107097626
(256, 8192, 8192) Triton Speed-up vs. torch.matmul 0.57 time 0.30425599217414856
(256, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.11
ref_time 0.3341119885444641
(512, 8192, 8192) Triton Speed-up vs. torch.matmul 0.64 time 0.5241600275039673
(512, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.11
ref_time 0.6436799764633179
(1024, 8192, 8192) Triton Speed-up vs. torch.matmul 0.61 time 1.05404794216156
(1024, 8192, 8192) Bitblas Speed-up vs. torch.matmul 0.1
2024-10-04 08:47:33 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of tools/get_available_targets.py.
W_nbits 4 group_size 128 matmul_type AUTO
ref_time 0.3417600095272064
(1, 16384, 16384) Triton Speed-up vs. torch.matmul 1.43 time 0.2396160066127777
(1, 16384, 16384) Torchao Speed-up vs. torch.matmul 2.76
(1, 16384, 16384) Bitblas Speed-up vs. torch.matmul 2.53
ref_time 0.3420160114765167
(2, 16384, 16384) Triton Speed-up vs. torch.matmul 1.22 time 0.28147199749946594
(2, 16384, 16384) Torchao Speed-up vs. torch.matmul 2.63
(2, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.62
ref_time 0.3447999954223633
(4, 16384, 16384) Triton Speed-up vs. torch.matmul 1.23 time 0.2813439965248108
(4, 16384, 16384) Torchao Speed-up vs. torch.matmul 2.1
(4, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.63
ref_time 0.34703999757766724
(8, 16384, 16384) Triton Speed-up vs. torch.matmul 1.22 time 0.28415998816490173
(8, 16384, 16384) Torchao Speed-up vs. torch.matmul 1.27
(8, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.66
ref_time 0.35043200850486755
(16, 16384, 16384) Triton Speed-up vs. torch.matmul 1.2 time 0.29129600524902344
(16, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.74
(16, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.67
ref_time 0.35417601466178894
(32, 16384, 16384) Triton Speed-up vs. torch.matmul 1.0 time 0.35420799255371094
(32, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.38
(32, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.49
ref_time 0.35763201117515564
(64, 16384, 16384) Triton Speed-up vs. torch.matmul 0.78 time 0.45686399936676025
(64, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.2
(64, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.28
ref_time 0.4419519901275635
(128, 16384, 16384) Triton Speed-up vs. torch.matmul 0.73 time 0.601855993270874
(128, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.12
(128, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.18
ref_time 0.6263999938964844
(256, 16384, 16384) Triton Speed-up vs. torch.matmul 0.56 time 1.124351978302002
(256, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.09
(256, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.13
ref_time 1.1796480417251587
(512, 16384, 16384) Triton Speed-up vs. torch.matmul 0.52 time 2.2871038913726807
(512, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.08
(512, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.12
ref_time 2.451647996902466
(1024, 16384, 16384) Triton Speed-up vs. torch.matmul 0.56 time 4.351327896118164
(1024, 16384, 16384) Torchao Speed-up vs. torch.matmul 0.08
(1024, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.13
2024-10-04 08:53:11 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of tools/get_available_targets.py.
W_nbits 2 group_size 128 matmul_type AUTO
ref_time 0.34540799260139465
(1, 16384, 16384) Triton Speed-up vs. torch.matmul 1.47 time 0.23452800512313843
(1, 16384, 16384) Bitblas Speed-up vs. torch.matmul 2.71
ref_time 0.3437120020389557
(2, 16384, 16384) Triton Speed-up vs. torch.matmul 1.27 time 0.27132800221443176
(2, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.65
ref_time 0.34672001004219055
(4, 16384, 16384) Triton Speed-up vs. torch.matmul 1.28 time 0.27113598585128784
(4, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.67
ref_time 0.3494400084018707
(8, 16384, 16384) Triton Speed-up vs. torch.matmul 1.28 time 0.2725760042667389
(8, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.7
ref_time 0.35254400968551636
(16, 16384, 16384) Triton Speed-up vs. torch.matmul 1.28 time 0.27478399872779846
(16, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.74
ref_time 0.35417601466178894
(32, 16384, 16384) Triton Speed-up vs. torch.matmul 1.2 time 0.2962239980697632
(32, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.44
ref_time 0.3561280071735382
(64, 16384, 16384) Triton Speed-up vs. torch.matmul 0.97 time 0.36822399497032166
(64, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.22
ref_time 0.44419199228286743
(128, 16384, 16384) Triton Speed-up vs. torch.matmul 0.73 time 0.6117759943008423
(128, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.15
ref_time 0.640608012676239
(256, 16384, 16384) Triton Speed-up vs. torch.matmul 0.61 time 1.0531200170516968
(256, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.11
ref_time 1.2030080556869507
(512, 16384, 16384) Triton Speed-up vs. torch.matmul 0.57 time 2.12825608253479
(512, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.1
ref_time 2.5157759189605713
(1024, 16384, 16384) Triton Speed-up vs. torch.matmul 0.62 time 4.054175853729248
(1024, 16384, 16384) Bitblas Speed-up vs. torch.matmul 0.1
2024-10-04 08:58:45 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of tools/get_available_targets.py.
W_nbits 4 group_size 128 matmul_type AUTO
ref_time 1.2765120267868042
(1, 32768, 32768) Triton Speed-up vs. torch.matmul 1.35 time 0.9481279850006104
(1, 32768, 32768) Torchao Speed-up vs. torch.matmul 2.76
(1, 32768, 32768) Bitblas Speed-up vs. torch.matmul 2.56
ref_time 1.3089280128479004
(2, 32768, 32768) Triton Speed-up vs. torch.matmul 1.37 time 0.9542400240898132
(2, 32768, 32768) Torchao Speed-up vs. torch.matmul 2.63
(2, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.86
ref_time 1.3063360452651978
(4, 32768, 32768) Triton Speed-up vs. torch.matmul 1.37 time 0.9519360065460205
(4, 32768, 32768) Torchao Speed-up vs. torch.matmul 2.02
(4, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.86
ref_time 1.3099839687347412
(8, 32768, 32768) Triton Speed-up vs. torch.matmul 1.37 time 0.9582080245018005
(8, 32768, 32768) Torchao Speed-up vs. torch.matmul 1.21
(8, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.89
ref_time 1.3158719539642334
(16, 32768, 32768) Triton Speed-up vs. torch.matmul 1.34 time 0.9845439791679382
(16, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.72
(16, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.89
ref_time 1.3440959453582764
(32, 32768, 32768) Triton Speed-up vs. torch.matmul 1.13 time 1.1911360025405884
(32, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.37
(32, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.53
ref_time 1.3771519660949707
(64, 32768, 32768) Triton Speed-up vs. torch.matmul 0.87 time 1.5784000158309937
(64, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.19
(64, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.27
ref_time 1.5383360385894775
(128, 32768, 32768) Triton Speed-up vs. torch.matmul 0.63 time 2.423840045928955
(128, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.11
(128, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.16
ref_time 2.5936319828033447
(256, 32768, 32768) Triton Speed-up vs. torch.matmul 0.56 time 4.658976078033447
(256, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.09
(256, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.14
ref_time 5.009151935577393
(512, 32768, 32768) Triton Speed-up vs. torch.matmul 0.57 time 8.741791725158691
(512, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.08
(512, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13
ref_time 9.679488182067871
(1024, 32768, 32768) Triton Speed-up vs. torch.matmul 0.52 time 18.480607986450195
(1024, 32768, 32768) Torchao Speed-up vs. torch.matmul 0.08
(1024, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13
2024-10-04 09:07:17 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of tools/get_available_targets.py.
W_nbits 2 group_size 128 matmul_type AUTO
ref_time 1.2772480249404907
(1, 32768, 32768) Triton Speed-up vs. torch.matmul 1.69 time 0.7576320171356201
(1, 32768, 32768) Bitblas Speed-up vs. torch.matmul 2.82
ref_time 1.3076159954071045
(2, 32768, 32768) Triton Speed-up vs. torch.matmul 1.73 time 0.7579519748687744
(2, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.84
ref_time 1.3032000064849854
(4, 32768, 32768) Triton Speed-up vs. torch.matmul 1.72 time 0.7576000094413757
(4, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.85
ref_time 1.3141759634017944
(8, 32768, 32768) Triton Speed-up vs. torch.matmul 1.71 time 0.7696319818496704
(8, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.88
ref_time 1.3204799890518188
(16, 32768, 32768) Triton Speed-up vs. torch.matmul 1.68 time 0.7845759987831116
(16, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.9
ref_time 1.3463040590286255
(32, 32768, 32768) Triton Speed-up vs. torch.matmul 1.28 time 1.0521600246429443
(32, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.51
ref_time 1.3624639511108398
(64, 32768, 32768) Triton Speed-up vs. torch.matmul 0.98 time 1.3894399404525757
(64, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.27
ref_time 1.546463966369629
(128, 32768, 32768) Triton Speed-up vs. torch.matmul 0.71 time 2.176255941390991
(128, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.16
ref_time 2.424799919128418
(256, 32768, 32768) Triton Speed-up vs. torch.matmul 0.56 time 4.307519912719727
(256, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13
ref_time 5.034048080444336
(512, 32768, 32768) Triton Speed-up vs. torch.matmul 0.62 time 8.100383758544922
(512, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13
ref_time 9.681568145751953
(1024, 32768, 32768) Triton Speed-up vs. torch.matmul 0.58 time 16.774080276489258
(1024, 32768, 32768) Bitblas Speed-up vs. torch.matmul 0.13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment