Skip to content

Instantly share code, notes, and snippets.

Last active January 20, 2019 13:18
Show Gist options
  • Save philastrophist/d7aff0810158269424e95b62ea98b4bc to your computer and use it in GitHub Desktop.
Save philastrophist/d7aff0810158269424e95b62ea98b4bc to your computer and use it in GitHub Desktop.
EM with low cutoff
from functools import partial
from itertools import cycle
import numpy as np
from astroML.density_estimation import XDGMM
import pytest
from matplotlib.patches import Ellipse
def no_selection(x):
"""x is an array of shape == (samples, dims)"""
return x[:, 0] > -np.inf
def gaussians_in_a_box(ndim, ncomp, density, percent=5):
xdgmm = XDGMM(ncomp)
xdgmm.V = np.asarray([np.eye(ndim)] * ncomp) * 0.10
width = np.asarray([[2 * np.sqrt(5.991 * var) for var in np.linalg.eig(v)[0]] for v in xdgmm.V]).max() = np.random.uniform(-width / density, width / density, size=(ncomp, ndim))
c = cycle([0.25, 0.5, 0.1])
xdgmm.alpha = np.array([next(c) for _ in range(ncomp)])
xdgmm.alpha /= xdgmm.alpha.sum()
data = xdgmm.sample(10000)
limits = np.percentile(data, [percent, 100-percent], axis=0)
def gaussians_in_a_box_selection(x):
return ((x > limits[0]) & (x < limits[1])).all(axis=1)
return xdgmm, gaussians_in_a_box_selection
if __name__ == '__main__':
import matplotlib.pyplot as plt
# define RNG for deterministic behavior
from numpy.random import RandomState
seed = 13
rng = RandomState(seed)
ndim, ncomp = 2, 4
model, _ = gaussians_in_a_box(ndim, ncomp, 0.75, 5)
def selection(x):
return (x[:, 1] > -1.5) & (x[:, 1] < 2.5) & (x[:, 0] > 0.5) & (x[:, 0] < 2.75)
data = model.sample(5000)
observed_data = data[selection(data)]
plt.scatter(*data.T, c='k', s=1)
plt.scatter(*observed_data.T, c='g', s=1)
def plot_ellipse(ax, mu, covariance, color, linewidth=2, alpha=0.5):
var, U = np.linalg.eig(covariance)
angle = 180. / np.pi * np.arccos(np.abs(U[0, 0]))
e = Ellipse(mu, 2 * np.sqrt(5.991 * var[0]),
2 * np.sqrt(5.991 * var[1]),
return e
for i in range(model.n_components):
plot_ellipse(plt.gca(),[i], model.V[i], 'g')
import pygmmis
gmm = pygmmis.GMM(K=ncomp, D=ndim)
w = 0.1 # minimum covariance regularization, same units as data
cutoff = 50 # segment the data set into neighborhood within 5 sigma around components
tol = 1e-6 # tolerance on logL to terminate EM
oversampling = 10
maxiter = 300
# run EM
import logging
logging.basicConfig(format='%(message)s', level=logging.INFO)
logL, U =, data, init_method='kmeans', w=w, cutoff=cutoff, tol=tol, rng=rng, maxiter=1,
split_n_merge=gmm.K * (gmm.K - 1) * (gmm.K - 2) / 2)
for i in range(model.n_components):
plot_ellipse(plt.gca(), gmm.mean.copy()[i], gmm.covar.copy()[i], 'k')
logL, U =, observed_data, init_method='none', sel_callback=selection, w=w, cutoff=cutoff, oversampling=oversampling,
tol=tol, rng=rng, maxiter=maxiter, split_n_merge=gmm.K * (gmm.K - 1) * (gmm.K - 2) / 2)
for i in range(gmm.K):
plot_ellipse(plt.gca(), gmm.mean[i], gmm.covar[i], 'r')
logL, U =, observed_data, init_method='none', sel_callback=selection, w=w, cutoff=cutoff, oversampling=oversampling,
tol=tol, rng=rng, maxiter=700)
for i in range(gmm.K):
plot_ellipse(plt.gca(), gmm.mean[i], gmm.covar[i], 'b')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment