Skip to content

Instantly share code, notes, and snippets.

@p-i-
Created October 5, 2016 21:15
Show Gist options
  • Save p-i-/f52dbd431217b5f7ae0a8bdc9592a85b to your computer and use it in GitHub Desktop.
Save p-i-/f52dbd431217b5f7ae0a8bdc9592a85b to your computer and use it in GitHub Desktop.
for epoch = 1 : maxepoch
%%% Do Conjugate Gradient Optimization
fprintf( 1, 'epoch %d, batch: ', epoch );
randomBatch = randperm(600);
for b = 1 : nBatches_train/100, % do 100 batches at a go
fprintf(1, ' %d', b);
data = zeros(10000,neurons0);
pd2_all = zeros(10000,neurons2);
targets = zeros(10000,10);
R = (b-1)*100+1 : b*100;
for i = 1 : 100
data ( (i-1)*100+1:i*100, :) = batchdata_train(:,:,randomBatch(R(i)));
pd2_all( (i-1)*100+1:i*100, :) = h2_train(:,:,randomBatch(R(i)));
targets( (i-1)*100+1:i*100, :) = batchtargets_train(:,:,randomBatch(R(i)));
end
%%%%%%%% DO CG with 3 linesearches
% checkgrad('CG_MNIST_INIT',VV,10^-5,Dim,data,targets);
reshape = @(X) X(:)';
theta = arrayfun( reshape, W01, W21, W23, W32, b1, b2, b3 );
%theta = [W01_(:)' W21(:)' W23_(:)' W32_(:)' b1_(:)' b2_(:)' b3(:)']';
dims = [neurons0; neurons1; neurons2; ];
do_init = epoch<6;
max_iter=3;
[X, fX] ...
= minimize( theta, 'CG_MNIST', max_iter, dims, data, targets, pd2_all, do_init );
W01 = reshape( X( 1: neurons0*neurons1), neurons0, neurons1 ); N = neurons0*neurons1;
W21 = reshape( X(N+1:N+neurons2*neurons1), neurons2, neurons1 ); N = N + neurons2*neurons1;
W23 = reshape( X(N+1:N+neurons1*neurons2), neurons1, neurons2 ); N = N + neurons1*neurons2;
W32 = reshape( X(N+1:N+neurons2*10 ), neurons2, 10 ); N = N + neurons2*10;
b1 = reshape( X(N+1:N+neurons1 ), 1 , neurons1 ); N = N + neurons1;
b2 = reshape( X(N+1:N+neurons2 ), 1 , neurons2 ); N = N + neurons2;
b3 = reshape( X(N+1:N+10 ), 1 , 10 ); N = N + 10;
end
end
function [f, df] = CG_MNIST( V, Dim, X, target, temp_h2, do_init )
n0 = Dim(1);
n1 = Dim(2);
n2 = Dim(3);
n3 = 10;
% Do deconversion.
k=0;
w01 = reshape( V(k+1:k+n0*n1), n0, n1); k = k+n0*n1;
w21 = reshape( V(k+1:k+n2*n1), n2, n1); k = k+n2*n1;
w12 = reshape( V(k+1:k+n1*n2), n1, n2); k = k+n1*n2;
w23 = reshape( V(k+1:k+n2*n3), n2, n3); k = k+n2*10;
b1 = reshape( V(k+1:k+n1 ), 1, n1); k = k+n1;
b2 = reshape( V(k+1:k+n2 ), 1, n2); k = k+n2;
b3 = reshape( V(k+1:k+10 ), 1, n3); %k = k+n3;
batchSize = size(X,1);
B1 = repmat(b1,batchSize,1);
B2 = repmat(b2,batchSize,1);
B3 = repmat(b3,batchSize,1);
pd1 = sigmoid( X*w01 + temp_h2*w21 + B1 );
pd2 = sigmoid( pd1*w12 + B2 );
expZ = exp( pd2*w23 + B3 );
pd3 = expZ ./ repmat( sum(expZ,2), 1, 10 );
f = - sum(sum( target(:,1:end) .* log(pd3) ));
IO = pd3 - target(:,1:end);
Ix3 = IO; dw23 = pd2'*Ix3; db3 = sum(Ix3);
Ix2 = (Ix3*w23') .* pd2 .* (1-pd2); dw12 = pd1'*Ix2; db2 = sum(Ix2);
Ix1 = (Ix2*w12') .* pd1 .* (1-pd1); dw21 = temp_h2'*Ix1; db1 = sum(Ix1);
dw01 = X'*Ix1;
if do_init
dw01 = 0 * dw01;
dw12 = 0 * dw12;
dw21 = 0 * dw21;
db1 = 0 * db1;
db2 = 0 * db2;
end
df = [dw01(:)' dw21(:)' dw12(:)' dw23(:)' db1(:)' db2(:)' db3(:)']';
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment