Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Created May 11, 2023 10:04
Show Gist options
  • Save vankesteren/8501c28e4094f1922fc69f7dadad946c to your computer and use it in GitHub Desktop.
Save vankesteren/8501c28e4094f1922fc69f7dadad946c to your computer and use it in GitHub Desktop.
LSTM in R with only one hidden unit
library(torch)
tot_obs <- 1000
x <- c(rnorm(.2*tot_obs), rnorm(.2*tot_obs, 5, 2), rnorm(.3*tot_obs), rnorm(.1*tot_obs, 5, 2), rnorm(.2*tot_obs))
x_torch <- torch_tensor(matrix(x))
plot(x, type = "l")
# Create a very simple lstm, with only one hidden node and linear activation for the output
SimpleLSTM <- nn_module("simplelstm",
initialize = function(obs_size, hidden_size) {
self$lstm <- nn_lstm(input_size = obs_size, hidden_size = hidden_size, num_layers = 1)
self$linear <- nn_linear(hidden_size, obs_size)
},
forward = function(input) {
self$lstm_out <- self$lstm(input)
self$hidden <- self$lstm_out[[1]]
self$final <- self$lstm_out[[2]]
return(torch_squeeze(self$linear(self$hidden)))
}
)
# let's try it out
net <- SimpleLSTM(1, 1)
net(x_torch)
# train the model
loss <- nn_mse_loss()
nepochs <- 150
opt <- optim_adam(net$parameters, lr = .1)
for (i in 1:nepochs) {
net$zero_grad()
l <- loss(torch_squeeze(x_torch), net(x_torch))
l$backward()
opt$step()
cat("iter", i, "loss", l$item(), "\r")
if (i %% 10 == 0) {
plot(x, type = "l", main = i, lwd = 1.5)
lines(net(x_torch), col = "blue")
}
}
# it does this with only 18 parameters!
net$parameters
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment