Created
March 30, 2014 13:09
-
-
Save chiral/9872579 to your computer and use it in GitHub Desktop.
Restricted Boltzmann Machine implementation in R and Julia (Julia version is much faster than R)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions | |
sigmoid(x) = 1/(1+exp(-x)) | |
type Theta | |
W::Array{Float64,2} | |
bn::Array{Float64,1} | |
bm::Array{Float64,1} | |
end | |
function rbm(obs,n_hidden; | |
eta=0.05, | |
epsilon=0.05, | |
maxiter=100, | |
CD_k=1, | |
reconstruct_trial=1, | |
verbose=0) | |
N,L = size(obs) | |
M = n_hidden | |
pn = map(i->min(0.9,sum(obs[i,:])/L),1:N) | |
bn = log(pn./(1-pn)) | |
bm = zeros(M) | |
W = rand(Normal(0,0.01),N,M) | |
pv_h(i,h) = sigmoid((W[i,:]*h)[1]+bn[i]) | |
ph_v(j,v) = sigmoid((v'*W[:,j])[1]+bm[j]) | |
pv_h_array(h) = sigmoid(W*h+bn) | |
ph_v_array(v) = sigmoid(W'*v+bm) | |
unif = Uniform(0,1) | |
gs_v(h) = 1.0*(rand(unif,N).<pv_h_array(h)) | |
gs_h(v) = 1.0*(rand(unif,M).<ph_v_array(v)) | |
function cd_k(v) | |
v1 = v | |
for i=1:CD_k | |
h1 = gs_h(v1) | |
v1 = gs_v(h1) | |
end | |
ph = ph_v_array(v) | |
ph1 = ph_v_array(v1) | |
return Theta(v*ph'-v1*ph1',v-v1,ph-ph1) | |
end | |
function theta_step() | |
t = Theta(zeros(N,M),zeros(N),zeros(M)) | |
for i=1:L | |
if verbose>=3 | |
print(STDERR,"CD_k for ",i,"\n") | |
end | |
d = cd_k(obs[:,i]) | |
t.W += d.W | |
t.bn += d.bn | |
t.bm += d.bm | |
end | |
return t | |
end | |
reconstruct(v) = gs_v(gs_h(v)) | |
function recon_error() | |
r = 0 | |
for t=1:reconstruct_trial,i=1:L | |
if verbose>=3 | |
print(STDERR,"recon trial ",i,"\n") | |
end | |
v = obs[:,i] | |
v1 = reconstruct(v) | |
r += sum(abs(v-v1)) | |
end | |
return r/(N*L*reconstruct_trial) | |
end | |
err = 1 | |
count = 0 | |
learn_info = "" | |
print(STDERR,"init OK.\n") | |
while err>epsilon && count<maxiter | |
count += 1 | |
if verbose>=2 | |
print(STDERR,"step ",count,"\n") | |
end | |
d = theta_step() | |
backup = (err,Theta(W,bn,bm)) | |
W += eta*d.W | |
bn += eta*d.bn | |
bm += eta*d.bm | |
err = recon_error() | |
if (err>backup[1]) | |
err,t=backup | |
W,bn,bm=t.W,t.bn,t.bm | |
else | |
learn_info=string("step ",count," : err=",err) | |
if verbose>=1 | |
print(STDERR,learn_info,"\n") | |
end | |
end | |
end | |
return ["theta"=>Theta(W,bn,bm), | |
"learn_info"=>learn_info, | |
"reconstruct"=>reconstruct] | |
end | |
### test program | |
function test() | |
obs = [1 0 1 0; 1 1 0 0; 0 1 0 1; 0 0 1 1; 1 1 1 1; 0 0 0 0; 1 0 0 1; 1 1 0 0; 1 0 1 0] | |
obj = rbm(1.0*obs',3,maxiter=1000,reconstruct_trial=10,verbose=1) | |
print(obj,"\n") | |
for i in 1:size(obs)[1] | |
print(obj["reconstruct"](obs[i,:]')') | |
end | |
end | |
test() | |
### mnist charactor sign recognition | |
print(""" | |
test_labels = readcsv("mnist/t10k-labels-idx1-ubyte.csv") | |
train_labels = readcsv("mnist/train-labels-idx1-ubyte.csv") | |
test_images = readcsv("mnist/t10k-images-idx3-ubyte.csv") | |
train_images = readcsv("mnist/train-images-idx3-ubyte.csv") | |
obj = rbm(1.0*train_images[:,:]',100,reconstruct_trial=3,maxiter=1,verbose=3) | |
print(obj) | |
f=open("theta.dump") | |
serialize(f,obj) | |
close(f) | |
""") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Restricted Boltzmann Machine implementation by isobe | |
sigmoid <- function(x) 1/(1+exp(-x)) | |
rbm <- function(obs,n_hidden,eta=0.05, | |
epsilon=0.05,maxiter=100, | |
CD_k=1,reconstruct_trial=10, | |
verbose=0) { | |
L <- nrow(obs) | |
N <- ncol(obs) | |
M <- n_hidden | |
# initial values assinment | |
# cf) Chapter 8 in | |
# http://www.cs.toronto.edu/~hinton/absps/guideTR.pdf | |
pn <- apply(obs,2,function(x) min(0.9,sum(x)/L)) | |
bn <- log(pn/(1-pn)) | |
bm <- rep(0,M) | |
W <- matrix(rnorm(N*M,0,0.01),N,M) | |
pv_h <- function(i,h) { | |
sigmoid(sum(W[i,]*h)+bn[i]) | |
} | |
ph_v <- function(i,v) { | |
sigmoid(sum(W[,i]*v)+bm[i]) | |
} | |
gs_step <- function(x,n,p_func) { | |
r<-c() | |
for (i in 1:n) { | |
r<-c(r,rbinom(1,1,p_func(i,x))) | |
} | |
return(r) | |
} | |
gs_v <- function(h) gs_step(h,N,pv_h) | |
gs_h <- function(v) gs_step(v,M,ph_v) | |
cd_k <- function(v) { | |
v1 <- v | |
for (i in 1:CD_k) { | |
h1 <- gs_h(v1) | |
v1 <- gs_v(h1) | |
} | |
# R has immutable value and lexical scope, | |
# so we can overwrite locally. | |
for (i in 1:N) for (j in 1:M) { | |
W[i,j] <- ph_v(j,v)*v[i]-ph_v(j,v1)*v1[i] | |
} | |
bn <- v-v1 | |
for (j in 1:M) { | |
bm[j] <- ph_v(j,v)-ph_v(j,v1) | |
} | |
return(list(W=W,bn=bn,bm=bm)) | |
} | |
theta_step <- function() { | |
W <- matrix(0,N,M) | |
bn <- rep(0,N) | |
bm <- rep(0,M) | |
for (i in 1:L) { | |
if (verbose>=3) cat(paste("theta for obs ",i,"\n")) | |
d <- cd_k(obs[i,]) | |
W <- W+d$W | |
bn <- bn+d$bn | |
bm <- bm+d$bm | |
} | |
return(list(W=W,bn=bn,bm=bm)) | |
} | |
reconstruct <- function(v) gs_v(gs_h(v)) | |
recon_error <- function() { | |
r <- 0 | |
for (t in 1:reconstruct_trial) for (i in 1:L) { | |
v <- obs[i,] | |
v1 <- reconstruct(v) | |
r <- r+sum(abs(v-v1)) | |
} | |
return(r/(N*L*reconstruct_trial)) | |
} | |
err <- 1 | |
count <- 0 | |
cat("init OK. \n") | |
while (err>epsilon && count<maxiter) { | |
if (verbose>=2) cat(paste("step =",count,"\n")) | |
d <- theta_step() | |
backup <- list(W=W,bn=bn,bm=bm,err=err) | |
W <- W + eta*d$W | |
bn <- bn + eta*d$bn | |
bm <- bm + eta*d$bm | |
count <- count+1 | |
err <- recon_error() | |
if (backup$err<err) { | |
W <- backup$W | |
bn <- backup$bn | |
bm <- backup$bm | |
err <- backup$err | |
} else if (verbose) { | |
if (verbose>=1) print(paste("step",count,": err=",err)) | |
} | |
} | |
hidden_prob <- function(v) { | |
apply(rbind(1:M),1,function(i) ph_v(i,v)) | |
} | |
learn_info=paste("step",count,": err=",err) | |
obj <- list(W=W,bn=bn,bm=bm, | |
learn_info=learn_info, | |
hidden_prob=hidden_prob, | |
hidden_sample=gs_h, | |
reconstruct=reconstruct) | |
class(obj) <- 'rbm' | |
return(obj) | |
} | |
print.rbm <- function(rbm) { | |
cat("edge weights:\n") | |
print(rbm$W) | |
cat("\nbias for observable nodes:\n") | |
print(rbm$bn) | |
cat("\nbias for hidden nodes:\n") | |
print(rbm$bm) | |
cat(paste("\n",rbm$learn_info,"\n",sep='')) | |
} | |
rbm_hidden_prob <- function(obj,obs) obj$hidden_prob(obs) | |
rbm_hidden_sample <- function(obj,obs) obj$hidden_sample(obs) | |
rbm_reconstruct <- function(obj,obs) obj$reconstruct(obs) | |
### test program | |
test <- function() { | |
obs <- rbind(c(1,0,1), | |
c(1,1,0), | |
c(1,0,1), | |
c(0,1,1)) | |
net <- rbm(obs,2,verbose=T,maxiter=3000) | |
print(net) | |
x <- c(1,1,0) | |
trial <- 5 | |
cat("original") | |
print(x) | |
for (t in 1:trial) { | |
cat("reconstructed") | |
print(rbm_reconstruct(net,x)) | |
} | |
} | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment