Created
May 31, 2014 16:55
-
-
Save kingjr/0807ebd6f458fdfd8bef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_evoked_img(evoked, picks=None, exclude='bads', unit=True, show=True, | |
ylim=None, proj=False, xlim='tight', hline=None, units=None, | |
scalings=None, titles=None, axes=None): | |
"""Plot evoked data as an image (chan x time) where color index amplitude | |
Parameters | |
---------- | |
evoked : instance of Evoked | |
The evoked data | |
picks : array-like of int | None | |
The indices of channels to plot. If None show all. | |
exclude : list of str | 'bads' | |
Channels names to exclude from being shown. If 'bads', the | |
bad channels are excluded. | |
show : bool | |
Call pyplot.show() as the end or not. | |
ylim : dict | None | |
ylim for plots. e.g. ylim = dict(eeg=[-200e-6, 200e6]) | |
Valid keys are eeg, mag, grad, misc. If None, the ylim parameter | |
for each channel equals the pyplot default. | |
xlim : 'tight' | tuple | None | |
xlim for plots. | |
proj : bool | 'interactive' | |
If true SSP projections are applied before display. If 'interactive', | |
a check box for reversible selection of SSP projection vectors will | |
be shown. | |
hline : list of floats | None | |
The values at which to show an horizontal line. | |
units : dict | None | |
The units of the channel types used for axes lables. If None, | |
defaults to `dict(eeg='uV', grad='fT/cm', mag='fT')`. | |
scalings : dict | None | |
The scalings of the channel types to be applied for plotting. If None,` | |
defaults to `dict(eeg=1e6, grad=1e13, mag=1e15)`. | |
titles : dict | None | |
The titles associated with the channels. If None, defaults to | |
`dict(eeg='EEG', grad='Gradiometers', mag='Magnetometers')`. | |
axes : instance of Axes | list | None | |
The axes to plot to. If list, the list must be a list of Axes of | |
the same length as the number of channel types. If instance of | |
Axes, there must be only one channel type plotted. | |
""" | |
import matplotlib.pyplot as plt | |
if axes is not None and proj == 'interactive': | |
raise RuntimeError('Currently only single axis figures are supported' | |
' for interactive SSP selection.') | |
scalings, titles, units = _mutable_defaults(('scalings', scalings), | |
('titles', titles), | |
('units', units)) | |
channel_types = set(key for d in [scalings, titles, units] for key in d) | |
if picks is None: | |
picks = list(range(evoked.info['nchan'])) | |
bad_ch_idx = [evoked.ch_names.index(ch) for ch in evoked.info['bads'] | |
if ch in evoked.ch_names] | |
if len(exclude) > 0: | |
if isinstance(exclude, string_types) and exclude == 'bads': | |
exclude = bad_ch_idx | |
elif (isinstance(exclude, list) | |
and all([isinstance(ch, string_types) for ch in exclude])): | |
exclude = [evoked.ch_names.index(ch) for ch in exclude] | |
else: | |
raise ValueError('exclude has to be a list of channel names or ' | |
'"bads"') | |
picks = list(set(picks).difference(exclude)) | |
types = [channel_type(evoked.info, idx) for idx in picks] | |
n_channel_types = 0 | |
ch_types_used = [] | |
for t in channel_types: | |
if t in types: | |
n_channel_types += 1 | |
ch_types_used.append(t) | |
axes_init = axes # remember if axes where given as input | |
fig = None | |
if axes is None: | |
fig, axes = plt.subplots(n_channel_types, 1) | |
if isinstance(axes, plt.Axes): | |
axes = [axes] | |
elif isinstance(axes, np.ndarray): | |
axes = list(axes) | |
if axes_init is not None: | |
fig = axes[0].get_figure() | |
if not len(axes) == n_channel_types: | |
raise ValueError('Number of axes (%g) must match number of channel ' | |
'types (%g)' % (len(axes), n_channel_types)) | |
# instead of projecting during each iteration let's use the mixin here. | |
if proj is True and evoked.proj is not True: | |
evoked = evoked.copy() | |
evoked.apply_proj() | |
times = 1e3 * evoked.times # time in miliseconds | |
for ax, t in zip(axes, ch_types_used): | |
this_scaling = scalings[t] | |
idx = [picks[i] for i in range(len(picks)) if types[i] == t] | |
if len(idx) > 0: | |
D = this_scaling * evoked.data[idx, :] | |
# plt.axes(ax) | |
if ylim is not None and t in ylim: | |
im = ax.imshow(D, interpolation='nearest', origin='lower', | |
extent=[times[0], times[-1], 0, D.shape[0]], aspect='auto', | |
vmin=ylim[t][0], vmax=ylim[t][1]) | |
else: | |
im = ax.imshow(D, interpolation='nearest', origin='lower', | |
extent=[times[0], times[-1], 0, D.shape[0]], aspect='auto') | |
if xlim is not None: | |
if xlim == 'tight': | |
xlim = (times[0], times[-1]) | |
ax.set_xlim(xlim) | |
ax.set_title(titles[t] + ' (%d channel%s)' % ( | |
len(D), 's' if len(D) > 1 else '')) | |
ax.set_xlabel('time (ms)') | |
ax.set_ylabel('channels') | |
plt.colorbar(im, ax=ax) | |
if axes_init is None: | |
plt.subplots_adjust(0.175, 0.08, 0.94, 0.94, 0.2, 0.63) | |
if show and plt.get_backend() != 'agg': | |
plt.show() | |
fig.canvas.draw() # for axes plots update axes. | |
tight_layout(fig=fig) | |
return fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment