Last active
October 15, 2022 23:17
-
-
Save slwu89/62729fe92e26e3bebf22f0fa46dbc575 to your computer and use it in GitHub Desktop.
penalized cubic splines in julia
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Random, Distributions | |
using LinearAlgebra | |
using Statistics | |
using Plots | |
using JuMP, Ipopt | |
using RCall | |
# initial fixup and optional (not here) scaling https://github.com/cran/mgcv/blob/a534170c82ce06ccd8d76b1a7d472c50a2d7bbd2/R/smooth.r#L1602 | |
# scaling here: https://github.com/cran/mgcv/blob/a534170c82ce06ccd8d76b1a7d472c50a2d7bbd2/R/smooth.r#L3757 | |
function penalty_cr(D,B) | |
# penalty matrix S | |
S = transpose(D) * inv(B) * D | |
# fixup and scale (smoothCon) | |
S = (S + transpose(S))/2 | |
# scale | |
maXX = opnorm(X, Inf)^2 | |
maS = opnorm(S, 1)/maXX | |
S = S ./ maS | |
return S | |
end | |
# make basis matrices D,B,F | |
function basis_cr(xk) | |
k = length(xk) | |
# knot spacing vector | |
h = diff(xk) | |
# create k-2 by k matrix D: D[i,i] = 1/h[i], D[i,i+1] = -1/h[i]-1/h[i+1] | |
D = zeros(Float64, k-2, k) | |
for i in axes(D,1) | |
D[i,i] = 1/h[i] | |
D[i,i+1] = -1/h[i] - 1/h[i+1] | |
D[i,i+2] = 1/h[i+1] | |
end | |
# B | |
B = zeros(Float64, k-2, k-2) | |
for i in axes(B,1) | |
B[i,i] = (h[i]+h[i+1])/3 | |
if i != size(B,1) | |
B[i,i+1] = h[i+1]/6 | |
B[i+1,i] = h[i+1]/6 | |
end | |
end | |
# F | |
F = zeros(Float64, k, k) | |
F[2:end-1,:] = inv(B) * D # B^{-1}D | |
return (D=D, B=B, F=F) | |
end | |
# model matrix X | |
function model_matrix_cr(xk, F, x = nothing) | |
F = transpose(F) | |
if isnothing(x) | |
x = xk | |
end | |
k = length(xk) | |
n = length(x) | |
# model matrix | |
X = zeros(Float64, n, k) | |
for i in 1:n | |
# find interval containing x[i] or extrapolate | |
if (x[i] < xk[1] || x[i] > xk[end]) | |
# extrapolate | |
if x[i] < xk[1] | |
# below knot range | |
hj = xk[2] - xk[1] | |
xik = x[i] - xk[1] | |
cjm = -xik*hj/3 | |
cjp = -xik*hj/6; | |
# set model matrix elements | |
X[i,:] = (cjm * F[:,1]) + (cjp * F[:,2]) | |
X[i,1] += 1 - xik/hj | |
X[i,2] += xik/hj | |
else | |
# above knot range | |
hj = xk[end] - xk[end-1] | |
xik = x[i] - xk[end] | |
cjm = xik*hj/6 | |
cjp = xik*hj/3 | |
X[i,:] = (cjm * F[:,end-1]) + (cjp * F[:,end]) | |
X[i,end-1] += -xik/hj | |
X[i,end] += 1 + xik/hj | |
end | |
else | |
# inside knot sequence | |
j = searchsortedlast(xk, x[i]) | |
j = min(j,k-1) # j is the start of the interval, so cannot be greater than k-1 | |
# if j == k | |
# j -= 1 | |
# end | |
hj = xk[j+1] - xk[j] # interval width | |
ajm = (xk[j+1] - x[i]) # basis fn's; table 5.1 in GAM book | |
ajp = (x[i] - xk[j]) | |
cjm = (ajm*(ajm*ajm/hj - hj))/6 | |
cjp = (ajp*(ajp*ajp/hj - hj))/6 | |
ajm /= hj | |
ajp /= hj | |
# set i-th row of X | |
X[i,:] = (cjm * F[:,j]) + (cjp * F[:,j+1]) | |
X[i,j] += ajm | |
X[i,j+1] += ajp | |
end | |
end | |
return X | |
end | |
# true function and sampled data points | |
x = rand(100) .* 4 .- 1 | |
sort!(x) | |
f = exp.(4 .* x) ./ (1 .+ exp.(4 .* x)) | |
y = f .+ rand(Normal(), 100) .* 0.1 | |
n = length(y) | |
# 10 knots, set up knot range | |
k = 10 | |
xk = quantile(x, range(0,1, length=k)) | |
# set up CR basis and model matrix | |
D, B, F = basis_cr(xk) | |
X = model_matrix_cr(xk, F, x) | |
S = penalty_cr(D, B) | |
C = mean(X, dims=1) | |
W = Diagonal(ones(n)) | |
λ = 0.5 # completely arbitrary value | |
# use JuMP to fit the model (constrained QP problem) | |
spline_fit = Model(Ipopt.Optimizer) | |
@variable(spline_fit, β[1:k]) | |
@objective(spline_fit, Min, (transpose(X*β - y) * W * (X*β - y)) + (λ * transpose(β) * S * β)) | |
optimize!(spline_fit) | |
# test it out in R | |
@rput x | |
@rput y | |
R""" | |
library(mgcv) | |
dat <- data.frame(x=x) | |
sm <- smoothCon(s(x,k=10,bs="cr"),dat,knots=NULL)[[1]] | |
sm_S <- sm$S[[1]] | |
sm_X <- sm$X | |
sm_F <- matrix(sm$F,10,10) | |
sm_C <- sm$C | |
sm_xp <- sm$xp | |
# fit it | |
M <- list( | |
X=sm_X, | |
p=sm$xp, | |
off=array(0), | |
S=list(sm_S), | |
Ain=matrix(0,1,10), | |
bin=-2e16, | |
C=matrix(0,0,0), | |
sp=array(0.5), | |
y=y, | |
w=rep(1,length(y)) | |
) | |
pcls(M) -> beta | |
""" | |
@rget sm_X | |
@rget sm_S | |
@rget sm_C | |
@rget beta | |
@rget sm_xp | |
scatter(x,y,legend=false) | |
plot!(x,f) | |
plot!(x, X * value.(β)) | |
plot!(x, X * beta) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment