Skip to content

Instantly share code, notes, and snippets.

@niklasschmitz
Last active August 24, 2021 17:52
Show Gist options
  • Save niklasschmitz/87e3382b20082e90a5d3f121a0d97443 to your computer and use it in GitHub Desktop.
Save niklasschmitz/87e3382b20082e90a5d3f121a0d97443 to your computer and use it in GitHub Desktop.
Kernel regression in dex-lang
import plot
-- Conjugate gradients solver
def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float =
x0 = for i:m. 0.0
ax = mat **. x0
r0 = b - ax
(xOut, _, _) = fold (x0, r0, r0) $
\s:m (x, r, p).
ap = mat **. p
alpha = vdot r r / vdot p ap
x' = x + alpha .* p
r' = r - alpha .* ap
beta = vdot r' r' / (vdot r r + 0.000001)
p' = r' + beta .* p
(x', r', p')
xOut
' # Kernel ridge regression
' To learn a function $f_{true}: \mathcal{X} \to \mathbb R$
from data $(x_1, y_1),\dots,(x_N, y_N)\in \mathcal{X}\times\mathbb R$,\
in kernel ridge regression the hypothesis takes the form
$f(x)=\sum_{i=1}^N \alpha_i k(x_i, x)$,\
where $k:\mathcal X \times \mathcal X \to \mathbb R$ is a positive semidefinite kernel function.\
The optimal coefficients are found by solving a linear system $\alpha=G^{-1}y$,\
where $G_{ij}:=k(x_i, x_j)$ and $y = (y_1,\dots,y_N)^\top\in\mathbb R^N$
-- Synthetic data
Nx = Fin 100
noise = 0.1
[k1, k2] = splitKey (newKey 0)
def trueFun (x:Float) : Float =
x + sin (20.0 * x)
xs : Nx=>Float = for i. rand (ixkey k1 i)
ys : Nx=>Float = for i. trueFun xs.i + noise * randn (ixkey k2 i)
-- Kernel ridge regression
def regress (kernel: a -> a -> Float) (xs: Nx=>a) (ys: Nx=>Float) : a->Float =
gram = for i j. kernel xs.i xs.j + select (i==j) 0.0001 0.0
alpha = solve gram ys
predict = \x. sum for i. alpha.i * kernel xs.i x
predict
def rbf (lengthscale:Float) (x:Float) (y:Float) : Float =
exp (-0.5 * sq ((x - y) / lengthscale))
predict = regress (rbf 0.2) xs ys
-- Evaluation
Nxtest = Fin 1000
xtest : Nxtest=>Float = for i. rand (ixkey k1 i)
preds = map predict xtest
:html showPlot $ xyPlot xs ys
:html showPlot $ xyPlot xtest preds
@niklasschmitz
Copy link
Author

niklasschmitz commented May 19, 2021

A kernelized version of https://github.com/google-research/dex-lang/blob/68e07022921d898a3962337ae3154b2c101f4b48/examples/regression.dx;
The kernel matrix gram is written as an array comprehension.

EDIT: this example is now upstreamed to the main repo via google-research/dex-lang#541

Check out dex-lang for more!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment