Skip to content

Instantly share code, notes, and snippets.

@treper
Created August 4, 2013 04:35
Show Gist options
  • Save treper/6149146 to your computer and use it in GitHub Desktop.
Save treper/6149146 to your computer and use it in GitHub Desktop.
draft code
#include <iostream>
#include <Eigen/Dense>
#include <Eigen/Array>
using namespace Eigen;
using namespace std;
void sigmoid(MatrixXf& input, MatrixXf& output)
{
output = (1+ (input.array().exp()).array()).array().inverse();
}
class SparseAutoEncoder
{
private:
int numpatches;
int visibleSize;
int hiddenSize;
//cost parameters
float lambda;
float sparsityParam;
float beta;
//training parameters
float cost;
MatrixXf data;//visibleSize x numpatches
MatrixXf W1;//hiddenSize x visibleSize
MatrixXf W2;//visibleSize x hiddenSize
VectorXf b1;//hiddenSize x 1
VectorXf b2;//visibleSize x 1
MatrixXf gradient;
MatrixXf hiddeninputs;//hiddenSize x numpatches
MatrixXf hiddenvalues;//hiddenSize x numpatches
MatrixXf finalinputs;//visibleSize x numpatches
MatrixXf outputs;//visibleSize x numpatches
MatrixXf errors;//visibleSize x numpatches
SparseAutoEncoder(int visibleSize, int hiddenSize, float lambda= 0.0001 , float sparsityParam = 0.01 , float beta = 3);
public:
void initializeParameters();
void sparseAutoEncoderCost();
void train();
void computeNumericalGradient();
void checkNumericalGradient();
}
SparseAutoEncoder::SparseAutoEncoder(int visibleSize, int hiddenSize, int numpatches)
{
initializeParameters(hiddenSize, visibleSize);
data = MatrixXf::Zero(visibleSize,numpatches);
hiddeninputs = MatrixXf::Zero(hiddenSize,numpatches);
hiddenvalues = MatrixXf::Zero(hiddenSize,numpatches);
finalinputs = MatrixXf::Zero(visibleSize,numpatches);
outputs = MatrixXf::Zero(visibleSize,numpatches);
}
void SparseAutoEncoder::initializeParameters(int hiddenSize, int visibleSize)
{
float r = sqrt(6)/sqrt(hiddenSize + visibleSize + 1);
W1.resize(hiddenSize, visibleSize);
W1.setRandom();
W1 = r*W1;
W2.resize(visibleSize, hiddenSize);
W2.setRandom();
W2 = r*W2;
b1 = MatrixXf::Zero(hiddenSize, 1);
b2 = MatrixXf::Zero(visibleSize, 1);
}
float sparseAutoEncoderCost()
{
MatrixXf hiddeninputs(hiddenSize,numpatches);
MatrixXf hiddenvalues(hiddenSize,numpatches);
MatrixXf finalinputs(visibleSize,numpatches);
MatrixXf outputs(visibleSize,numpatches);
MatrixXf weightsbuffer = MatrixXf::Ones(1,numpatches);
hiddeninputs = W1*data + b1*weightsbuffer;
sigmoid(hiddeninputs,hiddenvalues);
finalinputs = W2*hiddenvalues + b2*weightsbuffer;
sigmoid(finalinputs,outputs);
errors = outputs - data;
//Least squares component of cost
leastsquares = (errors.array()*errors.array()).sum()/(2*numpatches);
//Back-propagation calculation of gradients
delta3 = (errors.array()*outputs.array()).array()*(1-outputs.array()).array();//visibleSize x numpatches
W2grad = (delta3*hiddenvalues.transpose()).array()/numpatches;//visibleSize x hiddenSize
b2grad = (delta3*weightsbuffer.transpose()).array()/numpatches;//visibleSize x 1
//Sparsity
avgactivations = hiddenvalues*weightsbuffer.transpose()/numpatches;//hiddenSize x 1
sparsityvec = -sparsityParam*avgactivations.array().inverse() + (1-sparsityParam)*(1-avgactivations.array()).array().inverse();//hiddenSize x 1
kldiv = sparsityParam*log((sparsityParam*avgactivations.array().inverse()).prod()) + (1-sparsityParam)*log(((1-sparsityParam)*(1-avgactivations.array().inverse()).array()).prod());
delta2 = (((W2.transpose()*delta3).array() + (beta*(sparsityvec*weightsbuffer).array()).array()).array()*hiddenvalues.array()*(1-hiddenvalues.array()).array();//hiddenSize x numpatches
W1grad = delta2*data.transpose()/numpatches;//hiddenSize x visibleSize
b1grad = delta2*weightsbuffer.transpose()/numpatches;//hiddenSize x 1
cost = leastsquares + beta*kldiv;
//weight decay
cost = cost + lambda/2*(power(W1.norm(),2)+power(W2.norm(),2));
W1grad = (W1grad.array() + lambda*W1.array()).eval();
W2grad = (W2grad.array() + lambda*W2.array()).eval();
}
int main()
{
SparseAutoEncoder sparseAutoEncoder();
sparseAutoEncoder.loadData();
sparseAutoEncoder.train();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment