Created
August 13, 2021 13:39
-
-
Save finloop/97d5834c6f3bee4bf78148389ac3cccc to your computer and use it in GitHub Desktop.
Plot pie charts for each cluster to show how popular are different categories in each cluster.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
from math import ceil | |
# Parameters for data. I assume data is in df. | |
cluster_col = "cluster" # Column with clusters | |
group_by = "customer_city" # Metric to group by "city" etc. | |
n_first = 4 # Number of most popular entries of `group_by` in cluster to include | |
clusters = df[cluster_col].unique() | |
nclusters = len(clusters) | |
# Plot | |
fig1, axs = plt.subplots(ceil(sqrt(nclusters)),ceil(sqrt(nclusters)), figsize=(40,30)) | |
axs = axs.flatten() | |
for i, cluster in enumerate(clusters): | |
# Locate all data in cluster, then get n_first most popular entries in group_by. Everything else mark as "other". | |
data = df.loc[df[cluster_col] == cluster,:].copy() | |
other = data.groupby(group_by).count().sort_values(data.columns[0], ascending=False).iloc[n_first:,:].reset_index()[group_by].array | |
data.loc[data[group_by].isin(other), group_by] = "other" | |
# Extract data and labels | |
data = data.groupby(groupby_col).count().iloc[:,0] | |
labels = data.index.to_numpy() | |
data = data.to_numpy() | |
axs[i].pie(data, labels=labels, autopct='%1.1f%%', | |
shadow=True, startangle=90, textprops={'fontsize': 20}) | |
axs[i].set_title(str(cluster), fontsize=25) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment