Skip to content

Instantly share code, notes, and snippets.

@vanbasten23
Last active June 4, 2025 16:10
Show Gist options
  • Save vanbasten23/d0bf64d8866bc095121cff6df57c467f to your computer and use it in GitHub Desktop.
Save vanbasten23/d0bf64d8866bc095121cff6df57c467f to your computer and use it in GitHub Desktop.
# python pytorch/xla/test/quantized_ops/test_quantized_matmul.py -k test_blockwise_matmul_op
# python pytorch/xla/test/quantized_ops/test_quantized_matmul.py -k test_asymmetric_per_channel
import torch
import torch.nn.functional as F
import torch_xla
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB
XLA_LIB.define(
"quantized_matmul(Tensor x, Tensor w, Tensor scale, Tensor? zero_point=None, int? block_size=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor"
)
def _check_per_channel_quant_weight_dtype_shapes(input_dim, output_dim, w,
w_scaler, zero_point):
assert w.dtype == torch.int8, f"Weight dtype is expected to be torch.int8, got {w.dtype}."
assert w.dim(
) == 2, f"Weight tensor is expected to be 2D, got {w.dim()}D Tensor."
w_shape = list(w.shape)
assert output_dim == w_shape[0] and input_dim == w_shape[
1], f"Weight shape is expected to be [output_dim, input_dim], output_dim: {output_dim}, input_dim: {input_dim}, but got {w_shape}."
assert w_scaler.dim() == 1 and w_scaler.shape[0] == w_shape[
0], f"weight scaler shape is expect to be [out_channel,], got {w_scaler.shape}, weight shape {w_shape}."
if zero_point is not None:
assert zero_point.dim() == 1 and w_scaler.shape[0] == w_shape[
0], f"zero point shape is expect to be [out_channel,], got {zero_point.shape}, weight shape {w_shape}."
def _check_blockwise_quant_weight_dtype_shapes(input_dim, output_dim,
block_size, w, w_scaler,
zero_point):
assert w.dtype == torch.int8, (
f"Weight dtype is expected to be torch.int8, got {w.dtype}.")
assert w.dim() == 3, (
f"Weight tensor is expected to be 3D, got {w.dim()}D Tensor.")
w_shape = list(w.shape)
assert input_dim % block_size == 0, (
f"input_dim should be divisible by block_size, "
f"got input_dim: {input_dim}, block_size: {block_size}.")
assert w_shape[0] == input_dim / block_size and w_shape[1] == block_size, (
f"Weight shape is expected to be [input_dim / block_size, block_size, output_dim], "
f"input_dim: {input_dim}, block_size: {block_size}, output_dim: {output_dim}, "
f"but got {w_shape}.")
assert w_scaler.dim() == 2, (
f"weight scaler is expected to be 2D, got {w_scaler.dim()}D Tensor.")
assert w_scaler.shape[0] == w_shape[0] and w_scaler.shape[1] == w_shape[-1], (
f"weight scaler shape is expect to be [in_channel / block_size, out_channel], "
f"got {w_scaler.shape}, weight shape {w_shape}.")
if zero_point is not None:
assert zero_point.dim() == 2, (
f"zero_point is expected to be 2D, got {zero_point.dim()}D Tensor.")
assert zero_point.shape[0] == w_shape[0] and zero_point.shape[1] == w_shape[
-1], (
f"zero_point shape is expect to be [in_channel / block_size, out_channel], "
f"got {zero_point.shape}, weight shape {w_shape}.")
def _quantize_tensor(x: torch.Tensor, n_bits: int = 8, dim: int = -1):
"""
Quantizes a tensor to a lower bit representation.
Args:
x (torch.Tensor): The input tensor to be quantized.
n_bits (int, optional): The number of bits to represent the quantized tensor. Defaults to 8.
dim (int, optional): The dimension along which to compute the maximum value for scaling. Defaults to -1.
Returns:
torch.Tensor: The quantized tensor. (In int8 container)
torch.Tensor: The scaling factor used for quantization. (Same dtype as x)
"""
max_val = torch.amax(torch.abs(x), dim=dim, keepdim=True)
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scale = max_val / int_max
x_int = torch.clamp(torch.round(x / scale), int_min, int_max).to(torch.int8)
return x_int, scale.to(x.dtype)
@impl(XLA_LIB, "quantized_matmul", "XLA")
def quantized_matmul_xla(x: torch.Tensor,
w: torch.Tensor,
scaler: torch.Tensor,
zero_point: torch.Tensor = None,
block_size: int = -1,
int4_weight: bool = False,
quantize_activation: bool = False):
"""Quantized Matrix Multiply op on XLA devices.
Args:
x: torch.Tensor - Activation of Matmul [..., in_channel].
w: torch.Tensor - Weight Tensor.
per-channel quant: torch.int8 x [out_channel, in_channel].
block_wise quant: torch.int8 x [in_channel / block_size, block_size, out_channel].
scaler: torch.Tensor - Weight scaler.
per-channel quant: [out_channel,].
blockwise quant: [in_channel / block_size, out_channel].
zero_point: Optional[torch.Tensor] - Zero point tensor.
per-channel quant: [out_channel,].
blockwise quant: [in_channel / block_size, out_channel].
block_size: The blocksize for blockwise quantization, -1 for per-channel quantization.
int4_weight: if the weights are int4, the int4 weights need to be stored in a int8
container (unpacked).
"""
if int4_weight:
# Reinterpret cast the weight to s4 dtype in XLA.
w = torch_xla._XLAC._xla_cast_int4(w, w.cpu().flatten().numpy().tolist())
if block_size == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0],
w, scaler, zero_point)
if quantize_activation:
x, x_scale = _quantize_tensor(x)
out = torch_xla._XLAC._xla_dot_general(
x, w, (([-1], [-1]), ()), preferred_element_type=torch.int32)
else:
out = F.linear(x, w)
out = out * scaler
else:
# Blockwise quant.
assert quantize_activation == False, (
"Blockwise quantization does not support activation quantization.")
_check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1],
block_size, w, scaler,
zero_point)
x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size)
print(f'xw32 line124 blockwise_quant: {w.shape=}, {x.shape=}') # w.shape=torch.Size([3, 2, 8]), x.shape=torch.Size([3, 3, 2])
out = torch.einsum('scn,...sc->...sn', w, x)
print(f'xw32 line126 blockwise_quant: {scaler.shape=}, {out.shape=}') # scaler.shape=torch.Size([3, 8]), out.shape=torch.Size([3, 3, 8])
out = torch.einsum('sn,...sn->...n', scaler, out)
if zero_point is not None:
if block_size == -1:
# Per-channel quant.
print(f'xw32 line131 per_channel_quant: {x.shape=}, {zero_point.shape=}') # x.shape=torch.Size([3, 6]), zero_point.shape=torch.Size([8])
zp_out = torch.einsum("...c,z->...z", x, zero_point)
else:
# Blockwise quant.
zp_out = x.sum(dim=-1)
zp_out = torch.matmul(zp_out, zero_point)
out -= zp_out
if quantize_activation:
out = out * x_scale
return out
@impl(XLA_LIB, "quantized_matmul", "CompositeExplicitAutograd")
def quantized_matmul(x: torch.Tensor,
w: torch.Tensor,
scaler: torch.Tensor,
zero_point: torch.Tensor = None,
block_size: int = -1,
int4_weight: bool = False,
quantize_activation: bool = False):
if block_size == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0],
w, scaler, zero_point)
w = w.to(x.dtype)
if quantize_activation:
x, x_scale = _quantize_tensor(x)
# Upcast to torch.int32 to make sure the mamtul output is int32.
# Otherwise it will be int8 which causes overflow.
x = x.to(torch.int32)
w = w.to(torch.int32)
out = F.linear(x, w)
out = out * scaler
else:
# Blockwise quant.
assert quantize_activation == False, (
"Blockwise quantization does not support activation quantization.")
_check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1],
block_size, w, scaler,
zero_point)
x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size)
w = w.to(x.dtype)
out = torch.einsum('scn,...sc->...sn', w, x)
out = torch.einsum('sn,...sn->...n', scaler, out)
if zero_point is not None:
if block_size == -1:
# Per-channel quant.
zp_out = torch.einsum("...c,z->...z", x, zero_point)
else:
# Blockwise quant.
zp_out = x.sum(dim=-1)
zp_out = torch.matmul(zp_out, zero_point)
out -= zp_out
if quantize_activation:
out = out * x_scale
return out
class XlaQuantizedLinear(torch.nn.Module):
def __init__(self,
input_dim,
output_dim,
is_symmetric: bool = False,
block_size: int = -1,
int4_weight: bool = False,
quantize_activation: bool = False):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.is_symmetric = is_symmetric
self.block_size = block_size
self.int4_weight = int4_weight
self.quantize_activation = quantize_activation
self.register_buffer('weight',
torch.zeros(output_dim, input_dim).to(torch.int8))
self.register_buffer('weight_scaler', torch.zeros(output_dim))
if not self.is_symmetric:
self.register_buffer('zero_point', torch.zeros(output_dim))
else:
self.zero_point = None
def load_quantized_weight(self, weight, weight_scaler, zero_point=None):
'''
weight (Tensor):
per-channel quant: [out_channel, in_channel].
block_wise quant: [in_channel / block_size, block_size, out_channel].
weight_scaler (Tensor):
per-channel quant: [out_channel,].
blockwise quant: [in_channel / block_size, out_channel].
'''
if self.block_size == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(self.input_dim,
self.output_dim, weight,
weight_scaler, zero_point)
else:
# Blockwise quant.
_check_blockwise_quant_weight_dtype_shapes(self.input_dim,
self.output_dim,
self.block_size, weight,
weight_scaler, zero_point)
self.weight = weight
self.weight_scaler = weight_scaler
self.zero_point = zero_point
def forward(self, x):
return torch.ops.xla.quantized_matmul(
x,
self.weight,
self.weight_scaler,
self.zero_point,
block_size=self.block_size,
int4_weight=self.int4_weight,
quantize_activation=self.quantize_activation)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment