Last active
July 11, 2024 20:40
-
-
Save KeremTurgutlu/25f7c9d6d0328621bf6462871698880b to your computer and use it in GitHub Desktop.
HQQ Tinygemm vs BitBlas Benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer, HQQBackend | |
from hqq.backends.torchao import HQQLinearTorchWeightOnlynt4, patch_hqq_to_aoint4 | |
# from unpack_int4.ops import unpack_int4_packed | |
import torchao | |
import bitblas | |
# unpack_cuda_compiled = torch.compile(torchao.ops.unpack_int4_to_int, mode="default", fullgraph=True) | |
from bitblas.cache import global_operator_cache, get_database_path | |
from bitblas.module import BITBLAS_TARGET, BITBLAS_DATABASE_PATH | |
def _get_or_create_bitblas_operator(config): | |
if global_operator_cache.size() == 0: | |
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) | |
bitblas_matmul = global_operator_cache.get(config) | |
if bitblas_matmul is None: | |
# should disable tuning for the first time because we may require loading bitblas operator from database. | |
bitblas_matmul = bitblas.Matmul(config) | |
bitblas_matmul.hardware_aware_finetune(topk=20) | |
global_operator_cache.add(config, bitblas_matmul) | |
global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) | |
print("BitBLAS Tuning done, appended operator to global_operator_cache.") | |
else: | |
print("BitBLAS Operator found in global_operator_cache.") | |
return bitblas_matmul | |
def timed(fn): | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
start.record() | |
result = fn() | |
end.record() | |
torch.cuda.synchronize() | |
return result, start.elapsed_time(end) / 1000 | |
# @torch.compile(fullgraph=False) | |
def hqq_quants_to_torch_quants(W_q, scales, zeros, shape, nbits=4): | |
# W_q = W_q.to(dtype=self.compute_dtype, device=self.device) | |
# scales = scales.to(dtype=self.compute_dtype, device=self.device) | |
# zeros = zeros.to(dtype=self.compute_dtype, device=self.device) | |
max_int = 2**nbits - 1 | |
min_int = 0 | |
dump = 2 ** (nbits - 1) | |
# HQQ -> torch logic | |
new_zeros = (scales * dump) - zeros * scales | |
min_val = new_zeros - scales * dump | |
# group_quantize_tensor_from_qparams | |
W_r = (W_q - zeros) * scales | |
W_q = ( | |
W_r.sub(min_val) | |
.div(scales) | |
.round() | |
.clamp_(min_int, max_int) | |
.to(torch.int32) | |
.reshape(shape) | |
.contiguous() | |
) | |
# group_dequantize_tensor_from_qparams | |
# W_r = W_q*scales + min_val | |
scales = scales.contiguous().reshape(shape[0], -1) | |
new_zeros = new_zeros.contiguous().reshape(shape[0], -1) | |
return W_q, scales, new_zeros | |
def pack_scales_and_zeros(scales, zeros): | |
return ( | |
torch.cat( | |
[ | |
scales.reshape(scales.size(0), scales.size(1), 1), | |
zeros.reshape(zeros.size(0), zeros.size(1), 1), | |
], | |
2, | |
) | |
.transpose(0, 1) | |
.contiguous() | |
) | |
def reshape_packed(packed_tensor): | |
inner_k_tiles = packed_tensor.size(-1) * 2 | |
return packed_tensor.permute(0, 1, 3, 2).reshape(packed_tensor.size(0), | |
packed_tensor.size(1) * (inner_k_tiles // 2), | |
packed_tensor.size(2), | |
1).contiguous() | |
def _unpack_shifting(packed_tensor): | |
return [(packed_tensor >> (i * 4)) & 15 for i in range(8)] | |
@torch.compile(fullgraph=True) | |
def unpack_int4_32_pack_fast(packed_tensor, shape): | |
reshaped_tensor = reshape_packed(packed_tensor) | |
unpacked_tensors = _unpack_shifting(reshaped_tensor) | |
# use torch.cat | |
cat_tensors = [torch.cat(unpacked_tensors[i::4], dim=-1).view(-1, 8) for i in range(4)] | |
concatenated = torch.cat(cat_tensors, dim=-1) | |
# # pre-allocate | |
# concatenated = torch.empty(shape[0]*shape[1]//32, 32, device=reshaped_tensor.device, dtype=reshaped_tensor.dtype) | |
# for i in range(4): | |
# concatenated[:,i*8:(i+1)*8] = torch.cat(unpacked_tensors[i::4], dim=-1).view(-1, 8) | |
group_size = shape[1] // 32 | |
chunked_o = concatenated.view(-1, 8).unsqueeze(0).view(concatenated.size(0) // 8, 8, -1).unsqueeze(0) | |
res = chunked_o.view(-1, group_size, chunked_o.size(2), chunked_o.size(3)).permute(0, 2, 1, 3).reshape(shape) | |
return res | |
def unpack_scales_and_zeros(scales_and_zeros): | |
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 | |
# assert scales_and_zeros.dtype == torch.float | |
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) | |
def group_dequantize_tensor_from_qparams( | |
w_int32, scales, zeros, n_bit=4, groupsize=128 | |
): | |
assert groupsize > 1 | |
# needed for GPTQ single column dequantize | |
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: | |
groupsize = w_int32.shape[-1] | |
assert w_int32.shape[-1] % groupsize == 0 | |
assert w_int32.dim() == 2 | |
w_int32_grouped = w_int32.reshape(-1, groupsize) | |
scales = scales.reshape(-1, 1) | |
zeros = zeros.reshape(-1, 1) | |
w_dq = ( | |
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) | |
) | |
return w_dq | |
@torch.compile(mode="default", fullgraph=True) | |
def tinygemm_unpack_dequant_matmul_naive(x, weight_int4pack, scales_and_zeros, groupsize, shape): | |
unpacked_W_q = unpack_int4_32_pack_fast(weight_int4pack, shape) | |
return x @ group_dequantize_tensor_from_qparams(unpacked_W_q, *unpack_scales_and_zeros(scales_and_zeros), groupsize=groupsize).T | |
# @torch.compile(fullgraph=True) | |
def tinygemm_unpack_dequant_matmul(x, weight_int4pack, scales_and_zeros, groupsize, shape): | |
inner_k_tiles = weight_int4pack.size(-1) * 2 | |
unpacked_W_q = torchao.ops.dequantize_tensor_core_tiled_layout(weight_int4pack, scales_and_zeros, groupsize, inner_k_tiles) | |
return x @ unpacked_W_q.T | |
W_q_torch = torch.randint(0, 16, (8192, 8192), dtype=torch.int32, device="cuda") | |
weight_int4pack_inner_tile2 = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, 2) | |
weight_int4pack_inner_tile4 = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, 4) | |
weight_int4pack_inner_tile8 = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, 8) | |
unpacked_W_q = unpack_int4_32_pack_fast(weight_int4pack_inner_tile2, W_q_torch.shape) | |
assert torch.equal(unpacked_W_q, W_q_torch) | |
unpacked_W_q = unpack_int4_32_pack_fast(weight_int4pack_inner_tile4, W_q_torch.shape) | |
assert torch.equal(unpacked_W_q, W_q_torch) | |
unpacked_W_q = unpack_int4_32_pack_fast(weight_int4pack_inner_tile8, W_q_torch.shape) | |
assert torch.equal(unpacked_W_q, W_q_torch) | |
GROUP_SIZE = 128 | |
quant_config = BaseQuantizeConfig(nbits=4, | |
group_size=GROUP_SIZE, | |
quant_zero=False, | |
quant_scale=False, | |
offload_meta=False, | |
view_as_float=False, | |
axis=1) | |
in_features = 4096 | |
out_features = 7168 | |
W = torch.randn(out_features, in_features, dtype=torch.bfloat16, device="cuda") # output x input | |
m = torch.nn.Linear(*W.T.shape, bias=False) | |
m.weight.data.copy_(W) | |
hqq_linear = HQQLinear(m, quant_config, compute_dtype=torch.bfloat16, del_orig=True) | |
HQQLinear.set_backend(HQQBackend.PYTORCH) | |
# HQQ to Tinygemm conversion (4-bit). | |
W_q_unpacked = Quantizer.unpack[hqq_linear.meta['packing']](hqq_linear.W_q) | |
scale, zero, shape = hqq_linear.meta['scale'], hqq_linear.meta['zero'], hqq_linear.meta['shape'] | |
W_q_torch, scales_torch, zeros_torch = hqq_quants_to_torch_quants(W_q_unpacked, scale, zero, shape) | |
scales_and_zeros = pack_scales_and_zeros(scales_torch, zeros_torch) | |
weight_int4pack_inner_tile8 = torch.ops.aten._convert_weight_to_int4pack(W_q_torch, 8) | |
INPUT_SIZES = [4,8,16,32,64,80,96,128,256,512,1024] | |
BITBLAS_OPT_M = [1, 16, 32, 64, 128, 256, 512] | |
# BITBLAS_OPT_M = [1] | |
# HQQ to bitblas conversion (4-bit). | |
quant_config = BaseQuantizeConfig(nbits=4, | |
group_size=GROUP_SIZE, | |
quant_zero=False, | |
quant_scale=False, | |
offload_meta=False, | |
view_as_float=False, | |
axis=1) | |
W = torch.randn(out_features, in_features, dtype=torch.half, device="cuda") # output x input | |
m = torch.nn.Linear(*W.T.shape, bias=False) | |
m.weight.data.copy_(W) | |
hqq_linear = HQQLinear(m, quant_config, compute_dtype=torch.half, del_orig=True) | |
HQQLinear.set_backend(HQQBackend.PYTORCH) | |
W_q_unpacked = Quantizer.unpack[hqq_linear.meta['packing']](hqq_linear.W_q) | |
scale, zero, shape = hqq_linear.meta['scale'], hqq_linear.meta['zero'], hqq_linear.meta['shape'] | |
matmul_config = bitblas.MatmulConfig( | |
M=BITBLAS_OPT_M, | |
N=out_features, | |
K=in_features, | |
A_dtype="float16", | |
W_dtype="uint4", | |
accum_dtype="float16", | |
out_dtype="float16", | |
layout="nt", | |
with_bias=False, | |
group_size=GROUP_SIZE, | |
with_scaling=True, | |
with_zeros=True, | |
zeros_mode="original", | |
#fast_decoding=True, | |
) | |
matmul_eng_4bit = _get_or_create_bitblas_operator(matmul_config) | |
Wq_bitblas_4bit = matmul_eng_4bit.transform_weight(W_q_unpacked.reshape(shape)) | |
meta_shape_bitblas = (hqq_linear.out_features, hqq_linear.in_features // GROUP_SIZE) | |
scales_bitblas_4bit = scale.view(meta_shape_bitblas) | |
zeros_bitblas_4bit = zero.view(meta_shape_bitblas) | |
# HQQ to bitblas conversion (2-bit). | |
quant_config = BaseQuantizeConfig(nbits=2, | |
group_size=GROUP_SIZE, | |
quant_zero=False, | |
quant_scale=False, | |
offload_meta=False, | |
view_as_float=False, | |
axis=1) | |
W = torch.randn(out_features, in_features, dtype=torch.half, device="cuda") # output x input | |
m = torch.nn.Linear(*W.T.shape, bias=False) | |
m.weight.data.copy_(W) | |
hqq_linear = HQQLinear(m, quant_config, compute_dtype=torch.half, del_orig=True) | |
HQQLinear.set_backend(HQQBackend.PYTORCH) | |
matmul_config = bitblas.MatmulConfig( | |
M=BITBLAS_OPT_M, | |
N=out_features, | |
K=in_features, | |
A_dtype="float16", | |
W_dtype="uint2", | |
accum_dtype="float16", | |
out_dtype="float16", | |
layout="nt", | |
with_bias=False, | |
group_size=GROUP_SIZE, | |
with_scaling=True, | |
with_zeros=True, | |
zeros_mode="original", | |
#fast_decoding=True, | |
) | |
matmul_eng_2bit = _get_or_create_bitblas_operator(matmul_config) | |
Wq_bitblas_2bit = matmul_eng_2bit.transform_weight(W_q_unpacked.reshape(shape)) | |
meta_shape_bitblas = (hqq_linear.out_features, hqq_linear.in_features // GROUP_SIZE) | |
scales_bitblas_2bit = scale.view(meta_shape_bitblas) | |
zeros_bitblas_2bit = zero.view(meta_shape_bitblas) | |
for bs in INPUT_SIZES: # think it of bs x seqlen | |
x = torch.randn(bs, in_features, dtype=torch.bfloat16, device="cuda") | |
x_fp16 = torch.randn(bs, in_features, dtype=torch.half, device="cuda") | |
print(bs) | |
# tinygemm matmul time (ms) | |
times = [] | |
for i in range(30): | |
tinygemm_out, time = timed(lambda: torch.ops.aten._weight_int4pack_mm(x, | |
weight_int4pack_inner_tile8, | |
GROUP_SIZE, | |
scales_and_zeros)) | |
if i > 5: | |
times.append(time*1000) | |
print(f"tinygemm orig matmul: {np.mean(times)}") | |
# tinygemm unpack-dequant-matmul time (ms) | |
times = [] | |
for i in range(30): | |
unpacked_tinygemm_out, time = timed(lambda: tinygemm_unpack_dequant_matmul(x, | |
weight_int4pack_inner_tile8, | |
scales_and_zeros, | |
groupsize=GROUP_SIZE, | |
shape=shape)) | |
if i > 5: | |
times.append(time*1000) | |
np.mean(times) | |
print(f"tinygemm fused unpack-dequant-matmul: {np.mean(times)}") | |
# bitblas 4-bit matmul time (ms) | |
times = [] | |
for i in range(30): | |
bitblas_out, time = timed(lambda: matmul_eng_4bit(x_fp16, | |
Wq_bitblas_4bit, | |
scale=scales_bitblas_4bit, | |
zeros=zeros_bitblas_4bit)) | |
if i > 5: | |
times.append(time*1000) | |
np.mean(times) | |
print(f"bitblas 4-bit matmul: {np.mean(times)}") | |
# bitblas 2-bit matmul time (ms) | |
times = [] | |
for i in range(30): | |
bitblas_out, time = timed(lambda: matmul_eng_2bit(x_fp16, | |
Wq_bitblas_2bit, | |
scale=scales_bitblas_2bit, | |
zeros=zeros_bitblas_2bit)) | |
if i > 5: | |
times.append(time*1000) | |
np.mean(times) | |
print(f"bitblas 2-bit matmul: {np.mean(times)}") | |
_get_or_create_bitblas_operator(matmul_config) | |
# BitBLAS Tuning done, appended operator to global_operator_cache. | |
# BitBLAS Tuning done, appended operator to global_operator_cache. | |
# 4 | |
# tinygemm orig matmul: 0.038186667331804834 | |
# tinygemm fused unpack-dequant-matmul: 0.27805466825763386 | |
# bitblas 4-bit matmul: 0.0710400016978383 | |
# bitblas 2-bit matmul: 0.10022266364345948 | |
# 8 | |
# tinygemm orig matmul: 0.044628000197311245 | |
# tinygemm fused unpack-dequant-matmul: 0.2707599997520447 | |
# bitblas 4-bit matmul: 0.06920533441007137 | |
# bitblas 2-bit matmul: 0.10069333016872406 | |
# 16 | |
# tinygemm orig matmul: 0.07483733631670475 | |
# tinygemm fused unpack-dequant-matmul: 0.2706800041099389 | |
# bitblas 4-bit matmul: 0.07005466800183058 | |
# bitblas 2-bit matmul: 0.10090666357427835 | |
# 32 | |
# tinygemm orig matmul: 0.12526933290064335 | |
# tinygemm fused unpack-dequant-matmul: 0.26960799594720203 | |
# bitblas 4-bit matmul: 0.09228533475349347 | |
# bitblas 2-bit matmul: 0.08145066661139329 | |
# 64 | |
# tinygemm orig matmul: 0.23359866812825203 | |
# tinygemm fused unpack-dequant-matmul: 0.26743199800451595 | |
# bitblas 4-bit matmul: 0.09467333555221558 | |
# bitblas 2-bit matmul: 0.1003906639913718 | |
# 80 | |
# tinygemm orig matmul: 0.28740132972598076 | |
# tinygemm fused unpack-dequant-matmul: 0.2831786771615346 | |
# bitblas 4-bit matmul: 0.13239066861569881 | |
# bitblas 2-bit matmul: 0.13367333138982454 | |
# 96 | |
# tinygemm orig matmul: 0.345684003084898 | |
# tinygemm fused unpack-dequant-matmul: 0.28288133814930916 | |
# bitblas 4-bit matmul: 0.13734399837752184 | |
# bitblas 2-bit matmul: 0.1394786648452282 | |
# 128 | |
# tinygemm orig matmul: 0.45201199998458225 | |
# tinygemm fused unpack-dequant-matmul: 0.2844146713614464 | |
# bitblas 4-bit matmul: 0.15479466691613197 | |
# bitblas 2-bit matmul: 0.1588479975859324 | |
# 256 | |
# tinygemm orig matmul: 0.8864426662524542 | |
# tinygemm fused unpack-dequant-matmul: 0.3223479986190796 | |
# bitblas 4-bit matmul: 0.2660679966211319 | |
# bitblas 2-bit matmul: 0.24575733207166195 | |
# 512 | |
# tinygemm orig matmul: 1.8986239979664485 | |
# tinygemm fused unpack-dequant-matmul: 0.4362240085999171 | |
# bitblas 4-bit matmul: 0.3694506672521432 | |
# bitblas 2-bit matmul: 0.3685973385969798 | |
# 1024 | |
# tinygemm orig matmul: 3.9093759953975677 | |
# tinygemm fused unpack-dequant-matmul: 0.6917120044430097 | |
# bitblas 4-bit matmul: 0.6495146602392197 | |
# bitblas 2-bit matmul: 0.6729813342293104 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment