Last active
October 26, 2021 14:04
-
-
Save alexland/d6d64d3f634895b9dc8e to your computer and use it in GitHub Desktop.
cross tabulation in NumPy (table whose cells are counts by value over the two table dimensions)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def xtab(*cols, apply_wt=False): | |
''' | |
returns: | |
(i) xt, NumPy array storing the xtab results, number of dimensions is equal to | |
the len(args) passed in | |
(ii) unique_vals_all_cols, a tuple of 1D NumPy array for each dimension | |
in xt (for a 2D xtab, the tuple comprises the row and column headers) | |
pass in: | |
(i) 1 or more 1D NumPy arrays of integers | |
(ii) if wts is True, then the last array in cols is an array of weights | |
if return_inverse=True, then NP.unique also returns an integer index | |
(from 0, & of same len as array passed in) such that, uniq_vals[idx] gives the original array passed in | |
higher dimensional cross tabulations are supported (eg, 2D & 3D) | |
cross tabulation on two variables (columns): | |
>>> q1 = NP.array([7, 8, 8, 8, 5, 6, 4, 6, 6, 8, 4, 6, 6, 6, 6, 8, 8, 5, 8, 6]) | |
>>> q2 = NP.array([6, 4, 6, 4, 8, 8, 4, 8, 7, 4, 4, 8, 8, 7, 5, 4, 8, 4, 4, 4]) | |
>>> uv, xt = xtab(q1, q2) | |
>>> uv | |
(array([4, 5, 6, 7, 8]), array([4, 5, 6, 7, 8])) | |
>>> xt | |
array([[2, 0, 0, 0, 0], | |
[1, 0, 0, 0, 1], | |
[1, 1, 0, 2, 4], | |
[0, 0, 1, 0, 0], | |
[5, 0, 1, 0, 1]], dtype=uint64) | |
''' | |
if not all(len(col) == len(cols[0]) for col in cols[1:]): | |
raise ValueError("all arguments must be same size") | |
if len(cols) == 0: | |
raise TypeError("xtab() requires at least one argument") | |
fnx1 = lambda q: len(q.squeeze().shape) | |
if not all([fnx1(col) == 1 for col in cols]): | |
raise ValueError("all input arrays must be 1D") | |
if apply_wt: | |
cols, wt = cols[:-1], cols[-1] | |
else: | |
wt = 1 | |
uniq_vals_all_cols, idx = zip( *(NP.unique(col, return_inverse=True) for col in cols) ) | |
shape_xt = [uniq_vals_col.size for uniq_vals_col in uniq_vals_all_cols] | |
dtype_xt = 'float' if apply_wt else 'uint' | |
xt = NP.zeros(shape_xt, dtype=dtype_xt) | |
NP.add.at(xt, idx, wt) | |
return uniq_vals_all_cols, xt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment