Last active
April 3, 2021 10:54
-
-
Save mschauer/6841531d0370d46c1d9ad6e23feda489 to your computer and use it in GitHub Desktop.
Linear regression, p = 5, n = 10_000_000, with subsampling and approx ML estimate as control variate
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ZigZagBoomerang | |
using StaticArrays | |
using LinearAlgebra | |
using SparseArrays | |
using Random | |
using Test | |
using Statistics | |
Random.seed!(2) | |
using StaticArrays | |
# scale ~ Exponential(λ) | |
# coefs ~ Normal() | |
# preds ~ Normal(dot(x, coefs), scale) | |
#λ = 1.0 | |
#σ = randexp()*λ | |
β = @SVector randn(5) | |
n = 10_000_000 | |
const d = 5 | |
X = randn(typeof(β), n) | |
y = dot.(X, Ref(β)) + randn(n) | |
function ∇ϕkhat(β, samples, X, y, μ, bias) | |
s = bias | |
for _ in 1:samples | |
i = rand(1:length(y)) | |
s += length(y)/samples*(-X[i]*(y[i] - dot(X[i], β))) # likelihood | |
s -= length(y)/samples*(-X[i]*(y[i] - dot(X[i], μ))) # control | |
end | |
s | |
end | |
prior(x) = x # Gaussian prior | |
function ∇ϕ!(x_, x::T, args...) where {T} | |
prior(x) + ∇ϕkhat(x, args...)::T | |
end | |
∇ϕfull(μ, X, y) = @inbounds sum(-X[i]*(y[i] - dot(X[i], μ)) for i in eachindex(y)) | |
# Look at a fraction of the data: | |
c = 50000.0 | |
X_ = reinterpret(reshape, Float64, X[1:end÷200])' | |
μ = SVector{5,Float64}(X_\(y[1:end÷200])) | |
# one look at the full gradient | |
bias = ∇ϕfull(μ, X, y) # sum(-X[i]*(y[i] - dot(X[i], μ)) for i in eachindex(y)) | |
t0 = 0.0 | |
x0 = μ | |
θ0 = @SVector randn(Float64, 5) | |
Γ = SMatrix{5, 5, Float64, 25}(((X_'*X_))) | |
Γ = SMatrix{5, 5, Float64, 25}(Diagonal(1e7*ones(5))) | |
T = 2000.0 * 1/sqrt(n) | |
BP = BouncyParticle(Γ, μ, T/100) | |
samples = 20 | |
trace, (tT, xT, θT), (acc, num), _ = pdmp(∇ϕ!, t0, x0, θ0, T/20, c, BP, samples, X, y, μ, bias, adapt=true) | |
@time trace, (tT, xT, θT), (acc, num), c = pdmp(∇ϕ!, t0, xT, θ0, T, c, BP, samples, X, y, μ, bias, adapt=true) | |
ts, xs = ZigZagBoomerang.sep(trace) | |
xsd = last.(collect(discretize(trace, 1/sqrt(n)))) | |
using Makie | |
p1 = lines(getindex.(xs,1), getindex.(xs,2), linewidth=0.4, color=ts) | |
scatter!(getindex.(xsd,1), getindex.(xsd,2)) | |
scatter!([β[1]], [β[2]], color=:red) # truth | |
scatter!([μ[1]], [μ[2]], color=:blue) # approximate mode | |
m = reinterpret(reshape, Float64, X)'\y | |
scatter!([m[1]], [m[2]], color=:orange) # analytic posterior mean | |
p1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment