Created
December 22, 2011 10:34
-
-
Save daien/1509853 to your computer and use it in GitHub Desktop.
Build a regular grid and assigns each data point to a cell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def regular_multidim_digitize(data, n_bins=3, lbes=None, rbes=None): | |
""" Build a regular grid and assigns each data point to a cell | |
Parameters | |
---------- | |
data: (n_points, n_dims) array, | |
the data that we which to digitize | |
n_bins: int or (n_dims, ) array-like, optional, default: 3, | |
per-dimension number of bins | |
lbes: float or (n_dims, ) array-like, optional, default: None, | |
per-dimension left-most bin edges | |
by default, use the min along each dimension | |
rbes: float or (n_dims, ) array-like, optional, default: None, | |
per-dimension right-most bin edges | |
by default, use the max along each dimension | |
Returns | |
------- | |
assignments: (n_points, ) array of integers, | |
the assignment index of each point | |
Notes | |
----- | |
"Regular" here means evenly-spaced along each dimension (possibly separately) | |
Each cell has its unique integer index which is the cell number in C-order | |
(last dimension varies fastest). | |
To obtain the cell coordinates from the cell numbers do: | |
np.array(zip(*np.unravel_index(assignments, n_bins))) | |
""" | |
n_points, n_dims = data.shape | |
if isinstance(n_bins, int): | |
n_bins = np.array([n_bins for i in range(n_dims)], dtype=np.int) | |
else: | |
n_bins = np.asarray(n_bins, dtype=np.int) | |
assert n_bins.shape == (n_dims,), "invalid n_bins: {0}".format(n_bins) | |
if lbes is None: | |
lbes = np.min(data, axis=0).astype(np.float) | |
else: | |
if isinstance(lbes, float): | |
lbes = np.array([lbes for i in range(n_dims)], dtype=np.float) | |
else: | |
lbes = np.asarray(lbes, dtype=np.float) | |
assert len(lbes) == n_dims, "Invalid lbes: {0}".format(lbes) | |
# check for overflow | |
assert np.alltrue(lbes <= np.min(data, axis=0)), "lbes not low enough" | |
if rbes is None: | |
rbes = np.max(data, axis=0).astype(np.float) | |
else: | |
if isinstance(rbes, float): | |
rbes = np.array([rbes for i in range(n_dims)], dtype=np.float) | |
else: | |
rbes = np.asarray(rbes, dtype=np.float) | |
assert len(rbes) == n_dims, "Invalid rbes: {0}".format(rbes) | |
# check for overflow | |
assert np.alltrue(rbes >= np.max(data, axis=0)), "rbes not high enough" | |
# get the bin-widths per dimension | |
bws = (1. + 1e-15) * (rbes - lbes) / n_bins # add small shift to have max in last bin | |
# get the per-dim bin multi-dim index of each point | |
dis = ((data - lbes[np.newaxis, :]) / bws[np.newaxis, :]).astype(np.int) | |
# get index of the flattened grid | |
assignments = np.ravel_multi_index(dis.T, n_bins) | |
# DEBUG sanity check: | |
#assert np.alltrue(dis == np.array(zip(*np.unravel_index(assignments, n_bins)))) | |
return assignments |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
taldcroft left a comment on my fork that you might also be interested in.