Created
May 27, 2020 03:59
-
-
Save MattJBritton/b6944218903312f6220bfb48706b593c to your computer and use it in GitHub Desktop.
Feature Space Diagram
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from scipy.spatial.distance import pdist, squareform | |
from sklearn import datasets | |
from sklearn.decomposition import PCA | |
from sklearn.cluster import FeatureAgglomeration | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.impute import SimpleImputer | |
from sklearn.metrics import silhouette_score | |
# vis | |
import altair as alt | |
alt.renderers.enable("default") | |
alt.data_transformers.disable_max_rows() | |
# import data | |
# from https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic) | |
breast_cancer_dataset = datasets.load_breast_cancer() | |
breast_cancer_df = pd.DataFrame(breast_cancer_dataset.data, columns = breast_cancer_dataset.feature_names) | |
breast_cancer_df["diagnosis"] = breast_cancer_dataset["target"] | |
# function to generate FSD | |
def feature_space_diagram(data, target_name = None): | |
df = data.copy() | |
corr = np.round(df.corr(),2) | |
if target_name is not None: | |
target_corr = np.abs(corr.rename({target_name:"target_corr"}, axis=1)["target_corr"]) | |
df = df.drop(target_name, axis=1) | |
corr = corr.drop(target_name, axis=1).drop(target_name, axis=0) | |
sidebar_width = 200 | |
sidebar_component_height = 75 | |
#compute PCA and store as X,Y coordinates for each feature | |
pca = PCA(n_components = 2) | |
pca.fit(np.abs(corr)) | |
pca_coords = pd.DataFrame.from_dict( | |
dict( | |
zip( | |
list(df.columns), | |
pca.transform(np.abs(corr)).tolist() | |
) | |
), | |
orient="index" | |
).reset_index().rename({0:"X", 1:"Y", "index":"feature"}, axis=1) | |
#get feature clusters | |
scaler = StandardScaler() | |
feature_distances = squareform(pdist(scaler.fit_transform(df).T, "euclidean")) | |
silhouette_scores = [] | |
cluster_range = range(3,df.shape[1]) | |
for n_cluster in cluster_range: | |
corr_clusters = FeatureAgglomeration(n_clusters = n_cluster, affinity = "precomputed", linkage = "average").fit(feature_distances) | |
silhouette_scores = silhouette_scores\ | |
+ [ | |
{ | |
"cluster_num": n_cluster, | |
"silhouette_score": silhouette_score(feature_distances, corr_clusters.labels_, metric = "precomputed"), | |
"feature": list(df.columns)[i], | |
"cluster": label | |
} | |
for i, label in enumerate(corr_clusters.labels_) | |
] | |
cluster_label_df = pd.DataFrame(silhouette_scores) | |
cluster_label_df["cluster_size"] = cluster_label_df.groupby(["cluster_num", "cluster"])["feature"].transform("count") | |
cluster_label_df["key"] = cluster_label_df["cluster_num"].astype(str).str.cat(cluster_label_df["feature"].astype(str), sep=":") | |
cluster_label_df["cluster"] = cluster_label_df.groupby(["cluster_num","cluster"])["feature"].transform("first") | |
default_cluster_num = cluster_label_df.groupby("cluster_num")["silhouette_score"].max().idxmax() | |
# set correlation with target, if using, which determines circle size | |
if target_name is not None: | |
pca_coords = pca_coords.join( | |
target_corr | |
).reset_index() | |
else: | |
pca_coords = pca_coords.reset_index() | |
pca_coords["target_corr"] = 1 | |
# get dataset for lines between features (if they have higher correlation than corr_threshold) | |
corr_lines = corr.reset_index(drop=False).rename({"index":"feature"}, axis=1)\ | |
.melt(id_vars = ["feature"], var_name = "feature_2", value_name = "corr")\ | |
.query("feature > feature_2") | |
corr_lines["corr_abs"] = np.abs(corr_lines["corr"]) | |
corr_selector_data = corr_lines.copy() | |
corr_selector_data["corr_abs"] = np.floor((corr_selector_data["corr_abs"]*10))/10 | |
corr_selector_data = corr_selector_data.groupby("corr_abs").size().reset_index().rename({0:"Count"}, axis = 1) | |
corr_lines_1 = pd.merge( | |
corr_lines, | |
pca_coords.loc[:,["feature", "X", "Y"]], | |
on = "feature" | |
) | |
corr_lines_2 = pd.merge( | |
corr_lines, | |
pca_coords.set_index("feature").loc[:,["X", "Y"]], | |
left_on = "feature_2", right_index = True | |
) | |
corr_lines = corr_lines_1.append(corr_lines_2) | |
corr_lines["key"] = corr_lines["feature"] + corr_lines["feature_2"] | |
corr_line_selector = alt.selection_single(fields = ["corr_abs"], init = {"corr_abs":0.7}) | |
cluster_num_selector = alt.selection_single(fields = ["cluster_num"], init = {"cluster_num":default_cluster_num}) | |
cluster_selection = alt.selection_single(fields=["cluster"]) | |
base = alt.layer().encode( | |
x = alt.X("X", axis=None), | |
y = alt.Y("Y", axis=None), | |
color = alt.condition( | |
cluster_selection, | |
alt.Color("cluster:N", legend = None), | |
alt.value("lightgray") | |
) | |
) | |
base += alt.Chart(pca_coords).mark_circle().encode( | |
size = alt.Size("target_corr:Q", scale=alt.Scale(domain = [0,1]), legend=None) | |
).transform_calculate( | |
key = cluster_num_selector.cluster_num + ":" + alt.datum.feature | |
).transform_lookup( | |
lookup='key', | |
from_=alt.LookupData(data=cluster_label_df, key='key', | |
fields=['cluster_size', 'cluster']) | |
).add_selection( | |
cluster_num_selector | |
) | |
base += alt.Chart(pca_coords).mark_text(dx=20, dy = 10).encode( | |
text = "feature", | |
).transform_calculate( | |
key = cluster_num_selector.cluster_num + ":" + alt.datum.feature | |
).transform_lookup( | |
lookup='key', | |
from_=alt.LookupData(data=cluster_label_df, key='key', | |
fields=['cluster_size', 'cluster']) | |
) | |
base += alt.Chart(corr_lines).mark_line().encode( | |
detail = "key", | |
strokeWidth = alt.StrokeWidth("corr_abs", scale = alt.Scale(domain = [0,1], range = [.3,3])) | |
).transform_filter( | |
alt.datum.corr_abs >= corr_line_selector.corr_abs | |
).transform_calculate( | |
key = cluster_num_selector.cluster_num + ":" + alt.datum.feature | |
).transform_lookup( | |
lookup='key', | |
from_=alt.LookupData(data=cluster_label_df, key='key', | |
fields=['cluster_size', 'cluster']) | |
) | |
base = base.properties( | |
width = 800, | |
height = 500, | |
title = "Feature Space Diagram" | |
).interactive() | |
num_cluster_picker = alt.Chart(cluster_label_df).mark_bar().encode( | |
y = alt.Y("silhouette_score", title = "Silhouette Score"), | |
x = "cluster_num:O", | |
color = alt.condition( | |
cluster_num_selector, | |
alt.value("lightblue"), | |
alt.value("lightgray") | |
) | |
).add_selection( | |
cluster_num_selector | |
).properties( | |
width = sidebar_width, | |
height = sidebar_component_height, | |
title = "Select the Number of Clusters" | |
) | |
corr_threshold_picker = alt.Chart(corr_selector_data).mark_bar().encode( | |
x = "corr_abs:O", | |
y = alt.Y("Count", axis = alt.Axis(labelAngle = 0, title = "Feature Pairs")), | |
color = alt.condition( | |
alt.datum.corr_abs >= corr_line_selector.corr_abs, | |
alt.value("lightblue"), | |
alt.value("lightgray") | |
) | |
).add_selection( | |
corr_line_selector | |
).properties( | |
width = sidebar_width, | |
height = sidebar_component_height, | |
title = "Select Correlation Threshold to Show Lines" | |
) | |
cluster_bar_chart = alt.Chart(cluster_label_df).mark_bar(size=5).encode( | |
y = alt.Y( | |
"cluster:N", | |
sort = alt.EncodingSortField(field = "cluster_size", order="descending"), | |
title = None # "Clusters" | |
), | |
x = "cluster_size", | |
color = alt.Color("cluster:N", legend=None), | |
).add_selection( | |
cluster_selection | |
).transform_filter( | |
(alt.datum.cluster_num >= cluster_num_selector.cluster_num) & (alt.datum.cluster_num <= cluster_num_selector.cluster_num) | |
).properties( | |
width = sidebar_width, | |
height = 200, | |
title = "Cluster Sizes. Click to Highlight" | |
) | |
return (base) | (num_cluster_picker & corr_threshold_picker & cluster_bar_chart) | |
# call function and render chart | |
feature_space_diagram(breast_cancer_df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment