Skip to content

Instantly share code, notes, and snippets.

@armanbilge
Last active August 29, 2015 14:27
Show Gist options
  • Save armanbilge/3c4113e1c1e6b42fb3b8 to your computer and use it in GitHub Desktop.
Save armanbilge/3c4113e1c1e6b42fb3b8 to your computer and use it in GitHub Desktop.
using Distributions
using Iterators
using StatsBase
type Particle{T}
value::T
weight::Float64
end
Particle{T}(value::T) = Particle(value, 1.0)
function resetweight(p::Particle)
Particle(deepcopy(p.value))
end
function multiplyweight!(p::Particle, w::Float64)
p.weight *= w
end
type Trajectory
n::Int
t::Float64
end
Trajectory(t::Float64) = Trajectory(1, t)
function increment!(E::Trajectory)
E.n += 1
end
function decrement!(E::Trajectory)
E.n -= 1
end
function forward!(E::Trajectory, t::Float64)
E.t -= t
end
type ColoredTrajectory
n::Array{Int,1}
N::Int
l::Array{Int,1}
t::Float64
end
ColoredTrajectory(t::Float64, n::Int) = ColoredTrajectory(zeros(Int, n), 0, zeros(Int, 0), t)
countcolors(E::ColoredTrajectory) = length(E.n)
lineagecount(E::ColoredTrajectory) = length(E.l)
lineagecount(E::ColoredTrajectory, color::Int) = count(c -> c == color, E.l)
lineages(E::ColoredTrajectory, color::Int) = filter(i -> E.l[i] == color, 1:lineagecount(E))
function increment!(E::ColoredTrajectory, c::Int)
E.n[c] += 1
E.N += 1
end
function decrement!(E::ColoredTrajectory, c::Int)
E.n[c] -= 1
E.N -= 1
end
function forward!(E::ColoredTrajectory, t::Float64)
E.t -= t
end
function newlineage!(E::ColoredTrajectory, c::Int)
push!(E.l, c)
E.n[c] += 1
E.N += 1
end
type SEISTrajectory
N::Int
S::Int
E::Int
I::Int
l::Array{Int,1}
t::Float64
end
SEISTrajectory(S::Int, t::Float64) = SEISTrajectory(S, S, 0, 0, zeros(Int, 0), t)
lineagecount(E::SEISTrajectory) = length(E.l)
lineagecount(E::SEISTrajectory, color::Int) = count(c -> c == color, E.l)
lineages(E::SEISTrajectory, color::Int) = filter(i -> E.l[i] == color, 1:lineagecount(E))
function incrementexposed!(E::SEISTrajectory)
E.S -= 1
E.E += 1
end
function incrementinfected!(E::SEISTrajectory)
E.E -= 1
E.I += 1
end
function incrementsusceptible!(E::SEISTrajectory)
E.I -= 1
E.S += 1
end
function forward!(E::SEISTrajectory, t::Float64)
E.t -= t
end
function newlineage!(E::SEISTrajectoryt)
push!(E.l, 1)
incrementexposed!(E)
end
type NewColoredTrajectory
n::Array{Int,2}
m::Array{Int,1}
N::Int
t::Float64
end
NewColoredTrajectory(t::Float64, n::Int) = NewColoredTrajectory(zeros(Int, 2, n), zeros(Int, n), 0, t)
countcolors(E::NewColoredTrajectory) = size(E.n)[2]
lineagecount(E::NewColoredTrajectory) = size(E.n)[1]
function increment!(E::NewColoredTrajectory, l::Int, c::Int)
E.n[l][c] += 1
E.m[l] += 1
E.N += 1
end
function decrement!(E::NewColoredTrajectory, l::Int, c::Int)
E.n[l][c] -= 1
E.m[l] -= 1
E.N -= 1
end
function forward!(E::NewColoredTrajectory, t::Float64)
E.t -= t
end
type BirthDeathSamplingModel
lambda::Float64
mu::Float64
rho::Float64
end
type ColoredBirthDeathSamplingModel
lambda::Float64
tau::Float64
mu::Float64
rho::Float64
end
type SEISModel
N::Int
beta::Float64
gamma::Float64
mu::Float64
psi::Float64
end
type Node
label::String
height::Float64
parent::Node
left::Node
right::Node
color::Int
id::Int
Node(label::String, height::Float64, color::Int) = (x = new(); x.label = label; x.height = height; x.color = color; x)
Node(height::Float64, color::Int) = (x = new(); x.height = height; x.color = color; x)
end
function setchildren!(p::Node, l::Node, r::Node)
p.left = l
p.right = r
l.parent = p
r.parent = p
end
type Tree
root::Node
origin::Float64
color::Int
end
isroot(node::Node) = !isdefined(node, :parent)
isleaf(node::Node) = !isinternal(node)
isinternal(node::Node) = isdefined(node, :left)
function preorder(node::Node)
function iter()
produce(node)
if !isleaf(node)
for n in preorder(node.left)
produce(n)
end
for n in preorder(node.right)
produce(n)
end
end
end
Task(iter)
end
preorder(tree::Tree) = preorder(tree.root)
internalnodes(tree::Tree) = filter(isinternal, preorder(tree))
leaves(tree::Tree) = filter(isleaf, preorder(tree))
leafcount(tree::Tree, c::Int) = count(n -> n.color == c, leaves(tree))
function simulate!(E::Trajectory, bds::BirthDeathSamplingModel, stop::Float64, Ti::Int)
TiC2 = binomial(Ti, 2)
p = 1.0
nu = bds.lambda + bds.mu
birthdist = Bernoulli(bds.lambda / nu)
t = rand(Exponential(1/(E.n * nu)))
while E.t - t > stop
b = rand(birthdist)
forward!(E, t)
if b == 1
increment!(E)
EnC2 = binomial(E.n, 2)
q = (EnC2 - TiC2) / EnC2
if q <= 0
return 0.0
else
p *= q
end
else
decrement!(E)
end
t = rand(Exponential(1/(E.n * nu)))
end
p
end
function simulate!(E::ColoredTrajectory, cbds::ColoredBirthDeathSamplingModel, stop::Float64)
p = 1.0
eventdist = WeightVec([cbds.lambda * (1 - cbds.tau), cbds.lambda * cbds.tau, cbds.mu])
nu = sum(eventdist)
t = rand(Exponential(1/(E.N * nu)))
while E.t - t > stop
event = sample(eventdist)
color = sample(WeightVec(E.n))
lineageaffected = rand(Bernoulli(lineagecount(E, color) / E.n[color]))
forward!(E, t)
if event == 1 # Birth, no transfer
increment!(E, color)
p *= lineageaffected + 1
elseif event == 2 # Birth, transfer
destination = sample(filter(c -> c != color, 1:countcolors(E)))
increment!(E, destination)
if lineageaffected == 1
affectedlineage = sample(lineages(E, color))
E.l[affectedlineage] = destination
end
p *= lineageaffected + 1
else # Death
if lineageaffected == 1
return 0.0
end
decrement!(E, color)
end
t = rand(Exponential(1/(E.N * nu)))
end
p
end
function simulate!(E::SEISTrajectory, seis::SEISModel, stop::Float64)
p = 1.0
eventdist = WeightVec([seis.beta * E.S * E.I / E.N, seis.gamma * E.E, (seis.mu + seis.psi) * E.I])
nu = sum(eventdist)
t = rand(Exponential(1/nu)
while E.t - t > stop
event = sample(eventdist)
forward!(E, t)
if event == 1 # S + I -> E + I
lineageaffected = rand(Bernoulli(lineagecount(E, 2) / E.I))
if lineageaffected == 1 && randbool()
affected = sample(lineages(E, 2))
E.l[affected] = 1
end
p *= lineageaffected + 1
incrementexposed!(E)
elseif event == 2 # E -> I
lineageaffected = rand(Bernoulli(lineagecount(E, 1) / E.E))
p *= lineageaffected + 1
if lineageaffected == 1
affected = sample(lineages(E, 1))
E.l[affected] = 2
end
incrementinfected!(E)
else # I -> S
lineageaffected = rand(Bernoulli(lineagecount(E, 2) / E.I))
if lineageaffected == 1
return 0.0
end
incrementsusceptible!(E)
end
eventdist = WeightVec([cbds.lambda * (1 - cbds.tau), cbds.lambda * cbds.tau, cbds.mu])
nu = sum(eventdist)
t = rand(Exponential(1/nu))
end
p
end
function simulate!(E::NewColoredTrajectory, cbds::ColoredBirthDeathSamplingModel, stop::Float64, Ti::Int)
TiC2 = binomial(Ti, 2)
p = 1.0
eventdist = WeightVec([cbds.lambda * (1 - cbds.tau), cbds.lambda * cbds.tau, cbds.mu])
nu = sum(eventdist)
t = rand(Exponential(1/(E.N * nu)))
while E.t - t > stop
event = sample(eventdist)
lineage = sample(E.m)
color = sample(WeightVec(E.n[lineage]))
forward!(E, t)
if event == 1 # Birth, no transfer
increment!(E, lineage, color)
p *= lineageaffected + 1
elseif event == 2 # Birth, transfer
destination = sample(filter(c -> c != color, 1:countcolors(E)))
increment!(E, lineage, destination)
# p *= lineageaffected + 1
else # Death
decrement!(E, lineage, color)
if E.n[color] == 0
return 0.0
end
end
t = rand(Exponential(1/(E.N * nu)))
end
p
end
p1(t::Float64, l::Float64, m::Float64, rho::Float64) = rho*(l-m)^2 * exp(-(l-m)*t)/(rho*l+(l*(1-rho)-m)*exp(-(l-m)*t))^2
function f(x::AbstractArray{Float64,1}, tau::Float64, lambda::Float64, mu::Float64, rho::Float64)
log(p1(tau, lambda, mu, rho)) + sum(map(xi -> log(lambda * p1(xi, lambda, mu, rho)), x))
end
function f_hat(x::AbstractArray{Float64,1}, tau::Float64, lambda::Float64, mu::Float64, rho::Float64, N::Int=100000)
n = length(x) + 1
sort!(x, rev=true)
bds = BirthDeathSamplingModel(lambda, mu, rho)
particles = collect(repeatedly(() -> Particle(Trajectory(tau)), N))
f = 0
for (Ti, xi) in zip(1:(n-1), x)
for p in particles
multiplyweight!(p, simulate!(p.value, bds, xi, Ti))
p.value.t = xi
increment!(p.value)
q = binomial(p.value.n, 2)
if q > 0
multiplyweight!(p, (p.value.n - 1) * lambda / q)
else
multiplyweight!(p, 0.0)
end
end
w = Float64[p.weight for p in particles]
f += log(mean(w))
if f == -Inf
return -Inf
end
particles = collect(map(resetweight, sample(particles, WeightVec(w), N)))
end
for p in particles
multiplyweight!(p, simulate!(p.value, bds, 0.0, n))
p.value.t = 0
multiplyweight!(p, pdf(Binomial(p.value.n, rho), n))
end
w = Float64[p.weight for p in particles]
f += log(mean(w))
f -= log(2) * (n-1)
f += lfact(n)
f
end
function f_hat(tree::Tree, cbds::ColoredBirthDeathSamplingModel, colors::Int, N::Int=10000)
nodes = collect(internalnodes(tree))
sort!(nodes, by=n -> n.height, rev=true)
n = length(nodes) + 1
eventdist = Bernoulli(1 - cbds.tau)
particles = collect(repeatedly(() -> Particle(ColoredTrajectory(tree.origin, colors)), N))
f = 0
tree.root.id = 1
for p in particles
newlineage!(p.value, tree.color)
end
for node in nodes
for p in particles
multiplyweight!(p, simulate!(p.value, cbds, node.height))
p.value.t = node.height
if p.value.l[node.id] != node.color
multiplyweight!(p, 0.0)
else
if rand(eventdist) == 1
newlineage!(p.value, node.color)
multiplyweight!(p, cbds.lambda)
else
destination = sample(filter(c -> c != node.color, 1:countcolors(p.value)))
newlineage!(p.value, destination)
multiplyweight!(p, cbds.lambda)
end
if randbool()
node.left.id = node.id
node.right.id = length(p.value.l)
else
node.left.id = length(p.value.l)
node.right.id = node.id
end
end
end
w = Float64[p.weight for p in particles]
@show mean(w)
f += log(mean(w))
if f == -Inf
return -Inf, 0
end
particles = collect(map(resetweight, sample(particles, WeightVec(w), N)))
end
w = Float64[p.weight for p in particles]
@show mean(w)
for p in particles
multiplyweight!(p, simulate!(p.value, cbds, 0.0))
p.value.t = 0
end
w = Float64[p.weight for p in particles]
@show mean(w)
for p in particles
for l in leaves(tree)
if p.value.l[l.id] != l.color
multiplyweight!(p, 0.0)
break
end
end
if p.weight > 0
for c = 1:colors
lc = leafcount(tree, c)
multiplyweight!(p, cbds.rho ^ lc * (1 - cbds.rho) ^ (p.value.n[c] - lc))
end
end
end
w = Float64[p.weight for p in particles]
@show mean(w)
f += log(mean(w))
f, effective(w)
end
function f_hat(tree::Tree, seis::SEISModel, N::Int=10000)
nodes = collect(internalnodes(tree))
sort!(nodes, by=n -> n.height, rev=true)
n = length(nodes) + 1
particles = collect(repeatedly(() -> Particle(SEISTrajectory(seis.N, tree.origin)), N))
f = 0
tree.root.id = 1
for p in particles
newlineage!(p.value, tree.color)
end
for node in nodes
for p in particles
multiplyweight!(p, simulate!(p.value, seis, node.height))
p.value.t = node.height
if p.value.l[node.id] != node.color
multiplyweight!(p, 0.0)
else
multiplyweight!(p, seis.beta * / seis.N)
newlineage!(p.value)
if randbool()
node.left.id = node.id
node.right.id = length(p.value.l)
else
node.left.id = length(p.value.l)
node.right.id = node.id
end
end
end
w = Float64[p.weight for p in particles]
f += log(mean(w))
if f == -Inf
return -Inf, 0
end
particles = collect(map(resetweight, sample(particles, WeightVec(w), N)))
end
w = Float64[p.weight for p in particles]
for p in particles
multiplyweight!(p, simulate!(p.value, cbds, 0.0))
p.value.t = 0
end
w = Float64[p.weight for p in particles]
for p in particles
for l in leaves(tree)
if p.value.l[l.id] != l.color
multiplyweight!(p, 0.0)
break
end
end
if p.weight > 0
multiplyweight!(p, seis.rho ^ n * (1 - seis.rho) ^ (p.value.I - n))
end
end
w = Float64[p.weight for p in particles]
f += log(mean(w))
f#, effective(w)
end
effective(weight::Array{Float64,1}) = exp(entropy(weight))
function entropy(weight::Array{Float64,1})
nw = Float64[]
for w in weight
if w > 0.0
push!(nw, w)
end
end
if length(nw) == 0
return -Inf
end
nw = nw / sum(nw)
-sum(log(nw) .* nw)
end
logsum(x::AbstractArray{Float64}) = x[1] + log(sum(exp(x - x[1])))
function main()
# for x in 1/16:1/16:3
# args = (Float64[0.5, 1.0, float(e), float(pi)], 5.0, x, 0.3, 0.6)
# a = f_hat(args..., 100000)
# b = f(args...)
# println(x, "\t", a, "\t", x, "\t", b, "\t", exp(b-a))
# end
nc = 2
node = [Node(0.0, 1), Node(0.0, 1), Node(0.0, 1), Node(0.0, 1), Node(0.0, 1), Node(1.0, 1), Node(2.0, 1), Node(float(e), 1), Node(float(pi), 1)]
setchildren!(node[6], node[1], node[2])
setchildren!(node[7], node[4], node[5])
setchildren!(node[8], node[6], node[3])
setchildren!(node[9], node[7], node[8])
tree = Tree(node[9], 5.0, 1)
# node = [Node(0.0, 1), Node(0.0, 1), Node(1.0, 1)]
# setchildren!(node[3], node[1], node[2])
# tree = Tree(node[end], 2.0, 1)
# node = [Node(0.0, 1), Node(0.0, 1), Node(0.0, 1), Node(1.0, 1), Node(float(e), 1)]
# setchildren!(node[4], node[1], node[2])
# setchildren!(node[5], node[3], node[4])
# tree = Tree(node[end], float(pi), 1)
node = [Node(0.0, 1)]
tree = Tree(node[1], float(pi), 1)
f_hat(tree, ColoredBirthDeathSamplingModel(0.5, 0, 0, 0.1), 1)
# for x in 1/16:1/16:3
# mu = 0.1
# rho = 0.5
# cbds = ColoredBirthDeathSamplingModel(x, 0.25, mu, rho)
# as = Float64[]
# for color in product(1:nc, 1:nc, 1:nc)
# for (c, n) in zip(color, node)
# n.color = c
# end
# push!(as, f_hat(tree, cbds, nc, 10000))
# end
# a = logsum(as)
# b = f(Float64[1.0], 2.0, x, mu, rho)
# # c = f_hat(Float64[], 1.0, x, mu, rho, 1000)
# println(x, "\t", a, "\t", b, "\t", exp(a-b))
# end
# lambda = 1.0
# mu = 0.5
# rho = 0.8
# cbds = ColoredBirthDeathSamplingModel(lambda, 1/3, mu, rho)
# truth = f(Float64[1.0], 2.0, lambda, mu, rho)
# f(Float64[1.0, float(e)], float(pi), lambda, mu, rho)
# println("# ", truth)
# println("10 20 50 100 200 500 1000")
# # println("500 1000 2000 4000")
# for _ in 1:100
# for N in [10, 20, 50, 100, 200, 500, 1000]
# # for N in [500, 1000, 2000, 4000]
# as = Float64[]
# # for color in product(1:nc, 1:nc, 1:nc, 1:nc, 1:nc)
# for color in product(1:nc, 1:nc, 1:nc)
# for (c, n) in zip(color, node)
# n.color = c
# end
# x, y = f_hat(tree, cbds, nc, N)
# push!(as, x)
# end
# a = logsum(as)
# print(a, " ")
# end
# println()
# end
# N = 1000
# for _ in 1:100
# x, y = f_hat(tree, cbds, nc, N)
# xs = Float64[]
# ys = Float64[]
# for color in product(1:nc, 1:nc, 1:nc)
# for (c, n) in zip(color, node)
# n.color = c
# end
# x, y = f_hat(tree, cbds, nc, N)
# push!(xs, x)
# push!(ys, y)
# end
# println(sum(ys), " ", exp(logsum(xs) - truth))
# end
# for N in [10, 20, 50, 100, 200, 500, 1000]
# X = Float64[]
# Y = Float64[]
# for _ in 1:100
# xs = Float64[]
# ys = Float64[]
# for color in product(1:nc, 1:nc, 1:nc)
# for (c, n) in zip(color, node)
# n.color = c
# end
# x, y = f_hat(tree, cbds, nc, N)
# push!(xs, x)
# push!(ys, y)
# end
# push!(X, logsum(xs))
# push!(Y, logsum(ys))
# end
# println(mean(Y), " ", mean(exp(X - truth)), )
# end
# println(f(Float64[], 0.25, 0.0, 0.0, 0.8))
# println(f_hat(Float64[], 0.26, 0.0, 0.0, 1.0))
end
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment