Skip to content

Instantly share code, notes, and snippets.

@rldotai
Created February 7, 2018 00:04
Show Gist options
  • Select an option

  • Save rldotai/8e319a2a5fa796c8ac97d42dc8d9ecd9 to your computer and use it in GitHub Desktop.

Select an option

Save rldotai/8e319a2a5fa796c8ac97d42dc8d9ecd9 to your computer and use it in GitHub Desktop.
Plot trajectories using matplotlib.
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
def plot_trajectories(lst, ax=None, colors=None, cmap=None, alpha=None, ):
"""Plot trajectories via matplotlib's line segments.
Parameters
----------
lst: sequence of sequence containing points
For example, a list of numpy arrays, where each array contains the points along
a trajectory, e.g., [(x0, y0), (x1, y1), ..., (xn, yn)].
ax: matplotlib axis, optional
The axis on which to plot. If not provided, an axis is created.
colors: sequence of sequence of floats, optional
A list of the same "shape" as `lst`, containing the colors that should be associated
with each point along the trajectory
cmap: matplotlib colormap, optional
Colormap to use for the trajectories.
alpha:
Transparency value for the line segments.
"""
if ax is None:
fig, ax = plt.subplots()
for ix, pts in enumerate(lst):
line = pts.reshape(-1, 1, 2)
segments = np.concatenate([line[:-1], line[1:]], axis=1)
# create line segments and apply options
lc = mpl.collections.LineCollection(segments, cmap=cmap)
if alpha is not None:
lc.set_alpha(alpha)
if colors is not None:
lc.set_array(colors[ix])
# plot the segment
ax.add_collection(lc)
return ax
# For example, we plot some curves that are slightly offset from one another
"""
#xvals = [np.sin(np.linspace(0, 4*np.pi)) for i in range(10)]
>>> xvals = [np.linspace(-1.5, 1.5) for i in range(10)]
>>> yvals = [i*0.1 + np.linspace(0,1)*np.cos(np.linspace(0, 4*np.pi)) for i in range(10)]
>>> tlst = np.array([[(x, y) for x, y in zip(i, j)] for i, j in zip(xvals, yvals)])
>>> np.shape(tlst)
(10, 50, 2)
# Plot the trajectories
>>> fig, ax = plt.subplots()
>>> plot_trajectories(tlst, ax=ax)
# Adjust limits
>>> ax.set_xlim([-2,2])
>>> ax.set_ylim([-2, 2])
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment