Last active
September 2, 2015 10:10
-
-
Save swenzel/e18e3c991b4349c6bcb6 to your computer and use it in GitHub Desktop.
Live spike raster plot for PyNEST
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE | |
# Version 2, December 2004 | |
# | |
# Copyright (C) 2015 Swen Wenzel <[email protected]> | |
# | |
# Everyone is permitted to copy and distribute verbatim or modified | |
# copies of this license document, and changing it is allowed as long | |
# as the name is changed. | |
# | |
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE | |
# TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION | |
# | |
# 0. You just DO WHAT THE FUCK YOU WANT TO. | |
import multiprocessing as mp | |
import ctypes | |
import numpy as np | |
import nest | |
class LivePlot(mp.Process): | |
""" | |
:param senders: Sender ids, smallest and largest determine y-axis range. | |
If set to `None` the axis will be automatically updated. | |
:type senders: list or tuple of int | |
:param title: Title used for the figure | |
:type title: str | |
:param interval: Timeinterval in ms between figure updates | |
:type interval: int | |
:param timeframe: Defines in ms how far the plot goes back in time | |
:type timeframe: positive int | |
:param bufsize: Defines how many spike events are buffered | |
:type bufsize: number (decimals will be ignored) | |
This class automatically opens a figure which is constantly updated within | |
its own process. The process is directly started and therefore the object | |
is ready to use after creation. Due to process start up delays the figure | |
will only appear shortly after creation. If desired, it is possible to wait | |
for the figure (see func:`waitForFigure`). | |
Usage example: | |
>>> import nest | |
>>> from live_plot import LivePlot | |
>>> spike_generator = nest.Create('poisson_generator') | |
>>> nest.SetStatus(spike_generator, 'rate', 10.0) | |
>>> spike_detector = nest.Create('spike_detector') | |
>>> nest.Connect(spike_generator, spike_detector) | |
>>> myLivePlot = LivePlot() | |
>>> myLivePlot.waitForFigure() | |
>>> # closing the window will finish the thread and stop the simulation | |
>>> while myLivePlot.is_alive(): | |
>>> nest.Simulate(10) | |
>>> events = nest.GetStatus(spike_detector, 'events') | |
>>> # remove events which were just read to avoid adding them twice | |
>>> nest.SetStatus(spike_detector, 'n_events', 0) | |
>>> myLivePlot.addData(events) | |
""" | |
def __init__(self, senders=None, title=None, interval=100, timeframe=500, | |
bufsize=1e5): | |
super(LivePlot, self).__init__() | |
bufsize = int(bufsize) | |
self._datalock = mp.Lock() | |
self._stopEvent = mp.Event() | |
self._readyEvent = mp.Event() | |
self._times = mp.Array(ctypes.c_double, bufsize) | |
self._senders = mp.Array(ctypes.c_int, bufsize) | |
self._now = mp.Value(ctypes.c_double) | |
self._len = mp.Value(ctypes.c_uint) | |
self._interval = interval | |
self._minsender = mp.Value(ctypes.c_double) | |
self._maxsender = mp.Value(ctypes.c_double) | |
if senders is None: | |
self._autoSenders = True | |
self._maxsender.value = -np.inf | |
self._minsender.value = np.inf | |
else: | |
self._autoSenders = False | |
self._maxsender.value = np.max(senders) | |
self._minsender.value = np.min(senders) | |
# will be used as reference timepoint for the timewindow | |
self._now.value = 0 | |
# since the array won't be filled from the beginning, _len will keep | |
# track of the amount of entries | |
self._len.value = 0 | |
# since we have ringbuffer-like memory arrays, we have to keep track of | |
# the next insert point | |
self._nextInsert = 0 | |
if timeframe <= 0: | |
raise ValueError('timeframe has to be >0!') | |
self._timeframe = timeframe | |
self._title = title | |
self.start() | |
def run(self): | |
""" | |
This thread's main method. Even though it is not marked as private, it | |
is not supposed to be called from outside! | |
""" | |
import matplotlib.pyplot as plt | |
from matplotlib import animation | |
self._figure = plt.figure() | |
# need to save it to local variable, otherwise it doesn't work | |
# might be an issue with the garbage collector | |
_ani = animation.FuncAnimation(self._figure, self._update, | |
self._getData, interval=self._interval, | |
init_func=self._initFigure) | |
self._plt.set_xlim(-self._timeframe, 0) | |
self._plt.set_ylabel('Neuron ID') | |
plt.show(block=True) | |
def stop(self): | |
""" | |
Tells the thread to close the window and therefore stop the live plot. | |
It is save to call this method multiple times. | |
""" | |
if self.is_alive(): | |
self._stopEvent.set() | |
self.join(timeout=5) | |
def addData(self, events, now=None): | |
""" | |
:param events: New data coming from | |
`nest.GetStatus(spike_detector, 'events')` | |
:type events: dict or list/tuple of dicts | |
:param now: Used to set to current simulation time in ms | |
If set to `None`, `nest.GetKernelStatus('time')` will be | |
used. | |
:type now: int | |
Adds new spike events to this livePlot's memory. Also updates the | |
current simulation time and therefore shifts the spikes in the next | |
frame update. | |
.. warning:: | |
Duplicate events are not filtered out! So don't forget to call | |
`nest.SetStatus(spike_detector, 'n_events', 0)` | |
""" | |
if not isinstance(events, (dict, list, tuple)): | |
raise TypeError("events must be dict or list of dicts, " + | |
"not {}".format(type(events))) | |
if isinstance(events, (list, tuple)): | |
for e in events: | |
self.addData(e) | |
return | |
senders = events['senders'] | |
times = events['times'] | |
assert len(times) == len(senders), ('times and senders have to be ' + | |
'the same size!') | |
times = np.array(times, dtype=np.float64) | |
senders = np.array(senders, dtype=np.int32) | |
with self._datalock: | |
myTimes = np.frombuffer(self._times.get_obj(), | |
dtype=ctypes.c_double) | |
mySenders = np.frombuffer(self._senders.get_obj(), | |
dtype=ctypes.c_uint) | |
if now is None: | |
self._now.value = nest.GetKernelStatus('time') | |
else: | |
self._now.value = now | |
if len(times) == 0: | |
return | |
# compute indices which start at the the next insert position | |
idx = np.arange(self._nextInsert, times.size+self._nextInsert, | |
dtype=np.int32) | |
# wrap indices around the size | |
idx %= myTimes.size | |
# update next insert position | |
self._nextInsert = idx[-1]+1 | |
# since our buffer is not full from the beginning we have to | |
# continually increase self._len | |
if self._len.value != myTimes.size: | |
self._len.value = min(myTimes.size, self._len.value+len(times)) | |
# insert new values at their appropriate positions | |
myTimes[idx] = times | |
mySenders[idx] = senders | |
# update minsender and maxsender if necessary | |
if self._autoSenders: | |
minsender = np.min(senders) | |
maxsender = np.max(senders) | |
if self._minsender.value > minsender: | |
self._minsender.value = minsender | |
if self._maxsender.value < maxsender: | |
self._maxsender.value = maxsender | |
def waitForFigure(self): | |
self._readyEvent.wait() | |
def _initFigure(self): | |
if self._title is not None: | |
self._figure.suptitle(self._title) | |
self._figure.show() | |
self._plt = self._figure.add_subplot(111) | |
self._scat = self._plt.scatter([], [], marker='|') | |
self._readyEvent.set() | |
def _update(self, data): | |
""" | |
Updates the figure using `data` coming from :func:`_getData`. | |
""" | |
with self._datalock: | |
minsender = self._minsender.value | |
maxsender = self._maxsender.value | |
if data == 'close': | |
import matplotlib.pyplot as plt | |
plt.close() | |
else: | |
curtime, data = data | |
if len(data) > 0: | |
self._scat.set_offsets(data) | |
# try not to have events on the edges of the plot | |
margin = 0.1*(maxsender-minsender) | |
self._plt.set_ylim(minsender-margin, maxsender+margin) | |
self._plt.set_xlabel('time since: {} ms'.format(curtime)) | |
def _getData(self): | |
""" | |
Data 'generator' for matplotlib's animation. | |
""" | |
myTimes = np.frombuffer(self._times.get_obj(), dtype=ctypes.c_double) | |
mySenders = np.frombuffer(self._senders.get_obj(), dtype=ctypes.c_uint) | |
while not self._stopEvent.is_set(): | |
with self._datalock: | |
curtime = self._now.value | |
# inFrame is a boolean mask initialized to all False | |
# so it can be used together with the next step to | |
# hide all invalid values | |
inFrame = np.zeros_like(myTimes, dtype=np.bool) | |
# from the valid values (i.e. the ones until _len) filter out | |
# those which are too old | |
myLen = self._len.value | |
inFrame[:myLen] = myTimes[:myLen] >= curtime-self._timeframe | |
if np.any(inFrame): | |
data = list(zip(myTimes[inFrame]-curtime, | |
mySenders[inFrame])) | |
else: | |
data = [[], []] | |
yield curtime, data | |
yield 'close' | |
def main(): | |
import time | |
plot = LivePlot(title='Random example data') | |
plot.waitForFigure() | |
now = 0 | |
timedif = 50 | |
nEvents = 100 | |
for _ in range(100000): | |
if not plot.is_alive(): | |
break | |
senders = np.random.randint(100, size=nEvents) | |
times = now + np.random.rand(nEvents)*timedif | |
times.sort() | |
now += timedif | |
plot.addData({'times': times, 'senders': senders}, now) | |
time.sleep(timedif/1000.) | |
plot.stop() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment