Skip to content

Instantly share code, notes, and snippets.

@pltrdy
Last active July 16, 2020 14:43
Show Gist options
  • Save pltrdy/8a43bdd7d6fda9c81ce7c175e06ab698 to your computer and use it in GitHub Desktop.
Save pltrdy/8a43bdd7d6fda9c81ce7c175e06ab698 to your computer and use it in GitHub Desktop.
def scatter_hist(x, y, xlabel="", ylabel="", title="", left=0.1, width=0.65,
bottom=0.1, height=0.65, figsize=(8,8,),
ortho=False,
xbinwidth=None,
ybinwidth=None,
binwidth=5):
# scatter hist from [1] with small tweaks to work with different
# distributions on X and Y:
# [1]: https://matplotlib.org/examples/pylab_examples/scatter_hist.html
nullfmt = NullFormatter() # no labels
# definitions for the axes
bottom_h = left_h = left + width + 0.02
rect_scatter = [left, bottom, width, height]
rect_histx = [left, bottom_h, width, 0.2]
rect_histy = [left_h, bottom, 0.2, height]
# start with a rectangular Figure
fig = plt.figure(title, figsize=figsize)
axScatter = plt.axes(rect_scatter)
axHistx = plt.axes(rect_histx)
axHisty = plt.axes(rect_histy)
# no labels
axHistx.xaxis.set_major_formatter(nullfmt)
axHisty.yaxis.set_major_formatter(nullfmt)
# the scatter plot:
axScatter.scatter(x, y)
axScatter.set_xlabel(xlabel)
axScatter.set_ylabel(ylabel)
# now determine nice limits by hand:
xymin = np.min([np.min(x), np.min(y)])
xymax = np.max([np.max(x), np.max(y)])
lim = (int((xymax-xymin)/binwidth) + 1) * binwidth
if ortho:
assert xbinwidth is None
assert ybinwidth is None
axScatter.set_xlim((xymin, xymax))
axScatter.set_ylim((xymin, xymax))
bins = np.arange(-lim, lim + binwidth, binwidth)
axHistx.hist(x, bins=bins)
axHisty.hist(y, bins=bins, orientation='horizontal')
else:
if xbinwidth is None:
xbinwidth = binwidth
if ybinwidth is None:
ybinwidth = binwidth
axScatter.set_xlim((np.min(x), np.max(x)))
axScatter.set_ylim((np.min(y), np.max(y)))
xlim = (int((np.max(x)-np.min(x))/xbinwidth) + 1) * xbinwidth
ylim = (int((np.max(y)-np.min(y))/ybinwidth) + 1) * ybinwidth
xbins = np.arange(-xlim, xlim + xbinwidth, xbinwidth)
ybins = np.arange(-ylim, ylim + ybinwidth, ybinwidth)
axHistx.hist(x, bins=xbins)
axHisty.hist(y, bins=ybins, orientation='horizontal')
axHistx.set_xlim(axScatter.get_xlim())
axHisty.set_ylim(axScatter.get_ylim())
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment