Created
December 31, 2023 10:21
-
-
Save buttercutter/eddddcfa93fa82711afa173819f41a5d to your computer and use it in GitHub Desktop.
[Half-Quadratic Quantization of Large Machine Learning Models](https://mobiusml.github.io/hqq_blog/)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reference: [Half-Quadratic Quantization of Large Machine Learning Models](https://mobiusml.github.io/hqq_blog/) | |
import numpy as np | |
# Define the shrinkage function for soft-thresholding | |
def shrink(x, beta, p): | |
return np.sign(x) * np.maximum(np.abs(x) - (np.abs(x)**(p-1))/beta, 0) | |
# Define the quantization and dequantization operators | |
def quantize(W, s, z): | |
return np.round(W / s + z) | |
def dequantize(Wq, s, z): | |
return s * (Wq - z) | |
# Initialize parameters | |
W = np.random.randn(10, 10) # Replace with actual weights | |
print(f"W = {W}") | |
''' | |
The choice of scaling factor (s) and zero point (z) can significantly affect the accuracy of the dequantization process in recovering the original unquantized weights from the quantized weights. Here are some key points: | |
- The scaling factor s controls the "step size" of quantization levels. A larger s means coarser quantization and lower accuracy in representing the original distribution of weights. | |
- The zero point z determines the offset of the quantization range. An inappropriate zero point can clip part of the weights' distribution, losing information. | |
- An overly large s will quantize weights to a small set of levels, losing precision. Small s retains more precision but requires more bits for storage. | |
- A zero point z shifted significantly from the center of the weights' distribution will clip off values on one end, losing range. Centering z helps preserve the distribution. | |
- The optimal s and z depend on the statistical distribution of weights. These should be set to retain as much precision as possible for the weights. | |
- For a fixed number of bits, there is a tradeoff between s and z. Larger s may allow better z centering of the range. | |
- The dequantization accuracy depends directly on how well s and z can undo the quantization and recover the original unquantized weights. | |
So in summary, s and z should be carefully optimized based on the weight statistics to maximize dequantization accuracy and retain as much information as possible from the original weights. | |
Credit: Claude2 AI chatbot | |
''' | |
s = 1 # Scale factor (can be learned as well) | |
z = 0 # Zero-point (initially, can be 0) | |
beta = 0.001 # Beta for the HQQ algorithm | |
k = 0.9 # Update factor for beta | |
p = 1 # p-norm | |
num_iterations = 100 # Number of iterations for optimization | |
tolerance = 1e-5 # Tolerance for convergence | |
# Initialize the extra variable We to the original weights W | |
We = np.copy(W) | |
# Optimization loop | |
for i in range(num_iterations): | |
prev_We = np.copy(We) | |
# Update We using the shrinkage function and the previous We | |
Wq = quantize(W, s, z) | |
Wdq = dequantize(Wq, s, z) | |
We = shrink(Wdq, beta, p) | |
# Update z based on the new We | |
z = np.mean(Wq - (W - We) / s) | |
# Update beta | |
beta *= k | |
# Check for convergence | |
if np.linalg.norm(We - prev_We) < tolerance: | |
break | |
# Output the quantized weights | |
Wq = quantize(W, s, z) | |
print(f"Wq = {Wq}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment