Created
July 18, 2019 18:14
-
-
Save johnhw/e35ff373a70e26333edc1e9f7cffbc14 to your computer and use it in GitHub Desktop.
Matplotlib draw connections between two sets of 2D points efficiently
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from matplotlib.collections import LineCollection | |
def plot_connected_pairs(a,b,*args,**kwargs): | |
""" | |
Draw lines between two arrays of 2D points. | |
a and b must be the same shape, both [N,2] arrays to be plotted. | |
Parameters: | |
----------- | |
a: [N,2] array of points to plot from | |
b: [N,2] array of points to plot to | |
Any other arguments or keyword arguments are passed | |
to LineCollection directly. | |
Returns: | |
-------- | |
LineCollection of connecting lines | |
""" | |
segs = np.einsum('nij -> inj', np.stack([a,b])) | |
line_segments = LineCollection(segs, *args, **kwargs) | |
plt.gca().add_collection(line_segments) | |
return line_segments |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment