Created
June 1, 2023 20:05
-
-
Save bquast/0ea5f0f7b3a43b49ae35c3d6651df615 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # transformer simple.R | |
| # Bastiaan Quast | |
| softmax <- function(x) { | |
| e_x <- exp(x - max(x)) # subtract max to avoid numerical instability | |
| return(e_x / sum(e_x)) | |
| } | |
| # Initialize input queries, keys, and values | |
| query <- rnorm(5) | |
| keys <- matrix(rnorm(25), nrow=5, ncol=5) | |
| values <- matrix(rnorm(25), nrow=5, ncol=5) | |
| # Initialize weights for a simple feed-forward layer | |
| weights <- matrix(rnorm(25), nrow=5, ncol=5) | |
| # Assume we have some target values | |
| target <- c(1.0, 1.5, 2.0, 2.5, 3.0) | |
| learning_rate <- 0.01 # This is a hyperparameter that you may need to tune | |
| # Training loop | |
| for (epoch in 1:100) { # Let's run for 100 epochs | |
| # Forward pass through the attention mechanism | |
| dot_products <- query %*% keys | |
| attention_weights <- softmax(dot_products) | |
| attention_output <- attention_weights %*% values | |
| # Forward pass through the feed-forward layer | |
| output <- attention_output %*% weights | |
| # Calculate the loss: mean squared error | |
| loss <- mean((output - target)^2) | |
| print(paste("Epoch", epoch, ", Loss:", loss)) | |
| # Backward pass through the feed-forward layer | |
| output_grad <- 2.0 * (output - target) / length(output) # derivative of MSE loss | |
| weights_grad <- tcrossprod(attention_output, output_grad) | |
| # Update the parameters using gradient descent | |
| weights <- weights - learning_rate * weights_grad | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment