Last active
July 12, 2024 15:42
-
-
Save wolfecameron/d6ad12f6c663c73f80258fa988e465b2 to your computer and use it in GitHub Desktop.
Basic PyTorch implementation of masked self-attention with a single attention head.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Source: https://github.com/karpathy/nanoGPT/blob/master/model.py | |
""" | |
import math | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class MaskedSelfAttention(nn.Module): | |
def __init__( | |
self, | |
d, | |
T, | |
bias=False, | |
dropout=0.2, | |
): | |
""" | |
Arguments: | |
d: size of embedding dimension | |
T: maximum length of input sequences (in tokens) | |
bias: whether or not to use bias in linear layers | |
dropout: probability of dropout | |
""" | |
super().__init__() | |
self.d = d | |
# key, query, value projections for all heads, but in a batch | |
# output is 3X the dimension because it includes key, query and value | |
self.c_attn = nn.Linear(d, 3*d, bias=bias) | |
# causal mask to ensure that attention is only applied to | |
# the left in the input sequence | |
self.register_buffer("mask", torch.tril(torch.ones(T, T)) | |
.view(1, 1, T, T)) | |
def forward(self, x): | |
B, T, _ = x.size() # batch size, sequence length, embedding dimensionality | |
# compute query, key, and value vectors for all heads in batch | |
# split the output into separate query, key, and value tensors | |
q, k, v = self.c_attn(x).split(self.d, dim=2) # [B, T, d] | |
# compute the attention matrix, perform masking, and apply dropout | |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # [B, T, T] | |
att = att.masked_fill(self.bias[:,:T,:T] == 0, float('-inf')) | |
att = F.softmax(att, dim=-1) | |
# compute output vectors for each token | |
y = att @ v # [B, T, d] | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment