Created
October 29, 2019 12:47
-
-
Save benkrikler/6786b384184d056f18dcda96a6ef4b33 to your computer and use it in GitHub Desktop.
pandas dataframe with lists in cells to a wide-form one-item-per-column
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def flatten_to_wideform(df, mask_column=None, max_objects=None): | |
# get the list columns | |
lst_cols = [col for col, dtype in df.dtypes.items() | |
if pd.api.types.is_object_dtype(dtype) | |
and df[col].head(10).apply(lambda x: isinstance(x, (list, np.ndarray, tuple))).all() | |
] | |
if not lst_cols: | |
return df | |
# all columns except `lst_cols` | |
idx_cols = df.columns.difference(lst_cols) | |
# check all lists have same length | |
lens = pd.DataFrame({col: df[col].str.len() for col in lst_cols}) | |
diffs = lens[lens.columns[1:]].subtract(lens[lens.columns[0]], axis=0) | |
different_length = (diffs != 0) | |
if different_length.any().any(): | |
raise ValueError("Cannot bin multiple arrays with different jaggedness") | |
lens = lens[lst_cols[0]] | |
# create "exploded" DF, with one object per row | |
flattened = {col: np.concatenate(df.loc[lens > 0, col].values) for col in lst_cols} | |
flattened = pd.DataFrame(flattened) | |
index = "new_index" | |
index += str(hash(index)) # just to make sure it's a random name | |
flattened[index] = np.repeat(np.arange(len(lens)), lens) | |
# We might be given a mask to ignore objects so apply it if so | |
if mask_column is not None: | |
mask = flattened[mask_column] | |
flattened.drop(mask_column, inplace=True, axis="columns") | |
flattened = flattened.loc[mask] | |
# Give each object a sub-index label | |
subindex = "sub" + index | |
flattened[subindex] = flattened.groupby(index).cumcount() | |
flattened.set_index([index, subindex], inplace=True, drop=True) | |
# Keep up to max_objects items | |
if max_objects: | |
flattened = flattened.groupby(level=index).head(max_objects) | |
# Finally convert rows back to columns and add back the other variables | |
res = flattened.unstack(subindex).sort_index(level=subindex, axis=1) | |
res.columns = ["%s[%d]" % vals for vals in res.columns.values] | |
res = df[idx_cols].join(res, how="outer").reset_index(drop=True) | |
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from explode_wide import flatten_to_wideform | |
length = 1000 | |
nobjs = np.random.randint(15, size=length) | |
data = {"B": np.random.normal(3, 4, length), "A": np.arange(length) ** 2, | |
"C": [], "D": [], "mask": []} | |
for n in nobjs: | |
data["C"].append(np.random.normal(1, 1, n) ** 2) | |
data["D"].append(np.arange(n) + 2) | |
data["mask"].append(np.random.normal(0, 1, n) > 0) | |
fake = pd.DataFrame(data) | |
fake_flat = flatten_to_wideform(fake, max_objects=3, mask_column="mask") | |
for col in "AB": | |
print(col, (fake_flat[col] == fake[col]).all()) | |
def check_index(orig, mask_index, compare, index): | |
masked = [val[mask[index]] for mask, val in zip(mask_index, orig) if len(mask) > index] | |
print(compare.name, (compare.dropna() == masked).all()) | |
mask_index = [np.where(val)[0] for val in data["mask"]] | |
for col in "CD": | |
for i in range(3): | |
check_index(fake[col], mask_index, fake_flat[col + "[%d]" % i], index=i) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment