Skip to content

Instantly share code, notes, and snippets.

@johnhw
Created July 18, 2019 18:14
Show Gist options
  • Save johnhw/e35ff373a70e26333edc1e9f7cffbc14 to your computer and use it in GitHub Desktop.
Save johnhw/e35ff373a70e26333edc1e9f7cffbc14 to your computer and use it in GitHub Desktop.
Matplotlib draw connections between two sets of 2D points efficiently
from matplotlib.collections import LineCollection
def plot_connected_pairs(a,b,*args,**kwargs):
"""
Draw lines between two arrays of 2D points.
a and b must be the same shape, both [N,2] arrays to be plotted.
Parameters:
-----------
a: [N,2] array of points to plot from
b: [N,2] array of points to plot to
Any other arguments or keyword arguments are passed
to LineCollection directly.
Returns:
--------
LineCollection of connecting lines
"""
segs = np.einsum('nij -> inj', np.stack([a,b]))
line_segments = LineCollection(segs, *args, **kwargs)
plt.gca().add_collection(line_segments)
return line_segments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment