Skip to content

Instantly share code, notes, and snippets.

@JyotinderSingh
Created August 6, 2025 08:18
Show Gist options
  • Save JyotinderSingh/0bbf721f76d8c52b7cf3a91fb038c315 to your computer and use it in GitHub Desktop.
Save JyotinderSingh/0bbf721f76d8c52b7cf3a91fb038c315 to your computer and use it in GitHub Desktop.
Modularized GPTQ logic
class Dense(Layer):
...
def get_gptq_handler(self) -> "GPTQHandler":
"""Provides a GPTQHandler for a standard Dense layer."""
return GPTQHandler(
kernel=self.kernel,
bias=self.bias,
rows=self.kernel.shape[0],
columns=self.kernel.shape[1],
original_kernel_shape=self.kernel.shape,
)
class EinsumDense(Layer):
...
def get_gptq_handler(self) -> "GPTQHandler | None":
"""Provides a GPTQHandler for an EinsumDense layer."""
# Case 1: The kernel is already 2D.
if self.kernel.ndim == 2:
return GPTQHandler(
kernel=self.kernel,
bias=self.bias,
rows=self.kernel.shape[0],
columns=self.kernel.shape[1],
original_kernel_shape=self.kernel.shape,
)
# Case 2: The kernel is 3D (common in attention mechanisms).
if self.kernel.ndim == 3:
shape = list(self.kernel.shape)
try:
# Heuristic to find the model dimension (d_model), which is
# typically the largest dimension in the shape.
d_model_dim_index = shape.index(max(shape))
except ValueError:
# This case should be rare.
return None
# Determine the effective 2D shape for quantization.
if d_model_dim_index == 0: # QKV projection: (d_model, heads, head_dim)
in_features, heads, head_dim = shape
rows, columns = in_features, heads * head_dim
else: # Attention Output: (heads, head_dim, d_model)
heads, head_dim, out_features = shape
rows, columns = heads * head_dim, out_features
# The kernel is reshaped to 2D for the GPTQ algorithm.
reshaped_kernel = ops.reshape(self.kernel, (rows, columns))
return GPTQHandler(
kernel=reshaped_kernel,
bias=self.bias,
rows=rows,
columns=columns,
original_kernel_shape=self.kernel.shape,
)
# Return None for unsupported kernel dimensions (e.g., ndim=1 or >3).
raise ValueError("...")
import dataclasses
from keras.src import ops
from .quant import quantize
@dataclasses.dataclass
class GPTQHandler:
"""
A data class to hold a layer's effective 2D representation for GPTQ.
This handler acts as a standardized interface between a Keras layer and the
GPTQ algorithm. It provides the GPTQ class with a consistent, 2D view of
the kernel to be quantized, abstracting away the specific implementation
details of different layer types (e.g., Dense vs. EinsumDense).
"""
kernel: any
bias: any
rows: int
columns: int
original_kernel_shape: tuple
class GPTQ:
def __init__(self, layer):
self.original_layer = layer
self.nsamples = 0
self.quantizer = None
# Delegate the configuration logic to the layer itself.
handler = layer.get_gptq_handler()
if not isinstance(handler, GPTQHandler):
raise TypeError(
f"Layer '{type(layer).__name__}' does not support GPTQ. "
f"To add support, implement the `get_gptq_handler()` method "
f"on the layer class to return a valid GPTQHandler object."
)
# The handler provides a consistent, simplified view of the layer.
self.handler = handler
self.rows = handler.rows
self.columns = handler.columns
# The Hessian is initialized based on the handler's dimensions.
self.H = ops.zeros((self.rows, self.rows), dtype="float32")
def update_hessian_with_batch(self, inp):
"""Updates the Hessian matrix using a batch of input data."""
if len(inp.shape) > 2:
inp = ops.reshape(inp, (-1, inp.shape[-1]))
inp = ops.cast(inp, "float32")
if self.H.shape[0] != inp.shape[-1]:
raise ValueError(
f"Hessian dimensions ({self.H.shape[0]}) do not "
f"match input features ({inp.shape[-1]})."
)
current_H = 2 * ops.matmul(ops.transpose(inp), inp)
# Update the Hessian with a weighted average.
if self.nsamples == 0:
self.H = current_H
else:
total_samples = self.nsamples + inp.shape[0]
self.H = self.H * (self.nsamples / total_samples)
self.H += current_H * (inp.shape[0] / total_samples)
self.nsamples += inp.shape[0]
def quantize_and_correct_block(
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False
):
# Use the effective 2D kernel from the handler.
W = ops.transpose(ops.cast(self.handler.kernel, "float32"))
H = ops.cast(self.H, "float32")
if actorder:
perm = ops.argsort(-ops.diagonal(H))
W = ops.take(W, perm, axis=1)
H = ops.take(ops.take(H, perm, axis=0), perm, axis=1)
invperm = ops.argsort(perm)
# Dampen the Hessian matrix for numerical stability.
diag_H = ops.diagonal(H)
dead = ops.equal(diag_H, 0.0)
diag_H = ops.where(dead, 1.0, diag_H)
H = H + ops.diag(ops.where(dead, 1.0, ops.zeros_like(diag_H)))
damp = percdamp * ops.mean(diag_H)
diag_H = diag_H + damp
H = (H - ops.diag(ops.diagonal(H))) + ops.diag(diag_H)
Hinv = ops.linalg.inv(H)
Q = ops.zeros_like(W)
# Process the kernel in blocks.
for i1 in range(0, self.rows, blocksize):
i2 = min(i1 + blocksize, self.rows)
count = i2 - i1
W1 = W[:, i1:i2]
Q1 = ops.zeros_like(W1)
Err1 = ops.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
# Find quantization parameters.
if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
)
else:
self.quantizer.find_params(
ops.expand_dims(w, 1), weight=True
)
# Quantize the current column.
q = quantize(
ops.expand_dims(w, 1),
self.quantizer.scale,
self.quantizer.zero,
self.quantizer.maxq,
)[:, 0]
Q1 = ops.slice_update(Q1, (0, i), ops.expand_dims(q, axis=1))
err = (w - q) / d
Err1 = ops.slice_update(
Err1, (0, i), ops.expand_dims(err, axis=1)
)
# Update the remaining weights in the block with the error.
if i < count - 1:
update = ops.matmul(
ops.expand_dims(err, 1),
ops.expand_dims(Hinv1[i, i + 1 :], 0),
)
slice_to_update = W1[:, i + 1 :]
updated_slice = slice_to_update - update
W1 = ops.slice_update(W1, (0, i + 1), updated_slice)
Q = ops.concatenate([Q[:, :i1], Q1, Q[:, i2:]], axis=1)
# Update the remaining weights outside the block.
if i2 < self.rows:
update_total = ops.matmul(Err1, Hinv[i1:i2, i2:])
W = ops.concatenate(
[W[:, :i2], W[:, i2:] - update_total], axis=1
)
if actorder:
Q = ops.take(Q, invperm, axis=1)
Q = ops.transpose(Q)
# Reshape the quantized kernel back to its original shape.
Q = ops.reshape(Q, self.handler.original_kernel_shape)
new_weights = [ops.convert_to_numpy(Q)]
if self.original_layer.bias is not None:
new_weights.append(ops.convert_to_numpy(self.original_layer.bias))
self.original_layer.set_weights(new_weights)
def free(self):
"""Frees the memory used by the Hessian matrix."""
self.H = None
class Layer:
...
def get_gptq_handler(self) -> "GPTQHandler | None":
"""
Returns a handler with the necessary information for GPTQ.
Layers supporting GPTQ must override this method. The handler should
provide an effective 2D view of the kernel to be quantized. The
default implementation indicates that the layer is not supported.
"""
# By default, a layer does not support GPTQ.
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment