https://github.com/pytorch/ao/blob/main/tutorials/calibration_flow/static_quant.py
Turn to the tutorial above for 2 methods of static quantization
# Regular Linear
linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
# quantized weight fron INC
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias
# quantized weight is tensor sub-class
linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight), requires_grad=False
)
# activation quantization
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
# pre-hook for activation
input_quant_func = lambda x: to_affine_quantized_floatx_static(
x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None)
)
# pre-hook for activation, quantized tensor sub-class as weight
linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(linear.weight, input_quant_func),
requires_grad=False,
)
Conv should be similiar with Linear, so no demo shown here.
SDPA using native implementation(bmm + softmax + bmm) would be dequantization first:
# float8_layout.py, definition of Float8 tensor sub-class
@register_layout(Float8Layout)
class Float8AQTTensorImpl(AQTTensorImpl):
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
if func is aten.bemm:
return return_and_correct_aliasing(
# dequantized first
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
...
# sdpa part
score = torch.bmm(q, k.t())
...
#in fact the computation flow would be
score = torch.bmm(q.dequantize(), k.dequantize().t())
SDPA using SDPA backend would be dequantized first too:
with sdpa_kernel(backends=[SDPBackend.MATH]):
F.scaled_dot_product_attention(query.dequantized(),key.dequantized(),value.dequantized(),enable_gqa=True)
- Meta to fix serialization first
- Add Pre-hook for Convolution
- Add dispatch of BMM and SDPA to float8 tensor subclass
class QuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
act_obs: torch.nn.Module,
weight_obs: torch.nn.Module,
weight: torch.Tensor,
bias: torch.Tensor,
target_dtype: torch.dtype,
):
super().__init__()
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
assert weight.dim() == 2
block_size = (1, weight.shape[1])
self.target_dtype = target_dtype
self.bias = bias
mm_config = Float8MMConfig(use_fast_accum=True)
self.qweight = to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=mm_config),
)
def forward(self, input: Tensor):
block_size = input.shape
qinput = to_affine_quantized_floatx_static(
input,
self.act_scale,
block_size,
self.target_dtype,
Float8Layout(mm_config=None),
)
return F.linear(qinput, self.qweight, self.bias)
@classmethod
def from_observed(cls, observed_linear, target_dtype):
quantized_linear = cls(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.act_obs,
observed_linear.weight_obs,
observed_linear.weight,
observed_linear.bias,
target_dtype,
)
return quantized_linear
# the real quantization process
for module in model.modules:
...
# check whether it is Linear
new_linear = QuantizedLinear.from_observed(module, config.target_dtype)
Also the same approach for Convolution.
However, for SDPA, even the SDPA implementation already fused using SDPA operators, we can't do the conversion because it is not a module, not the mention the native implementation
# assuming FP8 SPDA already defined in PyTorch
with sdpa_kernel(backends=[SDPBackend.MATH]):
F.scaled_dot_product_attention(query,key,value,enable_gqa=True, q_scale, k_scale, v_scale)
# However, we can't convert the following to the SDPA kernels above
with sdpa_kernel(backends=[SDPBackend.MATH]):
F.scaled_dot_product_attention(query, key, value, enable_gqa=True)
- Add a nn.module QConvolution
- Register the SDPA as a nn.module in stock PT
- add Q_SDPA if 2 works
- fallback to option 1 if 2 doesn't work
Option 1:
Pros:
no additional requirements for PyTorch frontend, all work can be done within AO side
scalable to the broader operators / models
Cons: aditional effort to inductor
Option2
Pros: The model structure would be expicit to the end users: linear=>qlinear, conv=>qconv
Cons: Depend on PyTorch frontend to register the SDPA module, or need to fallback to option 1 for SDPA.
)
For these 2 options, can you provide a high-level pseudocode for how will users use it?