Last active
September 8, 2015 17:26
-
-
Save kbarbary/2a169680e52bd7089950 to your computer and use it in GitHub Desktop.
Ellipsoid refinement code for Nestle
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def ellipsoid_dist_sq(x, ell): | |
"""Return the square of each point's distance from the ellipsoid center, | |
relative to ellipsoid boundary. | |
Refered to in Feroz, Hobson & Bridges (2009) as the Mahalanobis distance. | |
Parameters | |
---------- | |
x : `~numpy.ndarray` | |
Points, array with shape (npoints, ndim). | |
ell : Ellipsoid | |
Ellipsoid to compare to. | |
""" | |
delta = x - ell.ctr | |
# fast way to compute delta[i] @ A @ delta[i] for all i. | |
return np.einsum('...i, ...i', np.tensordot(delta, ell.a, axes=1), delta) | |
def refine_ellipsoids(x, label, pointvol, xs, ells): | |
"""Reassign points between two clusters so as to minimize total volume | |
of the ellipsoids bounding each cluster. | |
This is as described in Feroz, Hobson & Bridges (2009), Algorithm 1. | |
There, the distance measure is referred to as the "Mahalanobis distance". | |
Parameters | |
---------- | |
x : `~numpy.ndarray` | |
All points in either cluster. | |
label : `~numpy.ndarray` | |
1-d integer array of 0 or 1 giving cluster membership of each point | |
in x. | |
pointvol : float | |
Expected volume corresponding to each point. | |
""" | |
hs = [None, None] | |
ndim = x.shape[1] | |
for it in range(10): # limit to 10 iterations | |
for k in (0, 1): | |
hs[k] = ellipsoid_dist_sq(x, ells[k]) | |
hs[k] *= ells[k].vol / (len(xs[k]) * pointvol) # actual/expected vol | |
# reassign each point to the cluster that gives it the smallest h. | |
# if hs[1] < hs[0] -> True -> (cast to int) -> 1 (label cluster 1) | |
# if hs[0] < hs[1] -> False -> (cast to int) -> 0 (label cluster 0) | |
newlabel = (hs[1] < hs[0]).astype(np.int) | |
# If no points were reassigned, exit the loop. | |
if np.all(newlabel == label): | |
break | |
# update labels, calculate new member points | |
label = newlabel | |
new_xs = [x[label == k, :] for k in (0, 1)] | |
# if either cluster doesn't have enough points, return previous | |
# clusters | |
if new_xs[0].shape[0] <= ndim or new_xs[1].shape[0] <= ndim: | |
return xs, ells | |
for k in (0, 1): | |
xs[k] = new_xs[k] | |
ells[k] = bounding_ellipsoid(xs[k], pointvol=pointvol, | |
minvol=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think there is typo in line 39:
hs
instead ofh
...