Skip to content

Instantly share code, notes, and snippets.

@sdwfrost
Forked from georgeblck/batch-lstm.R
Created May 17, 2017 11:53
Show Gist options
  • Save sdwfrost/85cd798f7cad865a6bb999ca9be71817 to your computer and use it in GitHub Desktop.
Save sdwfrost/85cd798f7cad865a6bb999ca9be71817 to your computer and use it in GitHub Desktop.
An efficient, batched LSTM in R
###
### This is a batched LSTM forward and backward pass. Written by Andrej Karpathy (@karpathy)
### BSD License
### Re-written in R by @georgeblck
###
rm(list=ls(all=TRUE))
LSTM.init <- function(input_size, hidden_size, fancy_forget_bias_init = 3){
# Initialize parameters of the LSTM (both weights and biases in one matrix)
# One might way to have a positive fancy_forget_bias_init number (e.g. maybe even up to 5, in some papers)
# +1 for the biases, which will be the first row of WLSTM
WLSTM <- matrix(rnorm((input_size+hidden_size+1)*(4*hidden_size)),
input_size+hidden_size+1, 4 * hidden_size) / sqrt(input_size + hidden_size)
WLSTM[1,] <- 0 # initialize biases to zero
if (fancy_forget_bias_init != 0){
# forget gates get little bit negative bias initially to encourage them to be turned off
# remember that due to Xavier initialization above, the raw output activations from gates before
# nonlinearity are zero mean and on order of standard deviation ~1
WLSTM[1, (hidden_size+1):(2*hidden_size)] <- fancy_forget_bias_init
}
return(WLSTM)
}
LSTM.forward <- function(X, WLSTM, c0 = NULL, h0 = NULL){
# X should be of shape (n,b,input_size)
# where n = length of sequence, b = batch size
n <- dim(X)[1]; b <- dim(X)[2]; input_size = dim(X)[3]
d <- dim(WLSTM)[2]/4 # hidden size
if (is.null(c0)){
c0 <- matrix(0, b, d)
}
if (is.null(h0)){
h0 <- matrix(0, b, d)
}
# Perform the LSTM forward pass with X as the input
xphpb <- dim(WLSTM)[1] # x plus h plus bias, lol
Hin <- array(0, c(n, b, xphpb))# input [1, xt, ht-1] to each tick of the LSTM
Hout <- array(0, c(n, b, d))# hidden representation of the LSTM (gated cell content)
IFOG <- array(0, c(n, b, d*4))# input, forget, output, gate (IFOG)
IFOGf <- array(0, c(n, b, d*4))# after nonlinearity
C <- array(0, c(n, b, d)) # cell content
Ct <- C # tanh of cell content
for (t in 1:n){
if (t > 1) {
prevh <- Hout[t-1, , ]
} else {
prevh <- h0
}
Hin[t, , 1] <- 1 # bias
Hin[t, , 2:(input_size+1)] <- X[t, , ]
Hin[t, , (input_size+2):xphpb] <- prevh
# compute all gate activations. dots: (most work is this line)
IFOG[t, , ] <- Hin[t, , ] %*% WLSTM
# non-linearities
IFOGf[t, , 1:(3*d)] <- 1.0/(1.0 + exp(-IFOG[t, , 1:(3*d)]))# sigmoids; these are the gates
IFOGf[t, , (3*d+1):(d*4)] <- tanh(IFOG[t, , (3*d+1):(d*4)]) # tanh
# compute the cell activation
if (t > 1){
prevc <- C[t-1, , ]
} else {
prevc <- c0
}
C[t, , ] <- IFOGf[t, ,1:d] * IFOGf[t, , (3*d+1):(4*d)] + IFOGf[t, , (d+1):(2*d)] * prevc
Ct[t, , ] <- tanh(C[t, , ])
Hout[t, , ] <- IFOGf[t, , (2*d+1):(3*d)] * Ct[t, , ]
}
cached <- list(WLSTM = WLSTM, Hout = Hout, IFOGf = IFOGf, IFOG = IFOG,
C = C, Ct = Ct, Hin = Hin, c0 = c0, h0 = h0)
return(list(Hout = Hout, C = C[t, , ], Hout_t = Hout[t, , ], cached = cached))
}
LSTM.backward <- function(dHout_in, cache, dcn = NULL, dhn = NULL){
WLSTM <- cache$WLSTM
Hout <- cache$Hout
IFOGf <- cache$IFOGf
IFOG <- cache$IFOG
C <- cache$C
Ct <- cache$Ct
Hin <- cache$Hin
c0 <- cache$c0
h0 <- cache$ho
n <- dim(Hout)[1]; b <- dim(Hout)[2]; d = dim(Hout)[3]
input_size <- dim(WLSTM)[1]- d - 1 # -1 due to bias
# backprop the LSTM
dIFOG <- array(0, dim(IFOG))
dIFOGf <- array(0, dim(IFOGf))
dWLSTM <- array(0, dim(WLSTM))
dHin <- array(0, dim(Hin))
dC <- array(0, dim(C))
dX <- array(0, c(n, b, input_size))
dh0 <- matrix(0, b, d)
dc0 <- matrix(0, b, d)
dHout <- dHout_in
if (!is.null(dcn)){# carry over gradients from later
dC[n,,] <- dC[n,,] + dcn
}
if (!is.null(dhn)){
dHout[n,,] <- dHout[n,,] + dhn
}
# Do the Backprop
for (t in n:1){
tanhCt <- Ct[t,,]
dIFOGf[t, , (2*d+1):(3*d)] <- tanhCt * dHout[t, , ]
# backprop tanh non-linearity first then continue backprop
dC[t, , ] <- dC[t, , ] + (1 - tanhCt^2) * (IFOGf[t, , (2*d+1):(3*d)] * dHout[t, , ])
if(t > 1){
dIFOGf[t, , (d+1):(2*d)] <- C[t-1, , ] * dC[t, , ]
dC[t-1, , ] <- dC[t-1, , ] + IFOGf[t, , (d+1):(2*d)] * dC[t, , ]
} else {
dIFOGf[t, , (d+1):(d*2)] <- c0 * dC[t, , ]
dc0 <- IFOGf[t, , (d+1):(2*d)] * dC[t, , ]
}
dIFOGf[t, , 1:d] <- IFOGf[t, , (3*d+1):(4*d)] * dC[t, , ]
dIFOGf[t, , (3*d+1):(4*d)] <- IFOGf[t, , 1:d] * dC[t, , ]
dIFOG[t, , (3*d+1):(4*d)] <- (1 - IFOGf[t, , (3*d+1):(4*d)]^2) * dIFOGf[t, , (3*d+1):(4*d)]
y <- IFOGf[t, , 1:(3*d)]
dIFOG[t, , 1:(3*d)] <- (y * (1.0 - y)) * dIFOGf[t, , 1:(3*d)]
# backprop matrix multiply
dWLSTM <- dWLSTM + t(Hin[t, , ]) %*% dIFOG[t, , ]
dHin[t, , ] <- dIFOG[t, , ] %*% t(WLSTM)
# backprop the identity transforms into Hin
dX[t, , ] <- dHin[t, , 2:(input_size+1)]
if(t > 1){
dHout[t-1, , ] <- dHout[t-1, , ] + dHin[t, , (input_size+2):dim(dHin)[3]]
} else {
dh0 <- dh0 + dHin[t, , (input_size+2):dim(dHin)[3]]
}
}
return(list(dX = dX, dWLSTM = dWLSTM, dc0 = dc0, dh0 = dh0))
}
checkSequentialMatchesBatch <- function(){
n <- 5; b <- 3; d <- 4 # sequence length, batch size, hidden size
input_size <- 10
WLSTM <- LSTM.init(input_size, d) # input size, hidden size
X <- array(rnorm(prod(n ,b, input_size)), c(n, b, input_size))
h0 <- matrix(rnorm(b * d), b, d)
c0 <- matrix(rnorm(b * d), b, d)
# sequential forward
cprev <- c0
hprev <- h0
caches <- vector("list", n)
Hcat <- array(0, c(n, b, d))
for (t in 1:n){
xt <- X[t,,,drop=FALSE]
seq_res <- LSTM.forward(xt, WLSTM, cprev, hprev)
cprev <- seq_res$C; hprev <- seq_res$Hout_t; caches[[t]] <- seq_res$cached
Hcat[t,,] <- hprev
}
# sanity check: perform batch forward to check that we get the same thing
batch_res <- LSTM.forward(X, WLSTM, c0, h0)
H <- batch_res$Hout; batch_cache <- batch_res$cached
if(!all.equal(H, Hcat)){
print('Sequential and Batch forward dont match!')
} else {
print("All good with the forwarding")
}
# eval loss
wrand <- array(rnorm(prod(dim(Hcat))), dim(Hcat))
loss <- sum(Hcat * wrand)
dH <- wrand
# get the batched version gradients
batch_bwd <- LSTM.backward(dH, batch_cache)
BdX <- batch_bwd$dX ; BdWLSTM <- batch_bwd$dWLSTM
Bdh0 <- batch_bwd$dh0; Bdc0 <- batch_bwd$dc0
# now perform sequential backward
dX <- array(0, dim(X))
dWLSTM <- array(0, dim(WLSTM))
dc0 <- matrix(0, dim(c0))
dh0 <- matrix(0, dim(h0))
dcnext <- NULL
dhnext <- NULL
for (t in n:1){
dht <- dH[t,,,drop=FALSE]
seq_bwd <- LSTM.backward(dht, caches[[t]], dcnext, dhnext)
dx <- seq_bwd$dX; dWLSTMt <- seq_bwd$dWLSTM
dcprev <- seq_bwd$dc0; dhprev <- seq_bwd$dh0
dhnext <- dhprev
dcnext <- dcprev
dWLSTM <- dWLSTM + dWLSTMt # accumulate LSTM gradient
dX[t,,] <- dx[1,,]
if(t == 1){
dc0 <- dcprev
dh0 <- dhprev
}
}
# and make sure the gradients match
print('Making sure batched version agrees with sequential version: (should all be True)')
print(all.equal(BdX, dX))
print(all.equal(BdWLSTM, dWLSTM))
print(all.equal(Bdc0, dc0))
print(all.equal(Bdh0, dh0))
}
checkBatchGradient <- function(){
### check that the batch gradient is correct ###
# lets gradient check this beast
n <- 5; b <- 3; d <- 4 # sequence length, batch size, hidden size
input_size <- 10
WLSTM <- LSTM.init(input_size, d) # input size, hidden size
X <- array(rnorm(prod(c(n,b,input_size))), c(n,b,input_size))
h0 <- matrix(rnorm(b * d), b,d)
c0 <- matrix(rnorm(b * d), b,d)
# batch forward backward
batch_res <- LSTM.forward(X, WLSTM, c0, h0)
H <- batch_res$Hout; Ct <- batch_res$C
Ht <- batch_res$Hout_t; cache <- batch_res$cached
wrand <- array(rnorm(prod(dim(H))), dim(H))
loss <- sum(H * wrand) # weighted sum is a nice hash to use I think
dH <- wrand
batch_bwd <- LSTM.backward(dH, cache)
dX <- batch_bwd$dX
dWLSTM <- batch_bwd$dWLSTM
dc0 <- batch_bwd$dc0
dh0 <- batch_bwd$dh0
fwd <- function(we){
h <- LSTM.forward(we$X, we$WLSTM, we$c0, we$h0)$Hout
return(sum(h*wrand))
}
# now gradient check all
delta <- 0.00001
rel_error_thr_warning <- 0.01
rel_error_thr_error <- 1
tocheck <- list(X = X, WLSTM = WLSTM, c0 = c0, h0 = h0)
grads_analytic <- list(dX = dX, dWLSTM = dWLSTM, dc0 = dc0, dh0 = dh0)
names <- c("X", "WLSTM", "c0", "h0")
doit <- mapply(FUN = function(w, dw, names, we.true){
w.temp <- w
cat("\nWeight:", names, "\n")
for (i in 1:length(w)){
temped <- c(w)
old_val <- temped[i]
temped[i] <- old_val + delta
we.true[[names]] <- array(temped, dim(w.temp))
loss0 <- fwd(we.true)
temped[i] <- old_val - delta
we.true[[names]] <- array(temped, dim(w.temp))
loss1 <- fwd(we.true)
grad_analytic <- c(dw)[i]
grad_numerical <- (loss0 - loss1) / (2 * delta)
if (grad_analytic == grad_numerical){
rel_error <- 0
status <- cat("\nOK\n")
} else if (abs(grad_numerical) < 1e-7 & abs(grad_analytic) < 1e-7){
rel_error <- 0
status <- cat("\nVAL SMALL WARNING\n")
} else {
rel_error <- abs(grad_analytic - grad_numerical) / abs(grad_numerical + grad_analytic)
status <- "OK"
if(rel_error > rel_error_thr_warning)
cat("\nWARNING\n")
if (rel_error > rel_error_thr_error)
cat("\n!!!NOTOK\n")
}
cat(status, "checking param", names, "index", i,"von" ,dim(w),
"\n(val = ", old_val, "), analytic = ", grad_analytic,
", numerical = ", grad_numerical, "\nrelative error = ", rel_error,"\n\n")
}
}, w = tocheck, dw = grads_analytic, names = names,
MoreArgs = list(we.true = tocheck))
}
checkSequentialMatchesBatch()
checkBatchGradient()
print ('every line should start with OK. Have a nice day!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment