Skip to content

Instantly share code, notes, and snippets.

@theobjectivedad
Created April 4, 2024 12:25
Show Gist options
  • Save theobjectivedad/11440f1aa65ee6e26cb1e23ae3424a3d to your computer and use it in GitHub Desktop.
Save theobjectivedad/11440f1aa65ee6e26cb1e23ae3424a3d to your computer and use it in GitHub Desktop.
AutoGPTQ Issue 459 Workaround

This is a monkeypatch to workaround AutoGPTQ issue 459:

import os
import time
from logging import getLogger

import torch
import torch.nn as nn
import transformers

logger = getLogger(__name__)


def _fasterquant(
    self,
    blocksize=128,
    percdamp=0.01,
    group_size=-1,
    actorder=False,
    static_groups=False,
):
    W = self.layer.weight.data.clone()
    if isinstance(self.layer, nn.Conv2d):
        W = W.flatten(1)
    if isinstance(self.layer, transformers.Conv1D):
        W = W.t()
    W = W.float()

    tick = time.time()

    if not self.quantizer.ready():
        self.quantizer.find_params(W, weight=True)

    H = self.H
    del self.H
    dead = torch.diag(H) == 0
    H[dead, dead] = 1
    W[:, dead] = 0

    g_idx = []
    scale = []
    zero = []
    now_idx = 1

    if static_groups:
        import copy

        groups = []
        for i in range(0, self.columns, group_size):
            quantizer = copy.deepcopy(self.quantizer)
            quantizer.find_params(W[:, i : (i + group_size)], weight=True)
            scale.append(quantizer.scale)
            zero.append(quantizer.zero)
            groups.append(quantizer)

    if actorder:
        perm = torch.argsort(torch.diag(H), descending=True)
        W = W[:, perm]
        H = H[perm][:, perm]
        invperm = torch.argsort(perm)

    Losses = torch.zeros_like(W)
    Q = torch.zeros_like(W)

    damp = percdamp * torch.mean(torch.diag(H))
    diag = torch.arange(self.columns, device=self.dev)
    H[diag, diag] += damp
    H = torch.linalg.cholesky(H)
    H = torch.cholesky_inverse(H)
    H = torch.linalg.cholesky(H, upper=True)
    Hinv = H

    for i1 in range(0, self.columns, blocksize):
        i2 = min(i1 + blocksize, self.columns)
        count = i2 - i1

        W1 = W[:, i1:i2].clone()
        Q1 = torch.zeros_like(W1)
        Err1 = torch.zeros_like(W1)
        Losses1 = torch.zeros_like(W1)
        Hinv1 = Hinv[i1:i2, i1:i2]

        for i in range(count):
            w = W1[:, i]
            d = Hinv1[i, i]

            if group_size != -1:
                if not static_groups:
                    if (i1 + i) % group_size == 0:
                        self.quantizer.find_params(
                            W[:, (i1 + i) : (i1 + i + group_size)], weight=True
                        )

                    if ((i1 + i) // group_size) - now_idx == -1:
                        scale.append(self.quantizer.scale)
                        zero.append(self.quantizer.zero)
                        now_idx += 1
                else:
                    idx = i1 + i
                    if actorder:
                        idx = perm[idx]
                    self.quantizer = groups[idx // group_size]

            q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
            Q1[:, i] = q
            Losses1[:, i] = (w - q) ** 2 / d**2

            err1 = (w - q) / d
            W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
            Err1[:, i] = err1

        Q[:, i1:i2] = Q1
        Losses[:, i1:i2] = Losses1 / 2

        W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

        if os.environ.get("DEBUG"):
            self.layer.weight.data[:, :i2] = Q[:, :i2]
            self.layer.weight.data[:, i2:] = W[:, i2:]
            logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
            logger.debug(torch.sum(Losses))

    torch.cuda.synchronize()
    logger.info(f"duration: {(time.time() - tick)}")

    # Is this bug?
    # See : https://github.com/PanQiWei/AutoGPTQ/issues/459
    try:
        logger.info(f"avg loss: {torch.sum(Losses).item() / self.nsamples}")
    except:
        ...

    group_size = group_size if group_size != -1 else self.columns
    if static_groups and actorder:
        g_idx = [perm[i] // group_size for i in range(self.columns)]
    else:
        g_idx = [i // group_size for i in range(self.columns)]
    g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
    if actorder:
        Q = Q[:, invperm]
        g_idx = g_idx[invperm]

    if isinstance(self.layer, transformers.Conv1D):
        Q = Q.t()
    self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(
        self.layer.weight.data
    )
    if os.environ.get("DEBUG"):
        logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))

    if scale == []:
        scale.append(self.quantizer.scale)
        zero.append(self.quantizer.zero)
    scale = torch.cat(scale, dim=1)
    zero = torch.cat(zero, dim=1)
    return scale, zero, g_idx



def hijack_fasterquant():
    import auto_gptq.quantization.gptq

    auto_gptq.quantization.gptq.GPTQ.fasterquant = _fasterquant

Obviously this isn't a best practice so use at your own risk!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment