Skip to content

Instantly share code, notes, and snippets.

@benkrikler
Created October 29, 2019 12:47
Show Gist options
  • Save benkrikler/6786b384184d056f18dcda96a6ef4b33 to your computer and use it in GitHub Desktop.
Save benkrikler/6786b384184d056f18dcda96a6ef4b33 to your computer and use it in GitHub Desktop.
pandas dataframe with lists in cells to a wide-form one-item-per-column
def flatten_to_wideform(df, mask_column=None, max_objects=None):
# get the list columns
lst_cols = [col for col, dtype in df.dtypes.items()
if pd.api.types.is_object_dtype(dtype)
and df[col].head(10).apply(lambda x: isinstance(x, (list, np.ndarray, tuple))).all()
]
if not lst_cols:
return df
# all columns except `lst_cols`
idx_cols = df.columns.difference(lst_cols)
# check all lists have same length
lens = pd.DataFrame({col: df[col].str.len() for col in lst_cols})
diffs = lens[lens.columns[1:]].subtract(lens[lens.columns[0]], axis=0)
different_length = (diffs != 0)
if different_length.any().any():
raise ValueError("Cannot bin multiple arrays with different jaggedness")
lens = lens[lst_cols[0]]
# create "exploded" DF, with one object per row
flattened = {col: np.concatenate(df.loc[lens > 0, col].values) for col in lst_cols}
flattened = pd.DataFrame(flattened)
index = "new_index"
index += str(hash(index)) # just to make sure it's a random name
flattened[index] = np.repeat(np.arange(len(lens)), lens)
# We might be given a mask to ignore objects so apply it if so
if mask_column is not None:
mask = flattened[mask_column]
flattened.drop(mask_column, inplace=True, axis="columns")
flattened = flattened.loc[mask]
# Give each object a sub-index label
subindex = "sub" + index
flattened[subindex] = flattened.groupby(index).cumcount()
flattened.set_index([index, subindex], inplace=True, drop=True)
# Keep up to max_objects items
if max_objects:
flattened = flattened.groupby(level=index).head(max_objects)
# Finally convert rows back to columns and add back the other variables
res = flattened.unstack(subindex).sort_index(level=subindex, axis=1)
res.columns = ["%s[%d]" % vals for vals in res.columns.values]
res = df[idx_cols].join(res, how="outer").reset_index(drop=True)
return res
from explode_wide import flatten_to_wideform
length = 1000
nobjs = np.random.randint(15, size=length)
data = {"B": np.random.normal(3, 4, length), "A": np.arange(length) ** 2,
"C": [], "D": [], "mask": []}
for n in nobjs:
data["C"].append(np.random.normal(1, 1, n) ** 2)
data["D"].append(np.arange(n) + 2)
data["mask"].append(np.random.normal(0, 1, n) > 0)
fake = pd.DataFrame(data)
fake_flat = flatten_to_wideform(fake, max_objects=3, mask_column="mask")
for col in "AB":
print(col, (fake_flat[col] == fake[col]).all())
def check_index(orig, mask_index, compare, index):
masked = [val[mask[index]] for mask, val in zip(mask_index, orig) if len(mask) > index]
print(compare.name, (compare.dropna() == masked).all())
mask_index = [np.where(val)[0] for val in data["mask"]]
for col in "CD":
for i in range(3):
check_index(fake[col], mask_index, fake_flat[col + "[%d]" % i], index=i)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment