Last active
August 24, 2021 17:52
-
-
Save niklasschmitz/87e3382b20082e90a5d3f121a0d97443 to your computer and use it in GitHub Desktop.
Kernel regression in dex-lang
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import plot | |
-- Conjugate gradients solver | |
def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float = | |
x0 = for i:m. 0.0 | |
ax = mat **. x0 | |
r0 = b - ax | |
(xOut, _, _) = fold (x0, r0, r0) $ | |
\s:m (x, r, p). | |
ap = mat **. p | |
alpha = vdot r r / vdot p ap | |
x' = x + alpha .* p | |
r' = r - alpha .* ap | |
beta = vdot r' r' / (vdot r r + 0.000001) | |
p' = r' + beta .* p | |
(x', r', p') | |
xOut | |
' # Kernel ridge regression | |
' To learn a function $f_{true}: \mathcal{X} \to \mathbb R$ | |
from data $(x_1, y_1),\dots,(x_N, y_N)\in \mathcal{X}\times\mathbb R$,\ | |
in kernel ridge regression the hypothesis takes the form | |
$f(x)=\sum_{i=1}^N \alpha_i k(x_i, x)$,\ | |
where $k:\mathcal X \times \mathcal X \to \mathbb R$ is a positive semidefinite kernel function.\ | |
The optimal coefficients are found by solving a linear system $\alpha=G^{-1}y$,\ | |
where $G_{ij}:=k(x_i, x_j)$ and $y = (y_1,\dots,y_N)^\top\in\mathbb R^N$ | |
-- Synthetic data | |
Nx = Fin 100 | |
noise = 0.1 | |
[k1, k2] = splitKey (newKey 0) | |
def trueFun (x:Float) : Float = | |
x + sin (20.0 * x) | |
xs : Nx=>Float = for i. rand (ixkey k1 i) | |
ys : Nx=>Float = for i. trueFun xs.i + noise * randn (ixkey k2 i) | |
-- Kernel ridge regression | |
def regress (kernel: a -> a -> Float) (xs: Nx=>a) (ys: Nx=>Float) : a->Float = | |
gram = for i j. kernel xs.i xs.j + select (i==j) 0.0001 0.0 | |
alpha = solve gram ys | |
predict = \x. sum for i. alpha.i * kernel xs.i x | |
predict | |
def rbf (lengthscale:Float) (x:Float) (y:Float) : Float = | |
exp (-0.5 * sq ((x - y) / lengthscale)) | |
predict = regress (rbf 0.2) xs ys | |
-- Evaluation | |
Nxtest = Fin 1000 | |
xtest : Nxtest=>Float = for i. rand (ixkey k1 i) | |
preds = map predict xtest | |
:html showPlot $ xyPlot xs ys | |
:html showPlot $ xyPlot xtest preds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A kernelized version of https://github.com/google-research/dex-lang/blob/68e07022921d898a3962337ae3154b2c101f4b48/examples/regression.dx;
The kernel matrix
gram
is written as an array comprehension.EDIT: this example is now upstreamed to the main repo via google-research/dex-lang#541
Check out dex-lang for more!