Last active
December 14, 2015 18:09
-
-
Save ghl3/5127363 to your computer and use it in GitHub Desktop.
Standardized plotting of features for pandas data frame.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def scatter_plots(df, class_title, feature_names=None, cmap=None): | |
| # Create a pandas group view for each class | |
| class_names = list(set(df[class_title])) | |
| groups = [df[(df[class_title] == name)] for name in class_names] | |
| if feature_names==None: | |
| feature_names = [col for col in df.columns if col != class_title] | |
| NUM_COLORS = len(class_names) | |
| if cmap==None: | |
| cmap = plt.get_cmap('gist_rainbow') | |
| colors = [cmap(1.*i/NUM_COLORS) for i in range(NUM_COLORS)] | |
| def pair_scatter(colA, colB): | |
| for group, name, color in zip(groups, class_names, colors): | |
| plt.scatter(group[colA], group[colB], marker="o", color=color, label=name) | |
| plt.legend() | |
| plt.xlabel(colA) | |
| plt.ylabel(colB) | |
| features = df[feature_names] | |
| import itertools | |
| combos = list(itertools.combinations(features, 2)) | |
| nrows = math.ceil(math.sqrt(len(combos))) | |
| for idx, (colA, colB) in enumerate(combos): | |
| plt.subplot(nrows, nrows, idx+1) | |
| pair_scatter(colA, colB) | |
| def feature_histograms(df, class_title, feature_names=None, nbins=30, cmap=None): | |
| class_names = list(set(df[class_title])) | |
| groups = [df[(df[class_title] == name)] for name in class_names] | |
| NUM_COLORS = len(class_names) | |
| if cmap==None: | |
| cmap = plt.get_cmap('gist_rainbow') | |
| colors = [cmap(1.*i/NUM_COLORS) for i in range(NUM_COLORS)] | |
| if feature_names==None: | |
| feature_names = [col for col in df.columns if col != class_title] | |
| def feature_hist(feature): | |
| bin_range = min(df[feature]), max(df[feature]) | |
| for group, class_name, color in zip(groups, class_names, colors): | |
| plt.hist(group[feature], color=color, label=class_name, bins=nbins, | |
| range=bin_range) | |
| plt.legend() | |
| plt.xlabel(feature) | |
| import math | |
| nrows = math.ceil(math.sqrt(len(feature_names))) | |
| for idx, feature in enumerate(feature_names): | |
| plt.subplot(nrows, nrows, idx+1) | |
| feature_hist(feature) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment