Created
August 6, 2025 08:18
-
-
Save JyotinderSingh/0bbf721f76d8c52b7cf3a91fb038c315 to your computer and use it in GitHub Desktop.
Modularized GPTQ logic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Dense(Layer): | |
| ... | |
| def get_gptq_handler(self) -> "GPTQHandler": | |
| """Provides a GPTQHandler for a standard Dense layer.""" | |
| return GPTQHandler( | |
| kernel=self.kernel, | |
| bias=self.bias, | |
| rows=self.kernel.shape[0], | |
| columns=self.kernel.shape[1], | |
| original_kernel_shape=self.kernel.shape, | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class EinsumDense(Layer): | |
| ... | |
| def get_gptq_handler(self) -> "GPTQHandler | None": | |
| """Provides a GPTQHandler for an EinsumDense layer.""" | |
| # Case 1: The kernel is already 2D. | |
| if self.kernel.ndim == 2: | |
| return GPTQHandler( | |
| kernel=self.kernel, | |
| bias=self.bias, | |
| rows=self.kernel.shape[0], | |
| columns=self.kernel.shape[1], | |
| original_kernel_shape=self.kernel.shape, | |
| ) | |
| # Case 2: The kernel is 3D (common in attention mechanisms). | |
| if self.kernel.ndim == 3: | |
| shape = list(self.kernel.shape) | |
| try: | |
| # Heuristic to find the model dimension (d_model), which is | |
| # typically the largest dimension in the shape. | |
| d_model_dim_index = shape.index(max(shape)) | |
| except ValueError: | |
| # This case should be rare. | |
| return None | |
| # Determine the effective 2D shape for quantization. | |
| if d_model_dim_index == 0: # QKV projection: (d_model, heads, head_dim) | |
| in_features, heads, head_dim = shape | |
| rows, columns = in_features, heads * head_dim | |
| else: # Attention Output: (heads, head_dim, d_model) | |
| heads, head_dim, out_features = shape | |
| rows, columns = heads * head_dim, out_features | |
| # The kernel is reshaped to 2D for the GPTQ algorithm. | |
| reshaped_kernel = ops.reshape(self.kernel, (rows, columns)) | |
| return GPTQHandler( | |
| kernel=reshaped_kernel, | |
| bias=self.bias, | |
| rows=rows, | |
| columns=columns, | |
| original_kernel_shape=self.kernel.shape, | |
| ) | |
| # Return None for unsupported kernel dimensions (e.g., ndim=1 or >3). | |
| raise ValueError("...") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import dataclasses | |
| from keras.src import ops | |
| from .quant import quantize | |
| @dataclasses.dataclass | |
| class GPTQHandler: | |
| """ | |
| A data class to hold a layer's effective 2D representation for GPTQ. | |
| This handler acts as a standardized interface between a Keras layer and the | |
| GPTQ algorithm. It provides the GPTQ class with a consistent, 2D view of | |
| the kernel to be quantized, abstracting away the specific implementation | |
| details of different layer types (e.g., Dense vs. EinsumDense). | |
| """ | |
| kernel: any | |
| bias: any | |
| rows: int | |
| columns: int | |
| original_kernel_shape: tuple | |
| class GPTQ: | |
| def __init__(self, layer): | |
| self.original_layer = layer | |
| self.nsamples = 0 | |
| self.quantizer = None | |
| # Delegate the configuration logic to the layer itself. | |
| handler = layer.get_gptq_handler() | |
| if not isinstance(handler, GPTQHandler): | |
| raise TypeError( | |
| f"Layer '{type(layer).__name__}' does not support GPTQ. " | |
| f"To add support, implement the `get_gptq_handler()` method " | |
| f"on the layer class to return a valid GPTQHandler object." | |
| ) | |
| # The handler provides a consistent, simplified view of the layer. | |
| self.handler = handler | |
| self.rows = handler.rows | |
| self.columns = handler.columns | |
| # The Hessian is initialized based on the handler's dimensions. | |
| self.H = ops.zeros((self.rows, self.rows), dtype="float32") | |
| def update_hessian_with_batch(self, inp): | |
| """Updates the Hessian matrix using a batch of input data.""" | |
| if len(inp.shape) > 2: | |
| inp = ops.reshape(inp, (-1, inp.shape[-1])) | |
| inp = ops.cast(inp, "float32") | |
| if self.H.shape[0] != inp.shape[-1]: | |
| raise ValueError( | |
| f"Hessian dimensions ({self.H.shape[0]}) do not " | |
| f"match input features ({inp.shape[-1]})." | |
| ) | |
| current_H = 2 * ops.matmul(ops.transpose(inp), inp) | |
| # Update the Hessian with a weighted average. | |
| if self.nsamples == 0: | |
| self.H = current_H | |
| else: | |
| total_samples = self.nsamples + inp.shape[0] | |
| self.H = self.H * (self.nsamples / total_samples) | |
| self.H += current_H * (inp.shape[0] / total_samples) | |
| self.nsamples += inp.shape[0] | |
| def quantize_and_correct_block( | |
| self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False | |
| ): | |
| # Use the effective 2D kernel from the handler. | |
| W = ops.transpose(ops.cast(self.handler.kernel, "float32")) | |
| H = ops.cast(self.H, "float32") | |
| if actorder: | |
| perm = ops.argsort(-ops.diagonal(H)) | |
| W = ops.take(W, perm, axis=1) | |
| H = ops.take(ops.take(H, perm, axis=0), perm, axis=1) | |
| invperm = ops.argsort(perm) | |
| # Dampen the Hessian matrix for numerical stability. | |
| diag_H = ops.diagonal(H) | |
| dead = ops.equal(diag_H, 0.0) | |
| diag_H = ops.where(dead, 1.0, diag_H) | |
| H = H + ops.diag(ops.where(dead, 1.0, ops.zeros_like(diag_H))) | |
| damp = percdamp * ops.mean(diag_H) | |
| diag_H = diag_H + damp | |
| H = (H - ops.diag(ops.diagonal(H))) + ops.diag(diag_H) | |
| Hinv = ops.linalg.inv(H) | |
| Q = ops.zeros_like(W) | |
| # Process the kernel in blocks. | |
| for i1 in range(0, self.rows, blocksize): | |
| i2 = min(i1 + blocksize, self.rows) | |
| count = i2 - i1 | |
| W1 = W[:, i1:i2] | |
| Q1 = ops.zeros_like(W1) | |
| Err1 = ops.zeros_like(W1) | |
| Hinv1 = Hinv[i1:i2, i1:i2] | |
| for i in range(count): | |
| w = W1[:, i] | |
| d = Hinv1[i, i] | |
| # Find quantization parameters. | |
| if groupsize != -1: | |
| if (i1 + i) % groupsize == 0: | |
| self.quantizer.find_params( | |
| W[:, (i1 + i) : (i1 + i + groupsize)], weight=True | |
| ) | |
| else: | |
| self.quantizer.find_params( | |
| ops.expand_dims(w, 1), weight=True | |
| ) | |
| # Quantize the current column. | |
| q = quantize( | |
| ops.expand_dims(w, 1), | |
| self.quantizer.scale, | |
| self.quantizer.zero, | |
| self.quantizer.maxq, | |
| )[:, 0] | |
| Q1 = ops.slice_update(Q1, (0, i), ops.expand_dims(q, axis=1)) | |
| err = (w - q) / d | |
| Err1 = ops.slice_update( | |
| Err1, (0, i), ops.expand_dims(err, axis=1) | |
| ) | |
| # Update the remaining weights in the block with the error. | |
| if i < count - 1: | |
| update = ops.matmul( | |
| ops.expand_dims(err, 1), | |
| ops.expand_dims(Hinv1[i, i + 1 :], 0), | |
| ) | |
| slice_to_update = W1[:, i + 1 :] | |
| updated_slice = slice_to_update - update | |
| W1 = ops.slice_update(W1, (0, i + 1), updated_slice) | |
| Q = ops.concatenate([Q[:, :i1], Q1, Q[:, i2:]], axis=1) | |
| # Update the remaining weights outside the block. | |
| if i2 < self.rows: | |
| update_total = ops.matmul(Err1, Hinv[i1:i2, i2:]) | |
| W = ops.concatenate( | |
| [W[:, :i2], W[:, i2:] - update_total], axis=1 | |
| ) | |
| if actorder: | |
| Q = ops.take(Q, invperm, axis=1) | |
| Q = ops.transpose(Q) | |
| # Reshape the quantized kernel back to its original shape. | |
| Q = ops.reshape(Q, self.handler.original_kernel_shape) | |
| new_weights = [ops.convert_to_numpy(Q)] | |
| if self.original_layer.bias is not None: | |
| new_weights.append(ops.convert_to_numpy(self.original_layer.bias)) | |
| self.original_layer.set_weights(new_weights) | |
| def free(self): | |
| """Frees the memory used by the Hessian matrix.""" | |
| self.H = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Layer: | |
| ... | |
| def get_gptq_handler(self) -> "GPTQHandler | None": | |
| """ | |
| Returns a handler with the necessary information for GPTQ. | |
| Layers supporting GPTQ must override this method. The handler should | |
| provide an effective 2D view of the kernel to be quantized. The | |
| default implementation indicates that the layer is not supported. | |
| """ | |
| # By default, a layer does not support GPTQ. | |
| return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment