Created
April 3, 2013 13:25
-
-
Save adrn/5301190 to your computer and use it in GitHub Desktop.
Make a scatter-plot matrix with matplotlib
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
""" Create a scatter-plot matrix using Matplotlib. """ | |
from __future__ import division, print_function | |
__author__ = "adrn <[email protected]>" | |
# Standard library | |
import os, sys | |
# Third-party | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import astropy.units as u | |
def scatter_plot_matrix(data, labels=None, axes=None, subplots_kwargs=dict(), | |
scatter_kwargs=dict()): | |
""" Create a scatter plot matrix from the given data. | |
Parameters | |
---------- | |
data : numpy.ndarray | |
A numpy array containined the scatter data to plot. The data | |
should be shape MxN where M is the number of dimensions and | |
with N data points. | |
labels : numpy.ndarray (optional) | |
A numpy array of length M containing the axis labels. | |
axes : matplotlib Axes array (optional) | |
If you've already created the axes objects, pass this in to | |
plot the data on that. | |
subplots_kwargs : dict (optional) | |
A dictionary of keyword arguments to pass to the | |
matplotlib.pyplot.subplots call. Note: only relevant if axes=None. | |
scatter_kwargs : dict (optional) | |
A dictionary of keyword arguments to pass to the | |
matplotlib.pyplot.scatter function calls. | |
""" | |
try: | |
M,N = data.shape | |
if M > N: raise ValueError() | |
except ValueError: # too many values to unpack | |
raise ValueError("Invalid data shape {0}. You must pass in an array of " | |
"shape (M, N) where N > M.".format(data.shape)) | |
if labels == None: | |
labels = [None]*M | |
if axes == None: | |
skwargs = subplots_kwargs.copy() | |
skwargs["sharex"] = True if not skwargs.has_key("sharex") else skwargs["sharex"] | |
skwargs["sharey"] = True if not skwargs.has_key("sharey") else skwargs["sharey"] | |
fig, axes = plt.subplots(M, M, **skwargs) | |
sc_kwargs = scatter_kwargs.copy() | |
sc_kwargs["edgecolor"] = "none" if not sc_kwargs.has_key("edgecolor") else sc_kwargs["edgecolor"] | |
sc_kwargs["c"] = "k" if not sc_kwargs.has_key("c") else sc_kwargs["c"] | |
sc_kwargs["s"] = 10 if not sc_kwargs.has_key("s") else sc_kwargs["s"] | |
xticks = yticks = None | |
for ii in range(M): | |
for jj in range(M): | |
axes[ii,jj].scatter(data[jj], data[ii], **sc_kwargs) | |
if yticks == None: | |
yticks = axes[ii,jj].get_yticks()[1:-1] | |
if xticks == None: | |
xticks = axes[ii,jj].get_xticks()[1:-1] | |
# first column | |
if jj == 0: | |
axes[ii,jj].set_ylabel(labels[ii]) | |
# Hack so ticklabels don't overlap | |
axes[ii,jj].yaxis.set_ticks(yticks) | |
# last row | |
if ii == M-1: | |
axes[ii,jj].set_xlabel(labels[jj]) | |
# Hack so ticklabels don't overlap | |
axes[ii,jj].xaxis.set_ticks(xticks) | |
fig = axes[0,0].figure | |
fig.subplots_adjust(hspace=0.0, wspace=0.0, left=0.08, bottom=0.08, top=0.9, right=0.9 ) | |
return fig, axes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add some examples! Also, are you using
astropy.units
anywhere?