Last active
November 22, 2020 16:31
-
-
Save torfjelde/0338a4dc81477c017b205c34ac2bc569 to your computer and use it in GitHub Desktop.
Implementation of Gibbs sampling for TrueSkill in Turing.jl.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> using LinearAlgebra | |
julia> using MATLAB | |
julia> ############# | |
### Setup ### | |
############# | |
data = read_matfile("tennis_data.mat"); | |
julia> G = Int.(jarray(data["G"])); | |
julia> W = string.(jvector(data["W"])); | |
julia> M = size(W, 1); | |
julia> N = size(G, 1); | |
julia> w = zeros(M); | |
julia> pv = 0.5 * ones(M); | |
julia> ################# | |
### Turing.jl ### | |
################# | |
using Turing, Random | |
julia> @model function trueskill(G, M) | |
t ~ MvNormal(ones(size(G, 1))) | |
w ~ MvNormal(ones(M)) | |
end | |
trueskill (generic function with 1 method) | |
julia> "`MvNormal` parameterized by the Cholesky of the precision matrix." | |
struct MvNormalPrecisionChol{T1, T2} <: ContinuousMultivariateDistribution | |
μ::T1 | |
L::T2 | |
end | |
MvNormalPrecisionChol | |
julia> Base.length(d::MvNormalPrecisionChol) = Base.length(μ) | |
julia> Base.size(d::MvNormalPrecisionChol) = Base.size(μ) | |
julia> function Base.rand(rng::Random.AbstractRNG, d::MvNormalPrecisionChol) | |
return d.μ + d.L.U \ randn(rng, length(d.μ)) | |
end | |
julia> "Conditional posterior of performance difference, i.e. p(t ∣ w₁, w₂, y)." | |
struct PerfDiffCondPosterior{T1, T2} <: ContinuousMultivariateDistribution | |
weights::T1 | |
G::T2 | |
end | |
PerfDiffCondPosterior | |
julia> Base.length(d::PerfDiffCondPosterior) = size(d.G, 1) | |
julia> Base.size(d::PerfDiffCondPosterior) = (Base.length(d), ) | |
julia> function Distributions.rand(rng::Random.AbstractRNG, d::PerfDiffCondPosterior) | |
N = size(d.G, 1) | |
t = zeros(N) | |
for g in 1:N | |
s = d.weights[G[g, 1]] - d.weights[G[g, 2]] | |
t[g] = s + Distributions.randn(rng) | |
while t[g] < 0 | |
t[g] = s + Distributions.randn(rng) | |
end | |
end | |
return t | |
end | |
julia> function make_conditionals(G, M, pv) | |
function cond_w(c) | |
t = c.t | |
N = length(t) | |
m = zeros(M) | |
for p in 1:M | |
for g in 1:N | |
if G[g, 1] == p | |
m[p] += t[g] | |
elseif G[g, 2] == p | |
m[p] -= t[g] | |
end | |
end | |
end | |
Σ⁻¹_likelihood = zeros(M, M) | |
for g in 1:N | |
p_left, p_right = G[g, :] | |
Σ⁻¹_likelihood[p_left, p_left] += 1 | |
Σ⁻¹_likelihood[p_right, p_right] += 1 | |
Σ⁻¹_likelihood[p_left, p_right] -= 1 | |
Σ⁻¹_likelihood[p_right, p_left] -= 1 | |
end | |
# Posterior precision matrix | |
Σ⁻¹ = Σ⁻¹_likelihood + Diagonal(1. ./ pv) | |
L = cholesky(Σ⁻¹) | |
μ = L \ m | |
return MvNormalPrecisionChol(μ, L) | |
end | |
cond_t(c) = PerfDiffCondPosterior(c.w, G) | |
return (w = cond_w, t = cond_t) | |
end | |
make_conditionals (generic function with 1 method) | |
julia> # Create the closures for the conditional distributions. | |
conds = make_conditionals(G, M, pv); | |
julia> # Test-run | |
t = rand(conds.t((w = zeros(M), ))); | |
julia> w = rand(conds.w((t = t, ))); | |
julia> #################################### | |
### Non-Turing.jl implementation ### | |
#################################### | |
num_samples = 1_000; | |
julia> skill_samples = zeros(num_samples, M); | |
julia> w = zeros(M); | |
julia> t = zeros(N); | |
julia> for i = 1:num_samples | |
t = rand(conds.t((w = w, ))) | |
w = rand(conds.w((t = t, ))) | |
skill_samples[i, :] .= w | |
end | |
julia> means = vec(mean(skill_samples; dims = 1)); | |
julia> indices = reverse(sortperm(means)); | |
julia> collect(zip(means[indices], W[indices])) | |
107-element Array{Tuple{Float64,String},1}: | |
(1.9094688313646073, "Novak-Djokovic") | |
(1.5132669604628082, "Roger-Federer") | |
(1.4869177962237277, "Rafael-Nadal") | |
(1.2919070767690484, "Andy-Murray") | |
... | |
julia> ################################ | |
### Turing.jl implementation ### | |
################################ | |
m = trueskill(G, M) | |
DynamicPPL.Model{var"#1#2",(:G, :M),(),(),Tuple{Array{Int64,2},Int64},Tuple{}}(:trueskill, var"#1#2"(), (G = [1 2; 1 3; … ; 96 105; 105 107], M = 107), NamedTuple()) | |
julia> samples = sample( | |
m, | |
Gibbs(GibbsConditional(:t, conds.t), GibbsConditional(:w, conds.w)), | |
num_samples | |
); | |
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:02 | |
julia> weights = group(samples, :w); | |
julia> means_turing = mean(weights).nt.mean; | |
julia> indices_turing = reverse(sortperm(means_turing)); | |
julia> collect(zip(means[indices], W[indices], means_turing[indices_turing], W[indices_turing])) | |
107-element Array{Tuple{Float64,String,Float64,String},1}: | |
(1.9094688313646073, "Novak-Djokovic", 1.8872984517143065, "Novak-Djokovic") | |
(1.5132669604628082, "Roger-Federer", 1.4937776943608394, "Roger-Federer") | |
(1.4869177962237277, "Rafael-Nadal", 1.4813510143847124, "Rafael-Nadal") | |
(1.2919070767690484, "Andy-Murray", 1.2555192408353655, "Andy-Murray") | |
... | |
julia> println("Weights of top 4 players: $(W[indices_turing[1:4]])") | |
Weights of top 4 players: ["Novak-Djokovic", "Roger-Federer", "Rafael-Nadal", "Andy-Murray"] | |
julia> weights[:, indices_turing[1:4], :] | |
Chains MCMC chain (1000×4×1 Array{Float64,3}): | |
Iterations = 1:1000 | |
Thinning interval = 1 | |
Chains = 1 | |
Samples per chain = 1000 | |
parameters = w[1], w[5], w[11], w[16] | |
internals = | |
Summary Statistics | |
parameters mean std naive_se mcse ess rhat | |
Symbol Float64 Float64 Float64 Float64 Float64 Float64 | |
w[16] 1.8873 0.2208 0.0070 0.0120 219.2681 1.0020 | |
w[5] 1.4938 0.2079 0.0066 0.0101 325.5945 0.9998 | |
w[1] 1.4814 0.1951 0.0062 0.0068 292.2588 0.9990 | |
w[11] 1.2555 0.1995 0.0063 0.0109 244.1839 1.0059 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
w[16] 1.4727 1.7368 1.8792 2.0322 2.3561 | |
w[5] 1.1111 1.3485 1.4862 1.6349 1.9075 | |
w[1] 1.0840 1.3492 1.4862 1.6063 1.8483 | |
w[11] 0.8805 1.1209 1.2533 1.3922 1.6365 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment