Skip to content

Instantly share code, notes, and snippets.

@adrn
Created April 3, 2013 13:25
Show Gist options
  • Select an option

  • Save adrn/5301190 to your computer and use it in GitHub Desktop.

Select an option

Save adrn/5301190 to your computer and use it in GitHub Desktop.
Make a scatter-plot matrix with matplotlib
# coding: utf-8
""" Create a scatter-plot matrix using Matplotlib. """
from __future__ import division, print_function
__author__ = "adrn <adrn@astro.columbia.edu>"
# Standard library
import os, sys
# Third-party
import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u
def scatter_plot_matrix(data, labels=None, axes=None, subplots_kwargs=dict(),
scatter_kwargs=dict()):
""" Create a scatter plot matrix from the given data.
Parameters
----------
data : numpy.ndarray
A numpy array containined the scatter data to plot. The data
should be shape MxN where M is the number of dimensions and
with N data points.
labels : numpy.ndarray (optional)
A numpy array of length M containing the axis labels.
axes : matplotlib Axes array (optional)
If you've already created the axes objects, pass this in to
plot the data on that.
subplots_kwargs : dict (optional)
A dictionary of keyword arguments to pass to the
matplotlib.pyplot.subplots call. Note: only relevant if axes=None.
scatter_kwargs : dict (optional)
A dictionary of keyword arguments to pass to the
matplotlib.pyplot.scatter function calls.
"""
try:
M,N = data.shape
if M > N: raise ValueError()
except ValueError: # too many values to unpack
raise ValueError("Invalid data shape {0}. You must pass in an array of "
"shape (M, N) where N > M.".format(data.shape))
if labels == None:
labels = [None]*M
if axes == None:
skwargs = subplots_kwargs.copy()
skwargs["sharex"] = True if not skwargs.has_key("sharex") else skwargs["sharex"]
skwargs["sharey"] = True if not skwargs.has_key("sharey") else skwargs["sharey"]
fig, axes = plt.subplots(M, M, **skwargs)
sc_kwargs = scatter_kwargs.copy()
sc_kwargs["edgecolor"] = "none" if not sc_kwargs.has_key("edgecolor") else sc_kwargs["edgecolor"]
sc_kwargs["c"] = "k" if not sc_kwargs.has_key("c") else sc_kwargs["c"]
sc_kwargs["s"] = 10 if not sc_kwargs.has_key("s") else sc_kwargs["s"]
xticks = yticks = None
for ii in range(M):
for jj in range(M):
axes[ii,jj].scatter(data[jj], data[ii], **sc_kwargs)
if yticks == None:
yticks = axes[ii,jj].get_yticks()[1:-1]
if xticks == None:
xticks = axes[ii,jj].get_xticks()[1:-1]
# first column
if jj == 0:
axes[ii,jj].set_ylabel(labels[ii])
# Hack so ticklabels don't overlap
axes[ii,jj].yaxis.set_ticks(yticks)
# last row
if ii == M-1:
axes[ii,jj].set_xlabel(labels[jj])
# Hack so ticklabels don't overlap
axes[ii,jj].xaxis.set_ticks(xticks)
fig = axes[0,0].figure
fig.subplots_adjust(hspace=0.0, wspace=0.0, left=0.08, bottom=0.08, top=0.9, right=0.9 )
return fig, axes
@keflavich
Copy link
Copy Markdown

Add some examples! Also, are you using astropy.units anywhere?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment