Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active April 28, 2025 01:22
Show Gist options
  • Save Birch-san/af52b0a53f807c6736b0e9f9b72fa123 to your computer and use it in GitHub Desktop.
Save Birch-san/af52b0a53f807c6736b0e9f9b72fa123 to your computer and use it in GitHub Desktop.
How to implement mm, bmm and matmul in pytorch via vmap
import torch
from torch import FloatTensor
def mm(a: FloatTensor, b: FloatTensor) -> FloatTensor:
assert a.ndim == 2
assert b.ndim == 2
assert a.size(-1) == b.size(-2)
assert a.size(-2) == b.size(-1)
# batched dot product
def bdp(a_row: FloatTensor, b: FloatTensor) -> FloatTensor:
return torch.vmap(torch.dot, in_dims=(None, -1))(a_row, b)
return torch.vmap(bdp, in_dims=(-2, None))(a, b)
def bmm(a: FloatTensor, b: FloatTensor) -> FloatTensor:
assert a.ndim == 3
assert b.ndim == 3
return torch.vmap(mm)(a, b)
def matmul(a: FloatTensor, b: FloatTensor) -> FloatTensor:
assert a.ndim >= 2
assert b.ndim >= 2
batch_dims = torch.broadcast_shapes(a.shape[:-2], b.shape[:-2])
a = a.broadcast_to((*batch_dims, *a.shape[-2:])).flatten(end_dim=-3)
b = b.broadcast_to((*batch_dims, *b.shape[-2:])).flatten(end_dim=-3)
return bmm(a, b).unflatten(-3, (batch_dims))
@Birch-san
Copy link
Author

device = torch.device('cuda')
dtype = torch.float16
gen = torch.Generator(device=device).manual_seed(0)
a = torch.randn(2, 4, 320, 64, dtype=dtype, device=device, generator=gen)
b = torch.randn(2, 4, 320, 64, dtype=dtype, device=device, generator=gen)

assert matmul(a, b.mT).equal(a @ b.mT) # passes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment