Skip to content

Instantly share code, notes, and snippets.

@jrevels
Last active December 1, 2016 21:16
Show Gist options
  • Save jrevels/fec9f75b667eecff897d566e8133aee1 to your computer and use it in GitHub Desktop.
Save jrevels/fec9f75b667eecff897d566e8133aee1 to your computer and use it in GitHub Desktop.
using ForwardDiff
using BenchmarkTools
function gen_beta_kl_obj(alpha2, beta2)
lgamma_alpha2 = lgamma(alpha2)
lgamma_beta2 = lgamma(beta2)
return x -> begin
alpha1, beta1 = x
alpha_diff = alpha1 - alpha2
beta_diff = beta1 - beta2
both_inv_diff = -(alpha_diff + beta_diff)
di_both1 = digamma(alpha1 + beta1)
log_term = lgamma(alpha1 + beta1) - lgamma(alpha1) - lgamma(beta1)
log_term -= lgamma(alpha2 + beta2) - lgamma_alpha2 - lgamma_beta2
apart_term = alpha_diff * digamma(alpha1) + beta_diff * digamma(beta1)
together_term = both_inv_diff * di_both1
return log_term + apart_term + together_term
end
end
function gen_beta_kl_diff(a2, b2)
kl_obj = gen_beta_kl_obj(a2, b2)
kl_input = zeros(2)
result = DiffBase.HessianResult(kl_input)
cfg = ForwardDiff.HessianConfig(result, kl_input)
return (a1, b1) -> begin
kl_input[1] = a1
kl_input[2] = b1
ForwardDiff.hessian!(result, kl_obj, kl_input, cfg)
return result
end
end
function gen_beta_kl_old{NumType <: Number}(alpha2::NumType, beta2::NumType)
lgamma_alpha2 = lgamma(alpha2)
lgamma_beta2 = lgamma(beta2)
function this_beta_kl{NumType2 <: Number}(
alpha1::NumType2, beta1::NumType2)
alpha_diff = alpha1 - alpha2
beta_diff = beta1 - beta2
both_inv_diff = -(alpha_diff + beta_diff)
di_both1 = digamma(alpha1 + beta1)
log_term = lgamma(alpha1 + beta1) - lgamma(alpha1) - lgamma(beta1)
log_term -= lgamma(alpha2 + beta2) - lgamma_alpha2 - lgamma_beta2
apart_term = alpha_diff * digamma(alpha1) + beta_diff * digamma(beta1)
together_term = both_inv_diff * di_both1
kl = log_term + apart_term + together_term
grad = zeros(NumType2, 2)
hess = zeros(NumType2, 2, 2)
trigamma_alpha1 = trigamma(alpha1)
trigamma_beta1 = trigamma(beta1)
trigamma_both = trigamma(alpha1 + beta1)
grad[1] = alpha_diff * trigamma_alpha1 + both_inv_diff * trigamma_both
grad[2] = beta_diff * trigamma_beta1 + both_inv_diff * trigamma_both
quadgamma_both = polygamma(2, alpha1 + beta1)
hess[1, 1] = alpha_diff * polygamma(2, alpha1) +
both_inv_diff * quadgamma_both +
trigamma_alpha1 - trigamma_both
hess[2, 2] = beta_diff * polygamma(2, beta1) +
both_inv_diff * quadgamma_both +
trigamma_beta1 - trigamma_both
hess[1, 2] = hess[2, 1] =
-trigamma_both + both_inv_diff * quadgamma_both
return kl, grad, hess
end
end
const a1 = 4.1
const b1 = 3.9
const a2 = 3.5
const b2 = 4.3
const kl_obj_input = [a1, b1]
const kl_obj = gen_beta_kl_obj(a2, b2)
const kl_new = gen_beta_kl_diff(a2, b2)
const kl_old = gen_beta_kl_old(a2, b2)
#=
julia> @benchmark kl_obj($(kl_obj_input))
BenchmarkTools.Trial:
memory estimate: 0.00 bytes
allocs estimate: 0
--------------
minimum time: 204.257 ns (0.00% GC)
median time: 204.450 ns (0.00% GC)
mean time: 204.485 ns (0.00% GC)
maximum time: 234.630 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 575
time tolerance: 5.00%
memory tolerance: 1.00%
julia> @benchmark kl_new($a1, $b1)
BenchmarkTools.Trial:
memory estimate: 128.00 bytes
allocs estimate: 3
--------------
minimum time: 1.108 μs (0.00% GC)
median time: 1.121 μs (0.00% GC)
mean time: 1.129 μs (0.00% GC)
maximum time: 2.466 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 10
time tolerance: 5.00%
memory tolerance: 1.00%
julia> @benchmark kl_old($a1, $b1)
BenchmarkTools.Trial:
memory estimate: 240.00 bytes
allocs estimate: 3
--------------
minimum time: 620.308 ns (0.00% GC)
median time: 627.256 ns (0.00% GC)
mean time: 678.424 ns (5.15% GC)
maximum time: 21.057 μs (95.06% GC)
--------------
samples: 10000
evals/sample: 172
time tolerance: 5.00%
memory tolerance: 1.00%
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment