Skip to content

Instantly share code, notes, and snippets.

@michaelchughes
Created April 19, 2013 21:25
Show Gist options
  • Save michaelchughes/5423341 to your computer and use it in GitHub Desktop.
Save michaelchughes/5423341 to your computer and use it in GitHub Desktop.
Rank one update of cholesky and log determinant of the posterior parameters of a Wishart distribution.
function [] = RankOneUpdateTest( D )
nTrial = 1000;
invW = randn( 3*D, D);
invW = invW'*invW;
x = randn(D,1);
cholinvW = chol(invW,'upper');
logdetinvW = 2*sum(log(diag(cholinvW)));
[ldetA,cholA] = update_naive( invW, x);
[ldetB,cholB] = update_rankone( cholinvW, logdetinvW,x);
assert( max(abs(ldetA-ldetB)) < 1e-8, 'Bad det calc');
tic;
for a = 1:nTrial
[ldetA,cholA] = update_naive( invW, x);
end
etimeNaive = toc/nTrial;
tic;
for a = 1:nTrial
[ldetB,cholB] = update_rankone( cholinvW, logdetinvW,x);
end
etimeRankone = toc/nTrial;
fprintf( ' Naive time: %.6f sec/trial\n', etimeNaive);
fprintf( 'Rank one time: %.6f sec/trial\n', etimeRankone);
fprintf('%.2f x speedup\n', etimeNaive/etimeRankone);
end
function [logdetNew, cholNew] = update_rankone( cholinvW, logdetinvW, x)
q = cholinvW' \ x;
logdetNew = logdetinvW + log( 1.0+q'*q);
cholNew = cholupdate( cholinvW, x, '+');
end
function [logdetNew, cholNew] = update_naive( invW, x)
cholNew = chol( invW + x*x', 'upper');
logdetNew = 2*sum(log(diag(cholNew)));
end
import numpy as np
import scipy.linalg
import choldate
import time
import argparse
nTrial = 1000
D = 64
def gen_problem( D):
invW = np.random.randn( 3*D, D)
invW = np.dot(invW.T, invW)
cholinvW = scipy.linalg.cholesky( invW, lower=False)
logdetinvW = 2*np.sum(np.log(np.diag(cholinvW)))
x = np.random.randn( D )
return invW, cholinvW, logdetinvW, x
def update_naive( invW, x):
cholnew = scipy.linalg.cholesky( invW + np.outer(x,x), lower=False)
logdetnew = 2*np.sum(np.log(np.diag(cholnew)))
return logdetnew, cholnew
def update_rankone( logdetinvW, cholinvW, x):
#q = np.linalg.solve( cholinvW.T, x)
q = scipy.linalg.solve_triangular( cholinvW.T, x, lower=True)
logdetnew = logdetinvW + np.log(1.0 + np.inner(q,q))
cholnew = cholinvW.copy()
choldate.cholupdate( cholinvW, x.copy() )
return logdetnew, cholinvW
def run_verification( D ):
invW, cholinvW, logdetinvW, x = gen_problem( D )
logdetA, cholA = update_naive( invW, x)
logdetB, cholB = update_rankone( logdetinvW, cholinvW, x)
assert np.allclose( logdetA, logdetB)
assert np.allclose( cholA, cholB)
print 'All tests pass'
def run_timing_experiments( D, nTrial):
invW, cholinvW, logdetinvW, x = gen_problem( D )
stime = time.time()
for t in xrange(nTrial):
update_naive( invW, x)
etimeA = (time.time() - stime)/nTrial
stime = time.time()
for t in xrange(nTrial):
update_rankone( logdetinvW, cholinvW, x)
etimeB = (time.time() - stime)/nTrial
print ' Naive time: %.6f sec/trial' % (etimeA)
print 'Rank one time: %.6f sec/trial' % (etimeB)
print '%.2f x speedup' % (etimeA/etimeB)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument( '--D', type=int, default=D)
parser.add_argument( '--nTrial', type=int, default=nTrial)
args = parser.parse_args()
run_verification(args.D)
run_timing_experiments( args.D, args.nTrial)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment