Last active
April 17, 2024 21:33
-
-
Save crowsonkb/e62e9f685da9c185233f66de754f05ca to your computer and use it in GitHub Desktop.
Grouped linear layer using https://github.com/tgale96/grouped_gemm.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Grouped linear layer using https://github.com/tgale96/grouped_gemm.""" | |
from dataclasses import dataclass | |
import warnings | |
import torch | |
from torch import nn | |
try: | |
import grouped_gemm | |
_gmm_kernel = torch.compiler.disable(grouped_gemm.ops.gmm) | |
except ImportError: | |
warnings.warn("grouped_gemm not available, falling back to PyTorch implementation.") | |
_gmm_kernel = None | |
@torch.compiler.disable | |
def gmm_pytorch(a, b, batch_sizes, trans_b=False): | |
"""Grouped matrix multiplication using PyTorch.""" | |
if a.ndim != 2: | |
raise ValueError("a must be a 2D tensor") | |
if b.ndim != 3: | |
raise ValueError("b must be a 3D tensor") | |
if batch_sizes.ndim != 1: | |
raise ValueError("batch_sizes must be a 1D tensor") | |
if b.shape[0] != batch_sizes.shape[0]: | |
raise ValueError("b and batch_sizes must have the same number of groups") | |
a_split = torch.split(a, batch_sizes.tolist()) | |
b_split = torch.unbind(b.mT if trans_b else b) | |
c = [a_part @ b_part for a_part, b_part in zip(a_split, b_split)] | |
return torch.cat(c) | |
def gmm(a, b, batch_sizes, trans_b=False): | |
"""Grouped matrix multiplication.""" | |
device_ok = a.device.type == "cuda" and b.device.type == "cuda" | |
can_cast = torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16 | |
cast_not_needed = a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 | |
if _gmm_kernel is not None and device_ok and (can_cast or cast_not_needed): | |
return _gmm_kernel(a.bfloat16(), b.bfloat16(), batch_sizes.cpu(), trans_b) | |
return gmm_pytorch(a, b, batch_sizes, trans_b) | |
@dataclass | |
class GroupInfo: | |
"""Group information.""" | |
shape: torch.Size | |
ids_sorted: torch.Tensor | |
ids_indices: torch.Tensor | |
batch_sizes: torch.Tensor | |
def group(x, ids, n_groups): | |
"""Group a tensor by group IDs. | |
Args: | |
x: The input tensor. | |
ids: The group IDs. | |
n_groups: The number of groups. | |
Returns: | |
x: The grouped tensor. | |
info: The group information. | |
""" | |
if x.shape[:-1] != ids.shape: | |
raise ValueError( | |
f"shape mismatch: x.shape[:-1] is {tuple(x.shape[:-1])}, ids.shape is {tuple(ids.shape)}" | |
) | |
shape = ids.shape | |
x = x.flatten(0, -2) | |
ids = ids.flatten() | |
ids_sorted, ids_indices = torch.sort(ids, stable=True) | |
batch_sizes = torch.bincount(ids_sorted, minlength=n_groups).cpu() | |
return x[ids_indices], GroupInfo(shape, ids_sorted, ids_indices, batch_sizes) | |
def ungroup(x, info): | |
"""Ungroup a tensor. | |
Args: | |
x: The grouped tensor. | |
info: The group information. | |
Returns: | |
The ungrouped tensor. | |
""" | |
return torch.empty_like(x).index_put_((info.ids_indices,), x).view(*info.shape, x.shape[-1]) | |
class GroupedLinear(nn.Module): | |
"""Grouped linear layer.""" | |
def __init__(self, in_features, out_features, n_groups, bias=True): | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.n_groups = n_groups | |
self.weight = nn.Parameter(torch.empty(n_groups, out_features, in_features)) | |
self.bias = nn.Parameter(torch.empty(n_groups, out_features)) if bias else None | |
bound = in_features**-0.5 | |
nn.init.uniform_(self.weight, -bound, bound) | |
if bias: | |
nn.init.uniform_(self.bias, -bound, bound) | |
def extra_repr(self): | |
return f"in_features={self.in_features}, out_features={self.out_features}, n_groups={self.n_groups}, bias={self.bias is not None}" | |
def forward(self, x, info): | |
x = gmm(x, self.weight, info.batch_sizes, trans_b=True) | |
if self.bias is not None: | |
x = x + self.bias.to(x.dtype)[info.ids_sorted] | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment