Skip to content

Instantly share code, notes, and snippets.

@swenzel
Last active September 2, 2015 10:10
Show Gist options
  • Save swenzel/e18e3c991b4349c6bcb6 to your computer and use it in GitHub Desktop.
Save swenzel/e18e3c991b4349c6bcb6 to your computer and use it in GitHub Desktop.
Live spike raster plot for PyNEST
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
# Version 2, December 2004
#
# Copyright (C) 2015 Swen Wenzel <[email protected]>
#
# Everyone is permitted to copy and distribute verbatim or modified
# copies of this license document, and changing it is allowed as long
# as the name is changed.
#
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
# TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
#
# 0. You just DO WHAT THE FUCK YOU WANT TO.
import multiprocessing as mp
import ctypes
import numpy as np
import nest
class LivePlot(mp.Process):
"""
:param senders: Sender ids, smallest and largest determine y-axis range.
If set to `None` the axis will be automatically updated.
:type senders: list or tuple of int
:param title: Title used for the figure
:type title: str
:param interval: Timeinterval in ms between figure updates
:type interval: int
:param timeframe: Defines in ms how far the plot goes back in time
:type timeframe: positive int
:param bufsize: Defines how many spike events are buffered
:type bufsize: number (decimals will be ignored)
This class automatically opens a figure which is constantly updated within
its own process. The process is directly started and therefore the object
is ready to use after creation. Due to process start up delays the figure
will only appear shortly after creation. If desired, it is possible to wait
for the figure (see func:`waitForFigure`).
Usage example:
>>> import nest
>>> from live_plot import LivePlot
>>> spike_generator = nest.Create('poisson_generator')
>>> nest.SetStatus(spike_generator, 'rate', 10.0)
>>> spike_detector = nest.Create('spike_detector')
>>> nest.Connect(spike_generator, spike_detector)
>>> myLivePlot = LivePlot()
>>> myLivePlot.waitForFigure()
>>> # closing the window will finish the thread and stop the simulation
>>> while myLivePlot.is_alive():
>>> nest.Simulate(10)
>>> events = nest.GetStatus(spike_detector, 'events')
>>> # remove events which were just read to avoid adding them twice
>>> nest.SetStatus(spike_detector, 'n_events', 0)
>>> myLivePlot.addData(events)
"""
def __init__(self, senders=None, title=None, interval=100, timeframe=500,
bufsize=1e5):
super(LivePlot, self).__init__()
bufsize = int(bufsize)
self._datalock = mp.Lock()
self._stopEvent = mp.Event()
self._readyEvent = mp.Event()
self._times = mp.Array(ctypes.c_double, bufsize)
self._senders = mp.Array(ctypes.c_int, bufsize)
self._now = mp.Value(ctypes.c_double)
self._len = mp.Value(ctypes.c_uint)
self._interval = interval
self._minsender = mp.Value(ctypes.c_double)
self._maxsender = mp.Value(ctypes.c_double)
if senders is None:
self._autoSenders = True
self._maxsender.value = -np.inf
self._minsender.value = np.inf
else:
self._autoSenders = False
self._maxsender.value = np.max(senders)
self._minsender.value = np.min(senders)
# will be used as reference timepoint for the timewindow
self._now.value = 0
# since the array won't be filled from the beginning, _len will keep
# track of the amount of entries
self._len.value = 0
# since we have ringbuffer-like memory arrays, we have to keep track of
# the next insert point
self._nextInsert = 0
if timeframe <= 0:
raise ValueError('timeframe has to be >0!')
self._timeframe = timeframe
self._title = title
self.start()
def run(self):
"""
This thread's main method. Even though it is not marked as private, it
is not supposed to be called from outside!
"""
import matplotlib.pyplot as plt
from matplotlib import animation
self._figure = plt.figure()
# need to save it to local variable, otherwise it doesn't work
# might be an issue with the garbage collector
_ani = animation.FuncAnimation(self._figure, self._update,
self._getData, interval=self._interval,
init_func=self._initFigure)
self._plt.set_xlim(-self._timeframe, 0)
self._plt.set_ylabel('Neuron ID')
plt.show(block=True)
def stop(self):
"""
Tells the thread to close the window and therefore stop the live plot.
It is save to call this method multiple times.
"""
if self.is_alive():
self._stopEvent.set()
self.join(timeout=5)
def addData(self, events, now=None):
"""
:param events: New data coming from
`nest.GetStatus(spike_detector, 'events')`
:type events: dict or list/tuple of dicts
:param now: Used to set to current simulation time in ms
If set to `None`, `nest.GetKernelStatus('time')` will be
used.
:type now: int
Adds new spike events to this livePlot's memory. Also updates the
current simulation time and therefore shifts the spikes in the next
frame update.
.. warning::
Duplicate events are not filtered out! So don't forget to call
`nest.SetStatus(spike_detector, 'n_events', 0)`
"""
if not isinstance(events, (dict, list, tuple)):
raise TypeError("events must be dict or list of dicts, " +
"not {}".format(type(events)))
if isinstance(events, (list, tuple)):
for e in events:
self.addData(e)
return
senders = events['senders']
times = events['times']
assert len(times) == len(senders), ('times and senders have to be ' +
'the same size!')
times = np.array(times, dtype=np.float64)
senders = np.array(senders, dtype=np.int32)
with self._datalock:
myTimes = np.frombuffer(self._times.get_obj(),
dtype=ctypes.c_double)
mySenders = np.frombuffer(self._senders.get_obj(),
dtype=ctypes.c_uint)
if now is None:
self._now.value = nest.GetKernelStatus('time')
else:
self._now.value = now
if len(times) == 0:
return
# compute indices which start at the the next insert position
idx = np.arange(self._nextInsert, times.size+self._nextInsert,
dtype=np.int32)
# wrap indices around the size
idx %= myTimes.size
# update next insert position
self._nextInsert = idx[-1]+1
# since our buffer is not full from the beginning we have to
# continually increase self._len
if self._len.value != myTimes.size:
self._len.value = min(myTimes.size, self._len.value+len(times))
# insert new values at their appropriate positions
myTimes[idx] = times
mySenders[idx] = senders
# update minsender and maxsender if necessary
if self._autoSenders:
minsender = np.min(senders)
maxsender = np.max(senders)
if self._minsender.value > minsender:
self._minsender.value = minsender
if self._maxsender.value < maxsender:
self._maxsender.value = maxsender
def waitForFigure(self):
self._readyEvent.wait()
def _initFigure(self):
if self._title is not None:
self._figure.suptitle(self._title)
self._figure.show()
self._plt = self._figure.add_subplot(111)
self._scat = self._plt.scatter([], [], marker='|')
self._readyEvent.set()
def _update(self, data):
"""
Updates the figure using `data` coming from :func:`_getData`.
"""
with self._datalock:
minsender = self._minsender.value
maxsender = self._maxsender.value
if data == 'close':
import matplotlib.pyplot as plt
plt.close()
else:
curtime, data = data
if len(data) > 0:
self._scat.set_offsets(data)
# try not to have events on the edges of the plot
margin = 0.1*(maxsender-minsender)
self._plt.set_ylim(minsender-margin, maxsender+margin)
self._plt.set_xlabel('time since: {} ms'.format(curtime))
def _getData(self):
"""
Data 'generator' for matplotlib's animation.
"""
myTimes = np.frombuffer(self._times.get_obj(), dtype=ctypes.c_double)
mySenders = np.frombuffer(self._senders.get_obj(), dtype=ctypes.c_uint)
while not self._stopEvent.is_set():
with self._datalock:
curtime = self._now.value
# inFrame is a boolean mask initialized to all False
# so it can be used together with the next step to
# hide all invalid values
inFrame = np.zeros_like(myTimes, dtype=np.bool)
# from the valid values (i.e. the ones until _len) filter out
# those which are too old
myLen = self._len.value
inFrame[:myLen] = myTimes[:myLen] >= curtime-self._timeframe
if np.any(inFrame):
data = list(zip(myTimes[inFrame]-curtime,
mySenders[inFrame]))
else:
data = [[], []]
yield curtime, data
yield 'close'
def main():
import time
plot = LivePlot(title='Random example data')
plot.waitForFigure()
now = 0
timedif = 50
nEvents = 100
for _ in range(100000):
if not plot.is_alive():
break
senders = np.random.randint(100, size=nEvents)
times = now + np.random.rand(nEvents)*timedif
times.sort()
now += timedif
plot.addData({'times': times, 'senders': senders}, now)
time.sleep(timedif/1000.)
plot.stop()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment