Created
April 27, 2020 21:16
-
-
Save Vindaar/ca092e987f595359c5b88742f8c69a74 to your computer and use it in GitHub Desktop.
KDE in Nim for `geom_density`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ggplotnim, seqmath, sequtils, stats, algorithm, strutils | |
type | |
KernelKind = enum | |
knCustom = "custom" | |
knBox = "box" | |
knTriangular = "triangular" | |
knTrig = "trigonometric" | |
knEpanechnikov = "epanechnikov" | |
knGauss = "gauss" | |
KernelFunc = proc(x, x_i, bw: float): float | |
proc boxKernel(x, x_i, bw: float): float = | |
## provides a box kernel | |
result = if abs((x - x_i) / bw) < 0.5: 1.0 else: 0.0 | |
proc triangularKernel(x, x_i, bw: float): float = | |
## provides a triangular kernel | |
let val = abs(x - x_i) / bw | |
result = if val < 1.0: 1.0 - val else: 0.0 | |
proc trigonometricKernel(x, x_i, bw: float): float = | |
## provides a trigonometric kernel | |
let val = abs(x - x_i) / bw | |
result = if val < 0.5: 1.0 + cos(2 * PI * val) else: 0.0 | |
proc epanechnikovKernel(x, x_i, bw: float): float = | |
## provides an Epanechnikov kernel | |
let val = abs(x - x_i) / bw | |
result = if val < 1.0: 3.0 / 4.0 * (1 - val * val) else: 0.0 | |
proc gaussKernel(x, x_i, bw: float): float = | |
## provides a gaussian kernel | |
result = gauss(x = x, mean = x_i, sigma = bw) | |
proc getCutoff(bw: float, kind: KernelKind): float = | |
## calculates a reasonable cutoff for a given `KernelKind` for a bandwidth `bw` | |
case kind | |
of knBox: result = 0.5 * bw | |
of knTriangular: result = bw | |
of knTrig: result = 0.5 * bw | |
of knEpanechnikov: result = bw | |
of knGauss: result = 3.0 * bw # 3 sigma amounts to 99.7% of gaussian kernel contribution | |
of knCustom: doAssert false, "`getCutoff` must not be called with custom kernel!" | |
proc getKernelFunc(kind: KernelKind): KernelFunc = | |
case kind | |
of knBox: result = boxKernel | |
of knTriangular: result = triangularKernel | |
of knTrig: result = trigonometricKernel | |
of knEpanechnikov: result = epanechnikovKernel | |
of knGauss: result = gaussKernel | |
of knCustom: doAssert false, "`getKernelFunc` must not be called with custom kernel!" | |
func iqr[T](s: openArray[T]): float = | |
## returns the interquartile range of `s`. | |
## The interquartile range (IQR) is the distance between the | |
## 25th and 75th percentile | |
result = percentile(s, 75) - percentile(s, 25) | |
proc findWindow[T](dist: T, s: T, x: seq[T], oldStart = 0, oldStop = 0): (int, int) = | |
## returns the index (start or stop) given a distance `dist` that | |
## is allowed between s - x[j] | |
# `s` and `x` must be sorted | |
var | |
startFound = false | |
stopFound = false | |
var j = oldStart | |
while j < x.len: | |
if not startFound and abs(s - x[j]) < dist: | |
startFound = true | |
result[0] = j | |
j = if oldStop == 0: j else: oldStop | |
continue | |
elif startFound and not stopFound and abs(s - x[j]) > dist: | |
stopFound = true | |
result[1] = j | |
break | |
inc j | |
proc kde[T: SomeNumber](s: openArray[T], | |
kernel: KernelFunc, | |
kernelKind = knCustom, | |
adjust: float = 1.0, | |
samples: int = 1000, | |
bw: float = NaN, | |
normalize = true, | |
cutoff: float = NaN): seq[float] = | |
## returns the kernel density estimation for `s`. The returned | |
## sequence contains `samples` elements. | |
## The bandwidth is estimated using Silverman's rule of thumb. | |
## `adjust` can be used to scale the automatic bandwidth calculation. | |
## Note that this assumes the data is roughly normal distributed. To | |
## override the automatic bandwidth calculation, hand the `bw` manually. | |
## If `normalize` is true the result will be normalized such that the | |
## integral over it is equal to 1. | |
## | |
## UPDATE / FINISH | |
let N = s.len | |
# sort input | |
let s = s.sorted | |
let (minS, maxS) = (min(s), max(s)) | |
let x = linspace(minS, maxS, samples) | |
let A = min(standardDeviation(s), | |
iqr(s) / 1.34) | |
let bwAct = if classify(bw) != fcNaN: bw | |
else: 0.9 * A * pow(N.float, -1.0/5.0) | |
result = newSeq[float](samples) | |
let norm = 1.0 / (N.float * bwAct) | |
var | |
start = 0 | |
stop = 0 | |
doAssert classify(cutoff) != fcNan or kernelKind != knCustom, "If a custom " & | |
"is used you have to provide a cutoff distance!" | |
let cutoff = if classify(cutoff) != fcNan: cutoff | |
else: getCutoff(bwAct, kernelKind) | |
for i in 0 ..< s.len: | |
(start, stop) = findWindow(cutoff, s[i], x, start, stop) | |
for j in start ..< stop: | |
result[j] += norm * kernel(x[j], s[i], bwAct) | |
if normalize: | |
let normFactor = result.sum * (maxS - minS) / samples.float | |
for i in 0 ..< samples: | |
result[i] = normFactor * result[i] | |
func kde[T: SomeNumber; U: KernelKind | string]( | |
s: openArray[T], | |
kernel: U = "gauss", | |
adjust: float = 1.0, | |
samples: int = 1000, | |
bw: float = NaN, | |
normalize = true): seq[float] = | |
when U is string: | |
let kKind = parseEnum[KernelKind](kernel) | |
else: | |
let kKind = kernel | |
let kernelFn = getKernelFunc(kKind) | |
result = kde(s, | |
kernelFn, | |
kernelKind = kKind, | |
adjust = adjust, | |
samples = samples, | |
bw = bw, | |
normalize = normalize) | |
proc normalize[T](s, bins: openArray[T]): seq[T] = | |
## normalizes the given sequence to an integral of 1 | |
# NOTE: evenly spaced binning required! | |
let factor = s.sum * (max(bins) - min(bins)) / bins.len.float | |
result = s.mapIt(it / factor) | |
proc main() = | |
let df = toDf(readCsv("data/diamonds.csv")) | |
let carat = df["carat"].toTensor(float).toRawSeq | |
let x = linspace(min(carat), max(carat), 1000) | |
let estimate = kde(carat) | |
let dfEst = seqsToDf(x, estimate) | |
ggplot(dfEst, aes("x", "estimate")) + | |
geom_line(fillColor = some(parseHex("9B4EFF")), | |
alpha = some(0.3)) + | |
ggsave("density_test.pdf") | |
when isMainModule: | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is essentially the example from: https://ggplot2.tidyverse.org/reference/geom_density.html