Last active
November 14, 2016 03:45
-
-
Save kbarbary/6075466 to your computer and use it in GitHub Desktop.
animate_source function for sncosmo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed under a 3-clause BSD style license - see LICENSE.rst | |
"""Animate one or more sncosmo Source spectral time series.""" | |
from __future__ import division | |
import math | |
import numpy as np | |
from astropy.utils.misc import isiterable | |
from astropy.extern import six | |
from astropy.extern.six.moves import range | |
from sncosmo import Source, get_source | |
def animate_source(source, label=None, fps=30, length=20., | |
phase_range=(None, None), wave_range=(None, None), | |
match_peakphase=True, match_peakflux=True, | |
peakwave=4000., fname=None, still=False): | |
"""Animate spectral timeseries of model(s) using matplotlib.animation. | |
*Note:* Requires matplotlib v1.1 or higher. | |
Parameters | |
---------- | |
source : `~sncosmo.Source` or str or iterable thereof | |
The Source to animate or list of sources to animate. | |
label : str or list of str, optional | |
If given, label(s) for Sources, to be displayed in a legend on | |
the animation. | |
fps : int, optional | |
Frames per second. Default is 30. | |
length : float, optional | |
Movie length in seconds. Default is 15. | |
phase_range : (float, float), optional | |
Phase range to plot (in the timeframe of the first source if multiple | |
sources are given). `None` indicates to use the maximum extent of the | |
source(s). | |
wave_range : (float, float), optional | |
Wavelength range to plot. `None` indicates to use the maximum extent | |
of the source(s). | |
match_peakflux : bool, optional | |
For multiple sources, scale fluxes so that the peak of the spectrum | |
at the peak matches that of the first source. Default is | |
True. | |
match_peakphase : bool, optional | |
For multiple sources, shift additional sources so that the source's | |
reference phase matches that of the first source. | |
peakwave : float, optional | |
Wavelength used in match_peakflux and match_peakphase. Default is | |
4000. | |
fname : str, optional | |
If not `None`, save animation to file `fname`. Requires ffmpeg | |
to be installed with the appropriate codecs: If `fname` has | |
the extension '.mp4' the libx264 codec is used. If the | |
extension is '.webm' the VP8 codec is used. Otherwise, the | |
'mpeg4' codec is used. The first frame is also written to a | |
png. | |
still : bool, optional | |
When writing to a file, also save the first frame as a png file. | |
This is useful for displaying videos on a webpage. | |
Returns | |
------- | |
ani : `~matplotlib.animation.FuncAnimation` | |
Animation object that can be shown or saved. | |
Examples | |
-------- | |
Compare the salt2 and hsiao sources: | |
>>> import matplotlib.pyplot as plt | |
>>> ani = animate_source(['salt2', 'hsiao'], phase_range=(None, 30.), | |
... wave_range=(2000., 9200.)) | |
>>> plt.show() # doctest: +SKIP | |
Compare the salt2 source with ``x1=1`` to the same source with ``x1=0.``: | |
>>> m1 = sncosmo.get_source('salt2') | |
>>> m1.set(x1=1.) | |
>>> m2 = sncosmo.get_source('salt2') | |
>>> m2.set(x1=0.) | |
>>> ani = animate_source([m1, m2], label=['salt2, x1=1', 'salt2, x1=0']) | |
>>> plt.show() | |
""" | |
from matplotlib import pyplot as plt | |
from matplotlib import animation | |
# Convert input to a list (if it isn't already). | |
if (not isiterable(source)) or isinstance(source, six.string_types): | |
sources = [source] | |
else: | |
sources = source | |
# Check that all entries are Source or strings. | |
for m in sources: | |
if not (isinstance(m, six.string_types) or isinstance(m, Source)): | |
raise ValueError('str or Source instance expected for ' | |
'source(s)') | |
sources = [get_source(m) for m in sources] | |
# Get the source labels | |
if label is None: | |
labels = [None] * len(sources) | |
elif isinstance(label, six.string_types): | |
labels = [label] | |
else: | |
labels = label | |
if len(labels) != len(sources): | |
raise ValueError('if given, length of label must match ' | |
'that of source') | |
# Get a wavelength array for each source. | |
waves = [np.arange(m.minwave(), m.maxwave(), 10.) for m in sources] | |
# Phase offsets needed to match peak phases. | |
peakphases = [m.peakphase(peakwave) for m in sources] | |
if match_peakphase: | |
phase_offsets = [p - peakphases[0] for p in peakphases] | |
else: | |
phase_offsets = [0.] * len(sources) | |
# Determine phase range to display. | |
minphase, maxphase = phase_range | |
if minphase is None: | |
minphase = min([sources[i].minphase() - phase_offsets[i] for | |
i in range(len(sources))]) | |
if maxphase is None: | |
maxphase = max([sources[i].maxphase() - phase_offsets[i] for | |
i in range(len(sources))]) | |
# Determine the wavelength range to display. | |
minwave, maxwave = wave_range | |
if minwave is None: | |
minwave = min([m.minwave() for m in sources]) | |
if maxwave is None: | |
maxwave = max([m.maxwave() for m in sources]) | |
# source time interval between frames | |
phase_interval = (maxphase - minphase) / (length * fps) | |
# maximum flux density of entire spectrum at the peak phase | |
# for each source | |
max_fluxes = [np.max(m.flux(phase, w)) | |
for m, phase, w in zip(sources, peakphases, waves)] | |
# scaling factors | |
if match_peakflux: | |
peakfluxes = [m.flux(phase, peakwave) # Not the same as max_fluxes! | |
for m, phase in zip(sources, peakphases)] | |
scaling_factors = [peakfluxes[0] / f for f in peakfluxes] | |
global_max_flux = max_fluxes[0] | |
else: | |
scaling_factors = [1.] * len(sources) | |
global_max_flux = max(max_fluxes) | |
ymin = -0.06 * global_max_flux | |
ymax = 1.1 * global_max_flux | |
# Set up the figure, the axis, and the plot element we want to animate | |
fig = plt.figure() | |
ax = plt.axes(xlim=(minwave, maxwave), ylim=(ymin, ymax)) | |
plt.axhline(y=0., c='k') | |
plt.xlabel('Wavelength ($\\AA$)') | |
plt.ylabel('Flux Density ($F_\lambda$)') | |
phase_text = ax.text(0.05, 0.95, '', ha='left', va='top', | |
transform=ax.transAxes) | |
empty_lists = 2 * len(sources) * [[]] | |
lines = ax.plot(*empty_lists, lw=1) | |
if label is not None: | |
for line, l in zip(lines, labels): | |
line.set_label(l) | |
legend = plt.legend(loc='upper right') | |
def init(): | |
for line in lines: | |
line.set_data([], []) | |
phase_text.set_text('') | |
return tuple(lines) + (phase_text,) | |
def animate(i): | |
current_phase = minphase + phase_interval * i | |
for j in range(len(sources)): | |
y = sources[j].flux(current_phase + phase_offsets[j], waves[j]) | |
lines[j].set_data(waves[j], y * scaling_factors[j]) | |
phase_text.set_text('phase = {0:.1f}'.format(current_phase)) | |
return tuple(lines) + (phase_text,) | |
ani = animation.FuncAnimation(fig, animate, init_func=init, | |
frames=int(fps*length), interval=(1000./fps), | |
blit=True) | |
# Save the animation as an mp4 or webm file. | |
# This requires that ffmpeg is installed. | |
if fname is not None: | |
if still: | |
i = fname.rfind('.') | |
stillfname = fname[:i] + '.png' | |
plt.savefig(stillfname) | |
ext = fname[i+1:] | |
codec = {'mp4': 'libx264', 'webm': 'libvpx'}.get(ext, 'mpeg4') | |
ani.save(fname, fps=fps, codec=codec, extra_args=['-vcodec', codec], | |
writer='ffmpeg_file', bitrate=1800) | |
plt.close() | |
else: | |
return ani |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment