Last active
December 20, 2021 10:21
-
-
Save Edenhofer/0f4947e4bf02cb2bcb3af7204ed2e626 to your computer and use it in GitHub Desktop.
Re-Implementation of SciPy's gaussian_kde
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from functools import partial | |
from jax import vmap, jit | |
from jax import numpy as jnp | |
def _det(mat, n_dim=None): | |
import numpy as np | |
if np.ndim(mat) == 2: | |
det = jnp.linalg.det(mat) | |
elif np.ndim(mat) == 1: | |
if n_dim is not None: | |
if np.shape(mat) == (1, ): | |
det = jnp.squeeze(mat)**n_dim | |
elif np.shape(mat) != (n_dim, ): | |
ve = f"incompatible shapes {np.shape(mat)!r} and (n_dim, )" | |
raise ValueError(ve) | |
else: | |
det = jnp.prod(mat) | |
else: | |
det = jnp.prod(mat) | |
elif np.ndim(mat) == 0: | |
if n_dim is None: | |
te = "unable to compute determinant if dimensions are unknown" | |
raise TypeError(te) | |
det = mat**n_dim | |
else: | |
raise ValueError(f"invalid matrix {mat!r}") | |
return det | |
@partial(jit, static_argnames=("norm", )) | |
def _nocheck_singlex_gaussian_kde_logpdf( | |
x, dataset, inv_cov, normed_weights=None, norm=True | |
): | |
"""Computes the kernel log density between a single point and a set of | |
points. | |
Notes | |
----- | |
Using a full matrix instead of an implicit diagonal matrix is about 75 % | |
slower in two dimensions. Using a scalar in favor of a diagonal matrix in | |
two dimensions however yields no measurable performance benefit. | |
""" | |
import numpy as np | |
from jax.scipy.special import logsumexp | |
# `x` should always be one dimension short of `dataset` and should be | |
# mapped over | |
assert np.ndim(x) == 1, ( | |
f"unexpected dimension of `x` {np.shape(x)!r}" | |
f"; expected (# of dims, ); maybe you forgot `vmap`" | |
) | |
assert np.ndim(dataset) == 2, ( | |
f"unexpected dimension of `dataset` {np.shape(dataset)!r}" | |
f"; expected (# of dims, # of data)" | |
) | |
assert np.ndim(inv_cov) in (0, 1, 2), f"invalid `inv_cov` {inv_cov!r}" | |
res = x.reshape(-1, 1) - dataset | |
if np.ndim(inv_cov) == 2: | |
expo = jnp.sum((inv_cov @ res) * res, axis=0) | |
else: | |
assert np.ndim(inv_cov) in (0, 1) | |
expo = jnp.sum(inv_cov.reshape(-1, 1) * res**2, axis=0) | |
ln_norm_kernel = logsumexp(-0.5 * expo, b=normed_weights) | |
if norm: | |
n_dim, n_norm = dataset.shape | |
n_norm = 1. if normed_weights is not None else n_norm | |
ln_norm_kernel -= 0.5 * jnp.log( | |
n_norm**2 * (2. * jnp.pi)**n_dim / _det(inv_cov, n_dim) | |
) | |
return ln_norm_kernel | |
def gaussian_kde_covarance_factor(n_dim, n_eff_data, method=None): | |
method = "scott" if method is None else method | |
if isinstance(method, str): | |
if method.lower().startswith("scott"): | |
factor = n_eff_data**(-1. / (n_dim + 4)) | |
elif method.lower().startswith("silverman"): | |
b = n_eff_data * (n_dim + 2.) / 4. | |
factor = b**(-1. / (n_dim + 4.)) | |
else: | |
raise ValueError(f"invalid method {method!r}") | |
elif callable(method): | |
raise NotImplementedError("there is no `gaussian_kde` object") | |
else: | |
factor = method | |
return factor | |
def gaussian_kde_covariance_inv_cov( | |
dataset, bw_method=None, weights=None, _normed_weights=None | |
): | |
dataset = jnp.atleast_2d(dataset).astype(float) | |
n_dim, n_data = dataset.shape | |
if weights is not None or _normed_weights is not None: | |
if _normed_weights is not None: | |
wgt = jnp.atleast_1d(_normed_weights).astype(float) | |
else: | |
wgt = jnp.atleast_1d(weights).astype(float) | |
wgt /= jnp.sum(wgt) | |
n_eff_data = 1. / jnp.sum(wgt**2) | |
if wgt.ndim != 1 or wgt.size != n_data: | |
ve = ( | |
"incompatible or invalid shape of `weights`" | |
f" {wgt.shape!r}" | |
) | |
raise ValueError(ve) | |
else: | |
wgt = None | |
n_eff_data = jnp.array(n_data, float) | |
cov = jnp.cov(dataset, rowvar=True, bias=False, aweights=wgt) | |
cov = jnp.atleast_2d(cov) | |
inv_cov = jnp.linalg.inv(cov) | |
factor = gaussian_kde_covarance_factor(n_dim, n_eff_data, bw_method) | |
covariance = cov * factor**2 | |
inv_cov = inv_cov / factor**2 | |
return covariance, inv_cov | |
def gaussian_kde_logpdf(x, dataset, bw_method=None, weights=None): | |
x = jnp.atleast_2d(x) | |
dataset = jnp.atleast_2d(dataset).astype(float) | |
if weights is not None: | |
weights = jnp.atleast_1d(weights).astype(float) | |
weights /= jnp.sum(weights) | |
_, inv_cov = gaussian_kde_covariance_inv_cov( | |
dataset, bw_method, _normed_weights=weights | |
) | |
return vmap( | |
_nocheck_singlex_gaussian_kde_logpdf, | |
in_axes=(1, None, None, None, None) | |
)(x, dataset, inv_cov, weights, True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment