Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created April 17, 2025 13:00
Show Gist options
  • Save Birch-san/05c1cd809ae38b1fecec486f096e5ed7 to your computer and use it in GitHub Desktop.
Save Birch-san/05c1cd809ae38b1fecec486f096e5ed7 to your computer and use it in GitHub Desktop.
why can't I invoke vmapped attention with a mask? why doesn't vmap unbind my mask's batch dim?
from typing import Optional
import torch
from torch import FloatTensor, BoolTensor, Tensor, inference_mode
from torch.func import functional_call, stack_module_state
from torch.nn import Module, Linear
from torch.nn.functional import scaled_dot_product_attention
from einops import rearrange
class Attention(Module):
def __init__(
self,
in_features: int,
head_dim: int,
n_heads: int,
device: Optional[torch.device | str | int] = None,
dtype: Optional[torch.dtype] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.qkv = Linear(in_features=in_features, out_features=3 * head_dim * n_heads, bias=False, **factory_kwargs)
def init_weights(self, generator: Optional[torch.Generator] = None) -> None:
self.qkv.weight.normal_(generator=generator)
def forward(self, x: FloatTensor, mask: Optional[BoolTensor] = None) -> FloatTensor:
"""
Args:
x: tensor (batch seq chan)
mask: tensor (batch seq seq)
"""
qkv: FloatTensor = self.qkv(x)
q, k, v = rearrange(qkv, "... seq (proj n_head head_dim) -> proj ... n_head seq head_dim", proj=3, n_head=self.n_heads).unbind()
# if we follow the torch docs and don't specify a head dim in our mask, we get this error for both vmapped and non-vmapped invocation:
# q.shape
# torch.Size([4, 8, 128, 64])
# k.shape
# torch.Size([4, 8, 128, 64])
# mask.shape
# torch.Size([4, 128, 128])
# RuntimeError: The expanded size of the tensor (8) must match the existing size (4) at non-singleton dimension 1. Target sizes: [4, 8, 128, 128]. Tensor sizes: [4, 128, 128]
# if we *don't* follow the torch docs and *do* specify a head dim in our mask (which is what I'm more familiar with),
# non-vmapped works fine,
# but vmapped gives us:
# q.shape
# torch.Size([4, 8, 128, 64])
# k.shape
# torch.Size([4, 8, 128, 64])
# mask.shape
# torch.Size([4, 1, 128, 128])
# RuntimeError: attn_bias: wrong shape (batch dimension)
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
a = scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
)
return a
in_features = 320
head_dim = 64
n_heads = 8
dtype = torch.float16
device = torch.device('cuda')
ensemble_size = 3
with torch.device('meta'):
attns: list[Attention] = [Attention(
in_features=in_features,
head_dim=head_dim,
n_heads=n_heads,
dtype=dtype
) for _ in range(ensemble_size + 1)]
for attn in attns:
attn.to_empty(device=device)
attn.eval()
attn.requires_grad_(False)
stateless_attn, *experts = attns
gen = torch.Generator(device=device)
for ix, expert in enumerate(experts):
expert.init_weights(generator=gen.manual_seed(ix))
params, buffers = stack_module_state(experts)
def bound_vmappable_fwd(
params: dict[str, Tensor],
buffers: dict[str, Tensor],
x: FloatTensor,
mask: Optional[BoolTensor] = None,
) -> FloatTensor:
return functional_call(
stateless_attn,
(params, buffers),
(x,),
{'mask': mask},
tie_weights=False,
strict=True,
)
def dispatch_attn(x: FloatTensor, mask: Optional[BoolTensor] = None) -> FloatTensor:
# broadcast x over the ensemble
x_broadcast = x.unsqueeze(0).expand(ensemble_size, *(-1,)*x.ndim)
# ... I *should* broadcast mask over the ensemble too, right?
# but for some reason, it's only for x that torch.vmap unbinds-and-dispatches dim 0.
# for the mask kwarg, torch.vmap just passes the same mask to all ensemble members.
# why are they treated differently?
mask_broadcast = mask
# if mask is None:
# mask_broadcast: Optional[BoolTensor] = None
# else:
# mask_broadcast = mask.unsqueeze(0).expand(ensemble_size, *(-1,)*mask.ndim)
return torch.vmap(bound_vmappable_fwd)(params, buffers, x_broadcast, mask=mask_broadcast)
batch_size = 4
seq_len = 128
with inference_mode():
x = torch.randn((batch_size, seq_len, in_features), device=device, dtype=dtype)
# the 1 is a singleton dim for broadcasting over heads
mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool)
# non-vmapped works fine, so long as I specify a head dim in the mask
out_novmap: FloatTensor = experts[0].forward(x, mask=mask)
out: FloatTensor = dispatch_attn(x, mask=mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment