Skip to content

Instantly share code, notes, and snippets.

@airMeng
Last active April 17, 2025 03:09
Show Gist options
  • Save airMeng/eba1cfc5cfaed41fecf8eafe93e7d601 to your computer and use it in GitHub Desktop.
Save airMeng/eba1cfc5cfaed41fecf8eafe93e7d601 to your computer and use it in GitHub Desktop.
Static quantization in AO

https://github.com/pytorch/ao/blob/main/tutorials/calibration_flow/static_quant.py

Turn to the tutorial above for 2 methods of static quantization

1. Regular Linear with quantized tensor-subclass

# Regular Linear
linear = torch.nn.Linear(
    observed_linear.in_features,
    observed_linear.out_features,
    False,
    device=observed_linear.weight.device,
    dtype=observed_linear.weight.dtype,
)

# quantized weight fron INC
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

# quantized weight is tensor sub-class
linear.weight = torch.nn.Parameter(
    weight_quant_func(linear.weight), requires_grad=False
)

# activation quantization
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()

# pre-hook for activation
input_quant_func = lambda x: to_affine_quantized_floatx_static(
    x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None)
)

# pre-hook for activation, quantized tensor sub-class as weight
linear.weight = torch.nn.Parameter(
    to_linear_activation_quantized(linear.weight, input_quant_func),
    requires_grad=False,
)

Conv should be similiar with Linear, so no demo shown here.

SDPA using native implementation(bmm + softmax + bmm) would be dequantization first:

# float8_layout.py, definition of Float8 tensor sub-class
@register_layout(Float8Layout)
class Float8AQTTensorImpl(AQTTensorImpl):
    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        kwargs = {} if kwargs is None else kwargs

        if func is aten.bemm:
            return return_and_correct_aliasing(
                # dequantized first
                func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
            )
        ...


# sdpa part
score = torch.bmm(q, k.t())
...

#in fact the computation flow would be
score = torch.bmm(q.dequantize(), k.dequantize().t())

SDPA using SDPA backend would be dequantized first too:

with sdpa_kernel(backends=[SDPBackend.MATH]):
    F.scaled_dot_product_attention(query.dequantized(),key.dequantized(),value.dequantized(),enable_gqa=True)

Implementation plan

  1. Meta to fix serialization first
  2. Add Pre-hook for Convolution
  3. Add dispatch of BMM and SDPA to float8 tensor subclass

2. Quantized Linear

class QuantizedLinear(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        weight: torch.Tensor,
        bias: torch.Tensor,
        target_dtype: torch.dtype,
    ):
        super().__init__()
        self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
        weight_scale, weight_zero_point = weight_obs.calculate_qparams()
        assert weight.dim() == 2
        block_size = (1, weight.shape[1])
        self.target_dtype = target_dtype
        self.bias = bias
        mm_config = Float8MMConfig(use_fast_accum=True)
        self.qweight = to_affine_quantized_floatx_static(
            weight,
            weight_scale,
            block_size,
            target_dtype,
            Float8Layout(mm_config=mm_config),
        )

    def forward(self, input: Tensor):
        block_size = input.shape
        qinput = to_affine_quantized_floatx_static(
            input,
            self.act_scale,
            block_size,
            self.target_dtype,
            Float8Layout(mm_config=None),
        )
        return F.linear(qinput, self.qweight, self.bias)

    @classmethod
    def from_observed(cls, observed_linear, target_dtype):
        quantized_linear = cls(
            observed_linear.in_features,
            observed_linear.out_features,
            observed_linear.act_obs,
            observed_linear.weight_obs,
            observed_linear.weight,
            observed_linear.bias,
            target_dtype,
        )
        return quantized_linear

# the real quantization process
for module in model.modules:
    ...
    # check whether it is Linear
    new_linear = QuantizedLinear.from_observed(module, config.target_dtype)

Also the same approach for Convolution.

However, for SDPA, even the SDPA implementation already fused using SDPA operators, we can't do the conversion because it is not a module, not the mention the native implementation

# assuming FP8 SPDA already defined in PyTorch
with sdpa_kernel(backends=[SDPBackend.MATH]):
    F.scaled_dot_product_attention(query,key,value,enable_gqa=True, q_scale, k_scale, v_scale)

# However, we can't convert the following to the SDPA kernels above
with sdpa_kernel(backends=[SDPBackend.MATH]):
    F.scaled_dot_product_attention(query, key, value, enable_gqa=True)

Implementation plan

  1. Add a nn.module QConvolution
  2. Register the SDPA as a nn.module in stock PT
  3. add Q_SDPA if 2 works
  4. fallback to option 1 if 2 doesn't work

Comparsion

Option 1:

Pros:

no additional requirements for PyTorch frontend, all work can be done within AO side

scalable to the broader operators / models

Cons: aditional effort to inductor

Option2

Pros: The model structure would be expicit to the end users: linear=>qlinear, conv=>qconv

Cons: Depend on PyTorch frontend to register the SDPA module, or need to fallback to option 1 for SDPA.

)

@leslie-fang-intel
Copy link

For these 2 options, can you provide a high-level pseudocode for how will users use it?

@airMeng
Copy link
Author

airMeng commented Apr 17, 2025

@leslie-fang-intel
From AO API side, the API should be the same https://github.com/pytorch/ao/blob/7fa9c69dc0999023add31d000d4750e0ac2cd799/tutorials/calibration_flow/static_quant.py#L315

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, StaticQuantConfig(target_dtype), is_observed_linear)

# quantized linear as a standalone module
quantize_(m2, StaticQuantConfig2(target_dtype), is_observed_linear)

If you need a little deeper, option 2 replaces the module directly

# the real quantization process
named_children_list = list(model.named_children())
for name, child in named_children_list:
    new_child = QuantizedLinear.from_observed(child, config.target_dtype)
    if new_child is not child:
        setattr(model, name, new_child)

@leslie-fang-intel
Copy link

leslie-fang-intel commented Apr 17, 2025

Thanks, I understand this implementation details. I mean from a model developer point of view, he will have a model with higher precision, how will he leverage the APIs step by step to get the final FP8 optimized model. Let's give a high-level pseudocode instead of checking the details in AO's example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment