Skip to content

Instantly share code, notes, and snippets.

@OneAdder
Last active February 14, 2025 17:29
Show Gist options
  • Save OneAdder/4a04ea345fe53e3afd8c7e3c0fcff131 to your computer and use it in GitHub Desktop.
Save OneAdder/4a04ea345fe53e3afd8c7e3c0fcff131 to your computer and use it in GitHub Desktop.
MultiHead Attention (Reference Implementation)
import torch
import numpy as np
mha = torch.nn.MultiheadAttention(num_heads=2, embed_dim=4, batch_first=True)
mha.in_proj_weight.data = torch.zeros(12, 4) + 0.1
mha.in_proj_bias.data = torch.zeros(12) + 0.11
mha.out_proj.weight.data = torch.zeros(4, 4) + 0.1
mha.out_proj.bias.data = torch.zeros(4) + 0.11
optim = torch.optim.SGD(mha.parameters(), lr=0.01)
x = torch.tensor(
np.array(
[0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12],
).reshape(3, 4, 1, order='F').transpose(2, 0, 1),
dtype=torch.float32,
requires_grad=True,
)
out, weights = mha(x, x, x)
print('Output:', np.array(np.nditer(out.detach().numpy(), order='F')))
print('Attention Weights:', np.array(np.nditer(weights.detach().numpy(), order='F')))
gradient = torch.tensor(
[.1, .1, .1, 3., 3., 3., 2., .1, 2., 3., .1, 3.],
requires_grad=True,
).reshape(1, 3, 4)
out.backward(gradient=gradient)
print('Gradient:', np.array(np.nditer(x.grad.numpy(), order='F')))
optim.step()
out, weights = mha(x, x, x)
print('Output after one step (SGD):', np.array(np.nditer(out.detach(), order='F')))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment