Created
April 19, 2013 21:25
-
-
Save michaelchughes/5423341 to your computer and use it in GitHub Desktop.
Rank one update of cholesky and log determinant of the posterior parameters of a Wishart distribution.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [] = RankOneUpdateTest( D ) | |
nTrial = 1000; | |
invW = randn( 3*D, D); | |
invW = invW'*invW; | |
x = randn(D,1); | |
cholinvW = chol(invW,'upper'); | |
logdetinvW = 2*sum(log(diag(cholinvW))); | |
[ldetA,cholA] = update_naive( invW, x); | |
[ldetB,cholB] = update_rankone( cholinvW, logdetinvW,x); | |
assert( max(abs(ldetA-ldetB)) < 1e-8, 'Bad det calc'); | |
tic; | |
for a = 1:nTrial | |
[ldetA,cholA] = update_naive( invW, x); | |
end | |
etimeNaive = toc/nTrial; | |
tic; | |
for a = 1:nTrial | |
[ldetB,cholB] = update_rankone( cholinvW, logdetinvW,x); | |
end | |
etimeRankone = toc/nTrial; | |
fprintf( ' Naive time: %.6f sec/trial\n', etimeNaive); | |
fprintf( 'Rank one time: %.6f sec/trial\n', etimeRankone); | |
fprintf('%.2f x speedup\n', etimeNaive/etimeRankone); | |
end | |
function [logdetNew, cholNew] = update_rankone( cholinvW, logdetinvW, x) | |
q = cholinvW' \ x; | |
logdetNew = logdetinvW + log( 1.0+q'*q); | |
cholNew = cholupdate( cholinvW, x, '+'); | |
end | |
function [logdetNew, cholNew] = update_naive( invW, x) | |
cholNew = chol( invW + x*x', 'upper'); | |
logdetNew = 2*sum(log(diag(cholNew))); | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.linalg | |
import choldate | |
import time | |
import argparse | |
nTrial = 1000 | |
D = 64 | |
def gen_problem( D): | |
invW = np.random.randn( 3*D, D) | |
invW = np.dot(invW.T, invW) | |
cholinvW = scipy.linalg.cholesky( invW, lower=False) | |
logdetinvW = 2*np.sum(np.log(np.diag(cholinvW))) | |
x = np.random.randn( D ) | |
return invW, cholinvW, logdetinvW, x | |
def update_naive( invW, x): | |
cholnew = scipy.linalg.cholesky( invW + np.outer(x,x), lower=False) | |
logdetnew = 2*np.sum(np.log(np.diag(cholnew))) | |
return logdetnew, cholnew | |
def update_rankone( logdetinvW, cholinvW, x): | |
#q = np.linalg.solve( cholinvW.T, x) | |
q = scipy.linalg.solve_triangular( cholinvW.T, x, lower=True) | |
logdetnew = logdetinvW + np.log(1.0 + np.inner(q,q)) | |
cholnew = cholinvW.copy() | |
choldate.cholupdate( cholinvW, x.copy() ) | |
return logdetnew, cholinvW | |
def run_verification( D ): | |
invW, cholinvW, logdetinvW, x = gen_problem( D ) | |
logdetA, cholA = update_naive( invW, x) | |
logdetB, cholB = update_rankone( logdetinvW, cholinvW, x) | |
assert np.allclose( logdetA, logdetB) | |
assert np.allclose( cholA, cholB) | |
print 'All tests pass' | |
def run_timing_experiments( D, nTrial): | |
invW, cholinvW, logdetinvW, x = gen_problem( D ) | |
stime = time.time() | |
for t in xrange(nTrial): | |
update_naive( invW, x) | |
etimeA = (time.time() - stime)/nTrial | |
stime = time.time() | |
for t in xrange(nTrial): | |
update_rankone( logdetinvW, cholinvW, x) | |
etimeB = (time.time() - stime)/nTrial | |
print ' Naive time: %.6f sec/trial' % (etimeA) | |
print 'Rank one time: %.6f sec/trial' % (etimeB) | |
print '%.2f x speedup' % (etimeA/etimeB) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( '--D', type=int, default=D) | |
parser.add_argument( '--nTrial', type=int, default=nTrial) | |
args = parser.parse_args() | |
run_verification(args.D) | |
run_timing_experiments( args.D, args.nTrial) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment