Skip to content

Instantly share code, notes, and snippets.

@Edenhofer
Last active December 20, 2021 10:21
Show Gist options
  • Save Edenhofer/0f4947e4bf02cb2bcb3af7204ed2e626 to your computer and use it in GitHub Desktop.
Save Edenhofer/0f4947e4bf02cb2bcb3af7204ed2e626 to your computer and use it in GitHub Desktop.
Re-Implementation of SciPy's gaussian_kde
#!/usr/bin/env python3
from functools import partial
from jax import vmap, jit
from jax import numpy as jnp
def _det(mat, n_dim=None):
import numpy as np
if np.ndim(mat) == 2:
det = jnp.linalg.det(mat)
elif np.ndim(mat) == 1:
if n_dim is not None:
if np.shape(mat) == (1, ):
det = jnp.squeeze(mat)**n_dim
elif np.shape(mat) != (n_dim, ):
ve = f"incompatible shapes {np.shape(mat)!r} and (n_dim, )"
raise ValueError(ve)
else:
det = jnp.prod(mat)
else:
det = jnp.prod(mat)
elif np.ndim(mat) == 0:
if n_dim is None:
te = "unable to compute determinant if dimensions are unknown"
raise TypeError(te)
det = mat**n_dim
else:
raise ValueError(f"invalid matrix {mat!r}")
return det
@partial(jit, static_argnames=("norm", ))
def _nocheck_singlex_gaussian_kde_logpdf(
x, dataset, inv_cov, normed_weights=None, norm=True
):
"""Computes the kernel log density between a single point and a set of
points.
Notes
-----
Using a full matrix instead of an implicit diagonal matrix is about 75 %
slower in two dimensions. Using a scalar in favor of a diagonal matrix in
two dimensions however yields no measurable performance benefit.
"""
import numpy as np
from jax.scipy.special import logsumexp
# `x` should always be one dimension short of `dataset` and should be
# mapped over
assert np.ndim(x) == 1, (
f"unexpected dimension of `x` {np.shape(x)!r}"
f"; expected (# of dims, ); maybe you forgot `vmap`"
)
assert np.ndim(dataset) == 2, (
f"unexpected dimension of `dataset` {np.shape(dataset)!r}"
f"; expected (# of dims, # of data)"
)
assert np.ndim(inv_cov) in (0, 1, 2), f"invalid `inv_cov` {inv_cov!r}"
res = x.reshape(-1, 1) - dataset
if np.ndim(inv_cov) == 2:
expo = jnp.sum((inv_cov @ res) * res, axis=0)
else:
assert np.ndim(inv_cov) in (0, 1)
expo = jnp.sum(inv_cov.reshape(-1, 1) * res**2, axis=0)
ln_norm_kernel = logsumexp(-0.5 * expo, b=normed_weights)
if norm:
n_dim, n_norm = dataset.shape
n_norm = 1. if normed_weights is not None else n_norm
ln_norm_kernel -= 0.5 * jnp.log(
n_norm**2 * (2. * jnp.pi)**n_dim / _det(inv_cov, n_dim)
)
return ln_norm_kernel
def gaussian_kde_covarance_factor(n_dim, n_eff_data, method=None):
method = "scott" if method is None else method
if isinstance(method, str):
if method.lower().startswith("scott"):
factor = n_eff_data**(-1. / (n_dim + 4))
elif method.lower().startswith("silverman"):
b = n_eff_data * (n_dim + 2.) / 4.
factor = b**(-1. / (n_dim + 4.))
else:
raise ValueError(f"invalid method {method!r}")
elif callable(method):
raise NotImplementedError("there is no `gaussian_kde` object")
else:
factor = method
return factor
def gaussian_kde_covariance_inv_cov(
dataset, bw_method=None, weights=None, _normed_weights=None
):
dataset = jnp.atleast_2d(dataset).astype(float)
n_dim, n_data = dataset.shape
if weights is not None or _normed_weights is not None:
if _normed_weights is not None:
wgt = jnp.atleast_1d(_normed_weights).astype(float)
else:
wgt = jnp.atleast_1d(weights).astype(float)
wgt /= jnp.sum(wgt)
n_eff_data = 1. / jnp.sum(wgt**2)
if wgt.ndim != 1 or wgt.size != n_data:
ve = (
"incompatible or invalid shape of `weights`"
f" {wgt.shape!r}"
)
raise ValueError(ve)
else:
wgt = None
n_eff_data = jnp.array(n_data, float)
cov = jnp.cov(dataset, rowvar=True, bias=False, aweights=wgt)
cov = jnp.atleast_2d(cov)
inv_cov = jnp.linalg.inv(cov)
factor = gaussian_kde_covarance_factor(n_dim, n_eff_data, bw_method)
covariance = cov * factor**2
inv_cov = inv_cov / factor**2
return covariance, inv_cov
def gaussian_kde_logpdf(x, dataset, bw_method=None, weights=None):
x = jnp.atleast_2d(x)
dataset = jnp.atleast_2d(dataset).astype(float)
if weights is not None:
weights = jnp.atleast_1d(weights).astype(float)
weights /= jnp.sum(weights)
_, inv_cov = gaussian_kde_covariance_inv_cov(
dataset, bw_method, _normed_weights=weights
)
return vmap(
_nocheck_singlex_gaussian_kde_logpdf,
in_axes=(1, None, None, None, None)
)(x, dataset, inv_cov, weights, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment