-
-
Save sdwfrost/85cd798f7cad865a6bb999ca9be71817 to your computer and use it in GitHub Desktop.
An efficient, batched LSTM in R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### | |
### This is a batched LSTM forward and backward pass. Written by Andrej Karpathy (@karpathy) | |
### BSD License | |
### Re-written in R by @georgeblck | |
### | |
rm(list=ls(all=TRUE)) | |
LSTM.init <- function(input_size, hidden_size, fancy_forget_bias_init = 3){ | |
# Initialize parameters of the LSTM (both weights and biases in one matrix) | |
# One might way to have a positive fancy_forget_bias_init number (e.g. maybe even up to 5, in some papers) | |
# +1 for the biases, which will be the first row of WLSTM | |
WLSTM <- matrix(rnorm((input_size+hidden_size+1)*(4*hidden_size)), | |
input_size+hidden_size+1, 4 * hidden_size) / sqrt(input_size + hidden_size) | |
WLSTM[1,] <- 0 # initialize biases to zero | |
if (fancy_forget_bias_init != 0){ | |
# forget gates get little bit negative bias initially to encourage them to be turned off | |
# remember that due to Xavier initialization above, the raw output activations from gates before | |
# nonlinearity are zero mean and on order of standard deviation ~1 | |
WLSTM[1, (hidden_size+1):(2*hidden_size)] <- fancy_forget_bias_init | |
} | |
return(WLSTM) | |
} | |
LSTM.forward <- function(X, WLSTM, c0 = NULL, h0 = NULL){ | |
# X should be of shape (n,b,input_size) | |
# where n = length of sequence, b = batch size | |
n <- dim(X)[1]; b <- dim(X)[2]; input_size = dim(X)[3] | |
d <- dim(WLSTM)[2]/4 # hidden size | |
if (is.null(c0)){ | |
c0 <- matrix(0, b, d) | |
} | |
if (is.null(h0)){ | |
h0 <- matrix(0, b, d) | |
} | |
# Perform the LSTM forward pass with X as the input | |
xphpb <- dim(WLSTM)[1] # x plus h plus bias, lol | |
Hin <- array(0, c(n, b, xphpb))# input [1, xt, ht-1] to each tick of the LSTM | |
Hout <- array(0, c(n, b, d))# hidden representation of the LSTM (gated cell content) | |
IFOG <- array(0, c(n, b, d*4))# input, forget, output, gate (IFOG) | |
IFOGf <- array(0, c(n, b, d*4))# after nonlinearity | |
C <- array(0, c(n, b, d)) # cell content | |
Ct <- C # tanh of cell content | |
for (t in 1:n){ | |
if (t > 1) { | |
prevh <- Hout[t-1, , ] | |
} else { | |
prevh <- h0 | |
} | |
Hin[t, , 1] <- 1 # bias | |
Hin[t, , 2:(input_size+1)] <- X[t, , ] | |
Hin[t, , (input_size+2):xphpb] <- prevh | |
# compute all gate activations. dots: (most work is this line) | |
IFOG[t, , ] <- Hin[t, , ] %*% WLSTM | |
# non-linearities | |
IFOGf[t, , 1:(3*d)] <- 1.0/(1.0 + exp(-IFOG[t, , 1:(3*d)]))# sigmoids; these are the gates | |
IFOGf[t, , (3*d+1):(d*4)] <- tanh(IFOG[t, , (3*d+1):(d*4)]) # tanh | |
# compute the cell activation | |
if (t > 1){ | |
prevc <- C[t-1, , ] | |
} else { | |
prevc <- c0 | |
} | |
C[t, , ] <- IFOGf[t, ,1:d] * IFOGf[t, , (3*d+1):(4*d)] + IFOGf[t, , (d+1):(2*d)] * prevc | |
Ct[t, , ] <- tanh(C[t, , ]) | |
Hout[t, , ] <- IFOGf[t, , (2*d+1):(3*d)] * Ct[t, , ] | |
} | |
cached <- list(WLSTM = WLSTM, Hout = Hout, IFOGf = IFOGf, IFOG = IFOG, | |
C = C, Ct = Ct, Hin = Hin, c0 = c0, h0 = h0) | |
return(list(Hout = Hout, C = C[t, , ], Hout_t = Hout[t, , ], cached = cached)) | |
} | |
LSTM.backward <- function(dHout_in, cache, dcn = NULL, dhn = NULL){ | |
WLSTM <- cache$WLSTM | |
Hout <- cache$Hout | |
IFOGf <- cache$IFOGf | |
IFOG <- cache$IFOG | |
C <- cache$C | |
Ct <- cache$Ct | |
Hin <- cache$Hin | |
c0 <- cache$c0 | |
h0 <- cache$ho | |
n <- dim(Hout)[1]; b <- dim(Hout)[2]; d = dim(Hout)[3] | |
input_size <- dim(WLSTM)[1]- d - 1 # -1 due to bias | |
# backprop the LSTM | |
dIFOG <- array(0, dim(IFOG)) | |
dIFOGf <- array(0, dim(IFOGf)) | |
dWLSTM <- array(0, dim(WLSTM)) | |
dHin <- array(0, dim(Hin)) | |
dC <- array(0, dim(C)) | |
dX <- array(0, c(n, b, input_size)) | |
dh0 <- matrix(0, b, d) | |
dc0 <- matrix(0, b, d) | |
dHout <- dHout_in | |
if (!is.null(dcn)){# carry over gradients from later | |
dC[n,,] <- dC[n,,] + dcn | |
} | |
if (!is.null(dhn)){ | |
dHout[n,,] <- dHout[n,,] + dhn | |
} | |
# Do the Backprop | |
for (t in n:1){ | |
tanhCt <- Ct[t,,] | |
dIFOGf[t, , (2*d+1):(3*d)] <- tanhCt * dHout[t, , ] | |
# backprop tanh non-linearity first then continue backprop | |
dC[t, , ] <- dC[t, , ] + (1 - tanhCt^2) * (IFOGf[t, , (2*d+1):(3*d)] * dHout[t, , ]) | |
if(t > 1){ | |
dIFOGf[t, , (d+1):(2*d)] <- C[t-1, , ] * dC[t, , ] | |
dC[t-1, , ] <- dC[t-1, , ] + IFOGf[t, , (d+1):(2*d)] * dC[t, , ] | |
} else { | |
dIFOGf[t, , (d+1):(d*2)] <- c0 * dC[t, , ] | |
dc0 <- IFOGf[t, , (d+1):(2*d)] * dC[t, , ] | |
} | |
dIFOGf[t, , 1:d] <- IFOGf[t, , (3*d+1):(4*d)] * dC[t, , ] | |
dIFOGf[t, , (3*d+1):(4*d)] <- IFOGf[t, , 1:d] * dC[t, , ] | |
dIFOG[t, , (3*d+1):(4*d)] <- (1 - IFOGf[t, , (3*d+1):(4*d)]^2) * dIFOGf[t, , (3*d+1):(4*d)] | |
y <- IFOGf[t, , 1:(3*d)] | |
dIFOG[t, , 1:(3*d)] <- (y * (1.0 - y)) * dIFOGf[t, , 1:(3*d)] | |
# backprop matrix multiply | |
dWLSTM <- dWLSTM + t(Hin[t, , ]) %*% dIFOG[t, , ] | |
dHin[t, , ] <- dIFOG[t, , ] %*% t(WLSTM) | |
# backprop the identity transforms into Hin | |
dX[t, , ] <- dHin[t, , 2:(input_size+1)] | |
if(t > 1){ | |
dHout[t-1, , ] <- dHout[t-1, , ] + dHin[t, , (input_size+2):dim(dHin)[3]] | |
} else { | |
dh0 <- dh0 + dHin[t, , (input_size+2):dim(dHin)[3]] | |
} | |
} | |
return(list(dX = dX, dWLSTM = dWLSTM, dc0 = dc0, dh0 = dh0)) | |
} | |
checkSequentialMatchesBatch <- function(){ | |
n <- 5; b <- 3; d <- 4 # sequence length, batch size, hidden size | |
input_size <- 10 | |
WLSTM <- LSTM.init(input_size, d) # input size, hidden size | |
X <- array(rnorm(prod(n ,b, input_size)), c(n, b, input_size)) | |
h0 <- matrix(rnorm(b * d), b, d) | |
c0 <- matrix(rnorm(b * d), b, d) | |
# sequential forward | |
cprev <- c0 | |
hprev <- h0 | |
caches <- vector("list", n) | |
Hcat <- array(0, c(n, b, d)) | |
for (t in 1:n){ | |
xt <- X[t,,,drop=FALSE] | |
seq_res <- LSTM.forward(xt, WLSTM, cprev, hprev) | |
cprev <- seq_res$C; hprev <- seq_res$Hout_t; caches[[t]] <- seq_res$cached | |
Hcat[t,,] <- hprev | |
} | |
# sanity check: perform batch forward to check that we get the same thing | |
batch_res <- LSTM.forward(X, WLSTM, c0, h0) | |
H <- batch_res$Hout; batch_cache <- batch_res$cached | |
if(!all.equal(H, Hcat)){ | |
print('Sequential and Batch forward dont match!') | |
} else { | |
print("All good with the forwarding") | |
} | |
# eval loss | |
wrand <- array(rnorm(prod(dim(Hcat))), dim(Hcat)) | |
loss <- sum(Hcat * wrand) | |
dH <- wrand | |
# get the batched version gradients | |
batch_bwd <- LSTM.backward(dH, batch_cache) | |
BdX <- batch_bwd$dX ; BdWLSTM <- batch_bwd$dWLSTM | |
Bdh0 <- batch_bwd$dh0; Bdc0 <- batch_bwd$dc0 | |
# now perform sequential backward | |
dX <- array(0, dim(X)) | |
dWLSTM <- array(0, dim(WLSTM)) | |
dc0 <- matrix(0, dim(c0)) | |
dh0 <- matrix(0, dim(h0)) | |
dcnext <- NULL | |
dhnext <- NULL | |
for (t in n:1){ | |
dht <- dH[t,,,drop=FALSE] | |
seq_bwd <- LSTM.backward(dht, caches[[t]], dcnext, dhnext) | |
dx <- seq_bwd$dX; dWLSTMt <- seq_bwd$dWLSTM | |
dcprev <- seq_bwd$dc0; dhprev <- seq_bwd$dh0 | |
dhnext <- dhprev | |
dcnext <- dcprev | |
dWLSTM <- dWLSTM + dWLSTMt # accumulate LSTM gradient | |
dX[t,,] <- dx[1,,] | |
if(t == 1){ | |
dc0 <- dcprev | |
dh0 <- dhprev | |
} | |
} | |
# and make sure the gradients match | |
print('Making sure batched version agrees with sequential version: (should all be True)') | |
print(all.equal(BdX, dX)) | |
print(all.equal(BdWLSTM, dWLSTM)) | |
print(all.equal(Bdc0, dc0)) | |
print(all.equal(Bdh0, dh0)) | |
} | |
checkBatchGradient <- function(){ | |
### check that the batch gradient is correct ### | |
# lets gradient check this beast | |
n <- 5; b <- 3; d <- 4 # sequence length, batch size, hidden size | |
input_size <- 10 | |
WLSTM <- LSTM.init(input_size, d) # input size, hidden size | |
X <- array(rnorm(prod(c(n,b,input_size))), c(n,b,input_size)) | |
h0 <- matrix(rnorm(b * d), b,d) | |
c0 <- matrix(rnorm(b * d), b,d) | |
# batch forward backward | |
batch_res <- LSTM.forward(X, WLSTM, c0, h0) | |
H <- batch_res$Hout; Ct <- batch_res$C | |
Ht <- batch_res$Hout_t; cache <- batch_res$cached | |
wrand <- array(rnorm(prod(dim(H))), dim(H)) | |
loss <- sum(H * wrand) # weighted sum is a nice hash to use I think | |
dH <- wrand | |
batch_bwd <- LSTM.backward(dH, cache) | |
dX <- batch_bwd$dX | |
dWLSTM <- batch_bwd$dWLSTM | |
dc0 <- batch_bwd$dc0 | |
dh0 <- batch_bwd$dh0 | |
fwd <- function(we){ | |
h <- LSTM.forward(we$X, we$WLSTM, we$c0, we$h0)$Hout | |
return(sum(h*wrand)) | |
} | |
# now gradient check all | |
delta <- 0.00001 | |
rel_error_thr_warning <- 0.01 | |
rel_error_thr_error <- 1 | |
tocheck <- list(X = X, WLSTM = WLSTM, c0 = c0, h0 = h0) | |
grads_analytic <- list(dX = dX, dWLSTM = dWLSTM, dc0 = dc0, dh0 = dh0) | |
names <- c("X", "WLSTM", "c0", "h0") | |
doit <- mapply(FUN = function(w, dw, names, we.true){ | |
w.temp <- w | |
cat("\nWeight:", names, "\n") | |
for (i in 1:length(w)){ | |
temped <- c(w) | |
old_val <- temped[i] | |
temped[i] <- old_val + delta | |
we.true[[names]] <- array(temped, dim(w.temp)) | |
loss0 <- fwd(we.true) | |
temped[i] <- old_val - delta | |
we.true[[names]] <- array(temped, dim(w.temp)) | |
loss1 <- fwd(we.true) | |
grad_analytic <- c(dw)[i] | |
grad_numerical <- (loss0 - loss1) / (2 * delta) | |
if (grad_analytic == grad_numerical){ | |
rel_error <- 0 | |
status <- cat("\nOK\n") | |
} else if (abs(grad_numerical) < 1e-7 & abs(grad_analytic) < 1e-7){ | |
rel_error <- 0 | |
status <- cat("\nVAL SMALL WARNING\n") | |
} else { | |
rel_error <- abs(grad_analytic - grad_numerical) / abs(grad_numerical + grad_analytic) | |
status <- "OK" | |
if(rel_error > rel_error_thr_warning) | |
cat("\nWARNING\n") | |
if (rel_error > rel_error_thr_error) | |
cat("\n!!!NOTOK\n") | |
} | |
cat(status, "checking param", names, "index", i,"von" ,dim(w), | |
"\n(val = ", old_val, "), analytic = ", grad_analytic, | |
", numerical = ", grad_numerical, "\nrelative error = ", rel_error,"\n\n") | |
} | |
}, w = tocheck, dw = grads_analytic, names = names, | |
MoreArgs = list(we.true = tocheck)) | |
} | |
checkSequentialMatchesBatch() | |
checkBatchGradient() | |
print ('every line should start with OK. Have a nice day!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment