Last active
May 18, 2023 13:50
-
-
Save bquast/169c42090e4337c5f4023ac46ce694f2 to your computer and use it in GitHub Desktop.
R implementation of attention, see blog post: https://qua.st/attention-in-R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# attention.R | |
# Bastiaan Quast | |
# [email protected] | |
# based on: | |
# https://machinelearningmastery.com/the-attention-mechanism-from-scratch/ | |
# encoder representations of four different words | |
word_1 = matrix(c(1,0,0), nrow=1) | |
word_2 = matrix(c(0,1,0), nrow=1) | |
word_3 = matrix(c(1,1,0), nrow=1) | |
word_4 = matrix(c(0,0,1), nrow=1) | |
# stacking the word embeddings into a single array | |
words = rbind(word_1, | |
word_2, | |
word_3, | |
word_4) | |
# generating the weight matrices | |
set.seed(42) | |
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3) | |
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3) | |
W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3) | |
# redefine matrices to match random numbers generated by Python in the original code | |
W_Q = matrix(c(2,0,2, | |
2,0,0, | |
2,1,2), | |
nrow=3, | |
ncol=3, | |
byrow = TRUE) | |
W_K = matrix(c(2,2,2, | |
0,2,1, | |
0,1,1), | |
nrow=3, | |
ncol=3, | |
byrow = TRUE) | |
W_V = matrix(c(1,1,0, | |
0,1,1, | |
0,0,0), | |
nrow=3, | |
ncol=3, | |
byrow = TRUE) | |
# generating the queries, keys and values | |
Q = words %*% W_Q | |
K = words %*% W_K | |
V = words %*% W_V | |
# scoring the query vectors against all key vectors | |
scores = Q %*% t(K) | |
# calculate the max for each row of the scores matrix | |
maxs = as.matrix(apply(scores, MARGIN=1, FUN=max)) | |
# initialize weights matrix | |
weights = matrix(0, nrow=4, ncol=4) | |
# computing the weights by a softmax operation | |
for (i in 1:dim(scores)[1]) { | |
weights[i,] = exp((scores[i,]-maxs[i,]) / ncol(K) ^ 0.5)/sum(exp((scores[i,]-maxs[i,]) / ncol(K) ^ 0.5)) | |
} | |
# computing the attention by a weighted sum of the value vectors | |
attention = weights %*% V | |
print(attention) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment