Skip to content

Instantly share code, notes, and snippets.

@kbarbary
Last active November 14, 2016 03:45
Show Gist options
  • Save kbarbary/6075466 to your computer and use it in GitHub Desktop.
Save kbarbary/6075466 to your computer and use it in GitHub Desktop.
animate_source function for sncosmo
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Animate one or more sncosmo Source spectral time series."""
from __future__ import division
import math
import numpy as np
from astropy.utils.misc import isiterable
from astropy.extern import six
from astropy.extern.six.moves import range
from sncosmo import Source, get_source
def animate_source(source, label=None, fps=30, length=20.,
phase_range=(None, None), wave_range=(None, None),
match_peakphase=True, match_peakflux=True,
peakwave=4000., fname=None, still=False):
"""Animate spectral timeseries of model(s) using matplotlib.animation.
*Note:* Requires matplotlib v1.1 or higher.
Parameters
----------
source : `~sncosmo.Source` or str or iterable thereof
The Source to animate or list of sources to animate.
label : str or list of str, optional
If given, label(s) for Sources, to be displayed in a legend on
the animation.
fps : int, optional
Frames per second. Default is 30.
length : float, optional
Movie length in seconds. Default is 15.
phase_range : (float, float), optional
Phase range to plot (in the timeframe of the first source if multiple
sources are given). `None` indicates to use the maximum extent of the
source(s).
wave_range : (float, float), optional
Wavelength range to plot. `None` indicates to use the maximum extent
of the source(s).
match_peakflux : bool, optional
For multiple sources, scale fluxes so that the peak of the spectrum
at the peak matches that of the first source. Default is
True.
match_peakphase : bool, optional
For multiple sources, shift additional sources so that the source's
reference phase matches that of the first source.
peakwave : float, optional
Wavelength used in match_peakflux and match_peakphase. Default is
4000.
fname : str, optional
If not `None`, save animation to file `fname`. Requires ffmpeg
to be installed with the appropriate codecs: If `fname` has
the extension '.mp4' the libx264 codec is used. If the
extension is '.webm' the VP8 codec is used. Otherwise, the
'mpeg4' codec is used. The first frame is also written to a
png.
still : bool, optional
When writing to a file, also save the first frame as a png file.
This is useful for displaying videos on a webpage.
Returns
-------
ani : `~matplotlib.animation.FuncAnimation`
Animation object that can be shown or saved.
Examples
--------
Compare the salt2 and hsiao sources:
>>> import matplotlib.pyplot as plt
>>> ani = animate_source(['salt2', 'hsiao'], phase_range=(None, 30.),
... wave_range=(2000., 9200.))
>>> plt.show() # doctest: +SKIP
Compare the salt2 source with ``x1=1`` to the same source with ``x1=0.``:
>>> m1 = sncosmo.get_source('salt2')
>>> m1.set(x1=1.)
>>> m2 = sncosmo.get_source('salt2')
>>> m2.set(x1=0.)
>>> ani = animate_source([m1, m2], label=['salt2, x1=1', 'salt2, x1=0'])
>>> plt.show()
"""
from matplotlib import pyplot as plt
from matplotlib import animation
# Convert input to a list (if it isn't already).
if (not isiterable(source)) or isinstance(source, six.string_types):
sources = [source]
else:
sources = source
# Check that all entries are Source or strings.
for m in sources:
if not (isinstance(m, six.string_types) or isinstance(m, Source)):
raise ValueError('str or Source instance expected for '
'source(s)')
sources = [get_source(m) for m in sources]
# Get the source labels
if label is None:
labels = [None] * len(sources)
elif isinstance(label, six.string_types):
labels = [label]
else:
labels = label
if len(labels) != len(sources):
raise ValueError('if given, length of label must match '
'that of source')
# Get a wavelength array for each source.
waves = [np.arange(m.minwave(), m.maxwave(), 10.) for m in sources]
# Phase offsets needed to match peak phases.
peakphases = [m.peakphase(peakwave) for m in sources]
if match_peakphase:
phase_offsets = [p - peakphases[0] for p in peakphases]
else:
phase_offsets = [0.] * len(sources)
# Determine phase range to display.
minphase, maxphase = phase_range
if minphase is None:
minphase = min([sources[i].minphase() - phase_offsets[i] for
i in range(len(sources))])
if maxphase is None:
maxphase = max([sources[i].maxphase() - phase_offsets[i] for
i in range(len(sources))])
# Determine the wavelength range to display.
minwave, maxwave = wave_range
if minwave is None:
minwave = min([m.minwave() for m in sources])
if maxwave is None:
maxwave = max([m.maxwave() for m in sources])
# source time interval between frames
phase_interval = (maxphase - minphase) / (length * fps)
# maximum flux density of entire spectrum at the peak phase
# for each source
max_fluxes = [np.max(m.flux(phase, w))
for m, phase, w in zip(sources, peakphases, waves)]
# scaling factors
if match_peakflux:
peakfluxes = [m.flux(phase, peakwave) # Not the same as max_fluxes!
for m, phase in zip(sources, peakphases)]
scaling_factors = [peakfluxes[0] / f for f in peakfluxes]
global_max_flux = max_fluxes[0]
else:
scaling_factors = [1.] * len(sources)
global_max_flux = max(max_fluxes)
ymin = -0.06 * global_max_flux
ymax = 1.1 * global_max_flux
# Set up the figure, the axis, and the plot element we want to animate
fig = plt.figure()
ax = plt.axes(xlim=(minwave, maxwave), ylim=(ymin, ymax))
plt.axhline(y=0., c='k')
plt.xlabel('Wavelength ($\\AA$)')
plt.ylabel('Flux Density ($F_\lambda$)')
phase_text = ax.text(0.05, 0.95, '', ha='left', va='top',
transform=ax.transAxes)
empty_lists = 2 * len(sources) * [[]]
lines = ax.plot(*empty_lists, lw=1)
if label is not None:
for line, l in zip(lines, labels):
line.set_label(l)
legend = plt.legend(loc='upper right')
def init():
for line in lines:
line.set_data([], [])
phase_text.set_text('')
return tuple(lines) + (phase_text,)
def animate(i):
current_phase = minphase + phase_interval * i
for j in range(len(sources)):
y = sources[j].flux(current_phase + phase_offsets[j], waves[j])
lines[j].set_data(waves[j], y * scaling_factors[j])
phase_text.set_text('phase = {0:.1f}'.format(current_phase))
return tuple(lines) + (phase_text,)
ani = animation.FuncAnimation(fig, animate, init_func=init,
frames=int(fps*length), interval=(1000./fps),
blit=True)
# Save the animation as an mp4 or webm file.
# This requires that ffmpeg is installed.
if fname is not None:
if still:
i = fname.rfind('.')
stillfname = fname[:i] + '.png'
plt.savefig(stillfname)
ext = fname[i+1:]
codec = {'mp4': 'libx264', 'webm': 'libvpx'}.get(ext, 'mpeg4')
ani.save(fname, fps=fps, codec=codec, extra_args=['-vcodec', codec],
writer='ffmpeg_file', bitrate=1800)
plt.close()
else:
return ani
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment