Created
April 17, 2025 13:00
-
-
Save Birch-san/05c1cd809ae38b1fecec486f096e5ed7 to your computer and use it in GitHub Desktop.
why can't I invoke vmapped attention with a mask? why doesn't vmap unbind my mask's batch dim?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
import torch | |
from torch import FloatTensor, BoolTensor, Tensor, inference_mode | |
from torch.func import functional_call, stack_module_state | |
from torch.nn import Module, Linear | |
from torch.nn.functional import scaled_dot_product_attention | |
from einops import rearrange | |
class Attention(Module): | |
def __init__( | |
self, | |
in_features: int, | |
head_dim: int, | |
n_heads: int, | |
device: Optional[torch.device | str | int] = None, | |
dtype: Optional[torch.dtype] = None, | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.n_heads = n_heads | |
self.head_dim = head_dim | |
self.qkv = Linear(in_features=in_features, out_features=3 * head_dim * n_heads, bias=False, **factory_kwargs) | |
def init_weights(self, generator: Optional[torch.Generator] = None) -> None: | |
self.qkv.weight.normal_(generator=generator) | |
def forward(self, x: FloatTensor, mask: Optional[BoolTensor] = None) -> FloatTensor: | |
""" | |
Args: | |
x: tensor (batch seq chan) | |
mask: tensor (batch seq seq) | |
""" | |
qkv: FloatTensor = self.qkv(x) | |
q, k, v = rearrange(qkv, "... seq (proj n_head head_dim) -> proj ... n_head seq head_dim", proj=3, n_head=self.n_heads).unbind() | |
# if we follow the torch docs and don't specify a head dim in our mask, we get this error for both vmapped and non-vmapped invocation: | |
# q.shape | |
# torch.Size([4, 8, 128, 64]) | |
# k.shape | |
# torch.Size([4, 8, 128, 64]) | |
# mask.shape | |
# torch.Size([4, 128, 128]) | |
# RuntimeError: The expanded size of the tensor (8) must match the existing size (4) at non-singleton dimension 1. Target sizes: [4, 8, 128, 128]. Tensor sizes: [4, 128, 128] | |
# if we *don't* follow the torch docs and *do* specify a head dim in our mask (which is what I'm more familiar with), | |
# non-vmapped works fine, | |
# but vmapped gives us: | |
# q.shape | |
# torch.Size([4, 8, 128, 64]) | |
# k.shape | |
# torch.Size([4, 8, 128, 64]) | |
# mask.shape | |
# torch.Size([4, 1, 128, 128]) | |
# RuntimeError: attn_bias: wrong shape (batch dimension) | |
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html | |
a = scaled_dot_product_attention( | |
q, | |
k, | |
v, | |
attn_mask=mask, | |
) | |
return a | |
in_features = 320 | |
head_dim = 64 | |
n_heads = 8 | |
dtype = torch.float16 | |
device = torch.device('cuda') | |
ensemble_size = 3 | |
with torch.device('meta'): | |
attns: list[Attention] = [Attention( | |
in_features=in_features, | |
head_dim=head_dim, | |
n_heads=n_heads, | |
dtype=dtype | |
) for _ in range(ensemble_size + 1)] | |
for attn in attns: | |
attn.to_empty(device=device) | |
attn.eval() | |
attn.requires_grad_(False) | |
stateless_attn, *experts = attns | |
gen = torch.Generator(device=device) | |
for ix, expert in enumerate(experts): | |
expert.init_weights(generator=gen.manual_seed(ix)) | |
params, buffers = stack_module_state(experts) | |
def bound_vmappable_fwd( | |
params: dict[str, Tensor], | |
buffers: dict[str, Tensor], | |
x: FloatTensor, | |
mask: Optional[BoolTensor] = None, | |
) -> FloatTensor: | |
return functional_call( | |
stateless_attn, | |
(params, buffers), | |
(x,), | |
{'mask': mask}, | |
tie_weights=False, | |
strict=True, | |
) | |
def dispatch_attn(x: FloatTensor, mask: Optional[BoolTensor] = None) -> FloatTensor: | |
# broadcast x over the ensemble | |
x_broadcast = x.unsqueeze(0).expand(ensemble_size, *(-1,)*x.ndim) | |
# ... I *should* broadcast mask over the ensemble too, right? | |
# but for some reason, it's only for x that torch.vmap unbinds-and-dispatches dim 0. | |
# for the mask kwarg, torch.vmap just passes the same mask to all ensemble members. | |
# why are they treated differently? | |
mask_broadcast = mask | |
# if mask is None: | |
# mask_broadcast: Optional[BoolTensor] = None | |
# else: | |
# mask_broadcast = mask.unsqueeze(0).expand(ensemble_size, *(-1,)*mask.ndim) | |
return torch.vmap(bound_vmappable_fwd)(params, buffers, x_broadcast, mask=mask_broadcast) | |
batch_size = 4 | |
seq_len = 128 | |
with inference_mode(): | |
x = torch.randn((batch_size, seq_len, in_features), device=device, dtype=dtype) | |
# the 1 is a singleton dim for broadcasting over heads | |
mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) | |
# non-vmapped works fine, so long as I specify a head dim in the mask | |
out_novmap: FloatTensor = experts[0].forward(x, mask=mask) | |
out: FloatTensor = dispatch_attn(x, mask=mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment