Last active
December 1, 2016 21:16
-
-
Save jrevels/fec9f75b667eecff897d566e8133aee1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ForwardDiff | |
using BenchmarkTools | |
function gen_beta_kl_obj(alpha2, beta2) | |
lgamma_alpha2 = lgamma(alpha2) | |
lgamma_beta2 = lgamma(beta2) | |
return x -> begin | |
alpha1, beta1 = x | |
alpha_diff = alpha1 - alpha2 | |
beta_diff = beta1 - beta2 | |
both_inv_diff = -(alpha_diff + beta_diff) | |
di_both1 = digamma(alpha1 + beta1) | |
log_term = lgamma(alpha1 + beta1) - lgamma(alpha1) - lgamma(beta1) | |
log_term -= lgamma(alpha2 + beta2) - lgamma_alpha2 - lgamma_beta2 | |
apart_term = alpha_diff * digamma(alpha1) + beta_diff * digamma(beta1) | |
together_term = both_inv_diff * di_both1 | |
return log_term + apart_term + together_term | |
end | |
end | |
function gen_beta_kl_diff(a2, b2) | |
kl_obj = gen_beta_kl_obj(a2, b2) | |
kl_input = zeros(2) | |
result = DiffBase.HessianResult(kl_input) | |
cfg = ForwardDiff.HessianConfig(result, kl_input) | |
return (a1, b1) -> begin | |
kl_input[1] = a1 | |
kl_input[2] = b1 | |
ForwardDiff.hessian!(result, kl_obj, kl_input, cfg) | |
return result | |
end | |
end | |
function gen_beta_kl_old{NumType <: Number}(alpha2::NumType, beta2::NumType) | |
lgamma_alpha2 = lgamma(alpha2) | |
lgamma_beta2 = lgamma(beta2) | |
function this_beta_kl{NumType2 <: Number}( | |
alpha1::NumType2, beta1::NumType2) | |
alpha_diff = alpha1 - alpha2 | |
beta_diff = beta1 - beta2 | |
both_inv_diff = -(alpha_diff + beta_diff) | |
di_both1 = digamma(alpha1 + beta1) | |
log_term = lgamma(alpha1 + beta1) - lgamma(alpha1) - lgamma(beta1) | |
log_term -= lgamma(alpha2 + beta2) - lgamma_alpha2 - lgamma_beta2 | |
apart_term = alpha_diff * digamma(alpha1) + beta_diff * digamma(beta1) | |
together_term = both_inv_diff * di_both1 | |
kl = log_term + apart_term + together_term | |
grad = zeros(NumType2, 2) | |
hess = zeros(NumType2, 2, 2) | |
trigamma_alpha1 = trigamma(alpha1) | |
trigamma_beta1 = trigamma(beta1) | |
trigamma_both = trigamma(alpha1 + beta1) | |
grad[1] = alpha_diff * trigamma_alpha1 + both_inv_diff * trigamma_both | |
grad[2] = beta_diff * trigamma_beta1 + both_inv_diff * trigamma_both | |
quadgamma_both = polygamma(2, alpha1 + beta1) | |
hess[1, 1] = alpha_diff * polygamma(2, alpha1) + | |
both_inv_diff * quadgamma_both + | |
trigamma_alpha1 - trigamma_both | |
hess[2, 2] = beta_diff * polygamma(2, beta1) + | |
both_inv_diff * quadgamma_both + | |
trigamma_beta1 - trigamma_both | |
hess[1, 2] = hess[2, 1] = | |
-trigamma_both + both_inv_diff * quadgamma_both | |
return kl, grad, hess | |
end | |
end | |
const a1 = 4.1 | |
const b1 = 3.9 | |
const a2 = 3.5 | |
const b2 = 4.3 | |
const kl_obj_input = [a1, b1] | |
const kl_obj = gen_beta_kl_obj(a2, b2) | |
const kl_new = gen_beta_kl_diff(a2, b2) | |
const kl_old = gen_beta_kl_old(a2, b2) | |
#= | |
julia> @benchmark kl_obj($(kl_obj_input)) | |
BenchmarkTools.Trial: | |
memory estimate: 0.00 bytes | |
allocs estimate: 0 | |
-------------- | |
minimum time: 204.257 ns (0.00% GC) | |
median time: 204.450 ns (0.00% GC) | |
mean time: 204.485 ns (0.00% GC) | |
maximum time: 234.630 ns (0.00% GC) | |
-------------- | |
samples: 10000 | |
evals/sample: 575 | |
time tolerance: 5.00% | |
memory tolerance: 1.00% | |
julia> @benchmark kl_new($a1, $b1) | |
BenchmarkTools.Trial: | |
memory estimate: 128.00 bytes | |
allocs estimate: 3 | |
-------------- | |
minimum time: 1.108 μs (0.00% GC) | |
median time: 1.121 μs (0.00% GC) | |
mean time: 1.129 μs (0.00% GC) | |
maximum time: 2.466 μs (0.00% GC) | |
-------------- | |
samples: 10000 | |
evals/sample: 10 | |
time tolerance: 5.00% | |
memory tolerance: 1.00% | |
julia> @benchmark kl_old($a1, $b1) | |
BenchmarkTools.Trial: | |
memory estimate: 240.00 bytes | |
allocs estimate: 3 | |
-------------- | |
minimum time: 620.308 ns (0.00% GC) | |
median time: 627.256 ns (0.00% GC) | |
mean time: 678.424 ns (5.15% GC) | |
maximum time: 21.057 μs (95.06% GC) | |
-------------- | |
samples: 10000 | |
evals/sample: 172 | |
time tolerance: 5.00% | |
memory tolerance: 1.00% | |
=# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment