Created
October 4, 2017 05:25
-
-
Save simonbyrne/cdd6a6b44f06036dee75e19db5de66ca to your computer and use it in GitHub Desktop.
Computing Fisher information via forward-mode automatic differentiation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions | |
import ForwardDiff: Dual, value, partials | |
@generated function get_values(a::NTuple{N}) where {N} | |
return ForwardDiff.tupexpr(i -> :(value(a[$i])),N) | |
end | |
ForwardDiff.value(p::ForwardDiff.Partials) = | |
ForwardDiff.Partials(get_values(p.values)) | |
fisherinf(f, x) = ForwardDiff.hessian(y -> EI(f(y)), x) | |
# should probably check we're getting the right type... | |
# "expected information" operator | |
# - value: technically entropy, but with an arbitrary constant, so choose 0 | |
# - gradient: 0 (expected score) | |
# - hessian: Fisher information | |
function EI(d::Normal{Dual{A,Dual{B,T,N},M}}) where {A,B,M,N,T} | |
vσ = value(value(d.σ)) | |
Iμμ = 1/vσ^2 | |
Iσσ = 2/vσ^2 | |
Dual{A}(zero(value(d.μ)), # Dual{B,T.N} | |
value(partials(d.μ)) * Dual{B}(zero(T), Iμμ*partials(value(d.μ))) + | |
value(partials(d.σ)) * Dual{B}(zero(T), Iσσ*partials(value(d.σ)))) | |
end | |
fisherinf(x -> Normal(x[1], x[2]), [0.0, 1.2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment