-
-
Save ixaxaar/c9f54432f49b69985dba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE | |
# Version 2, December 2004 | |
# | |
# Copyright (C) 2015 Swen Wenzel <[email protected]> | |
# | |
# Everyone is permitted to copy and distribute verbatim or modified | |
# copies of this license document, and changing it is allowed as long | |
# as the name is changed. | |
# | |
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE | |
# TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION | |
# | |
# 0. You just DO WHAT THE FUCK YOU WANT TO. | |
import matplotlib.pyplot as plt | |
from matplotlib import animation | |
import threading | |
import numpy as np | |
import nest | |
class LivePlot(threading.Thread): | |
""" | |
:param senders: Sender ids, smallest and largest determine y-axis range. | |
If set to `None` the axis will be automatically updated. | |
:type senders: list or tuple of int | |
:param title: Title used for the figure | |
:type title: str | |
:param interval: Timeinterval in ms between figure updates | |
:type interval: int | |
:param timeframe: Defines in ms how far the plot goes back in time | |
:type timeframe: positive int | |
:param bufsize: Defines how many spike events are buffered | |
:type bufsize: number (decimals will be ignored) | |
This class automatically opens a figure which is constantly updated within | |
its own thread. The thread is automatically set as deamon and also directly | |
started. | |
Usage example: | |
>>> import nest | |
>>> from live_plot import LivePlot | |
>>> spike_generator = nest.Create('poisson_generator') | |
>>> nest.SetStatus(spike_generator, 'rate', 10.0) | |
>>> spike_detector = nest.Create('spike_detector') | |
>>> nest.Connect(spike_generator, spike_detector) | |
>>> myLivePlot = LivePlot() | |
>>> # closing the window will finish the thread and stop the simulation | |
>>> while myLivePlot.isAlive(): | |
>>> nest.Simulate(10) | |
>>> events = nest.GetStatus(spike_detector, 'events') | |
>>> # remove events which were just read to avoid adding them twice | |
>>> nest.SetStatus(spike_detector, 'n_events', 0) | |
>>> myLivePlot.addData(events) | |
""" | |
def __init__(self, senders=None, title=None, interval=100, timeframe=500, | |
bufsize=1e5): | |
super(LivePlot, self).__init__() | |
self._datalock = threading.Lock() | |
self._stopEvent = threading.Event() | |
self._times = np.empty(bufsize, dtype=np.float64) | |
self._senders = np.empty(bufsize, dtype=np.int32) | |
self._interval = interval | |
if senders is None: | |
self._autoSenders = True | |
self._maxsender = -np.inf | |
self._minsender = np.inf | |
else: | |
self._autoSenders = False | |
self._maxsender = np.max(senders) | |
self._minsender = np.min(senders) | |
# will be used as reference timepoint for the timewindow | |
self._now = 0 | |
# since the array won't be filled from the beginning, _len will keep | |
# track of the amount of entries | |
self._len = 0 | |
# since we have ringbuffer-like memory arrays, we have to keep track of | |
# the next insert point | |
self._nextInsert = 0 | |
if timeframe <= 0: | |
raise ValueError('timeframe has to be >0!') | |
self._timeframe = timeframe | |
self._title = title | |
# set to deamon so it will not keep the programm runnin once the main | |
# thread is finished. | |
self.setDaemon(True) | |
self.start() | |
def run(self): | |
""" | |
This thread's main method. Even though it is not marked as private, it is | |
not supposed to be called from outside! | |
""" | |
self._figure = plt.figure() | |
if self._title is not None: | |
self._figure.suptitle(self._title) | |
self._figure.show() | |
self._plt = self._figure.add_subplot(111) | |
self._scat = self._plt.scatter([],[],marker='|') | |
ani = animation.FuncAnimation(self._figure, self._update, self._getData, | |
interval = self._interval) | |
self._plt.set_xlim(-self._timeframe, 0) | |
self._plt.set_ylabel('Neuron ID') | |
plt.show() | |
def stop(self): | |
""" | |
Tells the thread to close the window and therefore stop the live plot. | |
It is save to call this method multiple times. | |
""" | |
if self.isAlive(): | |
self._stopEvent.set() | |
self.join(timeout=5) | |
def addData(self, events, now=None): | |
""" | |
:param events: New data coming from | |
`nest.GetStatus(spike_detector, 'events')` | |
:type events: dict or list/tuple of dicts | |
:param now: Used to set to current simulation time in ms | |
If set to `None`, `nest.GetStatus((0,), 'time')[0]` will be | |
used. | |
:type now: int | |
Adds new spike events to this livePlot's memory. Also updates the | |
current simulation time and therefore shifts the spikes in the next | |
frame update. | |
.. warning:: | |
Duplicate events are not filtered out! So don't forget to call | |
`nest.SetStatus(spike_detector, 'n_events', 0)` | |
""" | |
if not isinstance(events, (dict, list, tuple)): | |
raise TypeError("events must be dict or list of dicts, " + | |
"not {}".format(type(events))) | |
if isinstance(events, (list, tuple)): | |
for e in events: | |
self.addData(e) | |
return | |
senders = events['senders'] | |
times = events['times'] | |
assert len(times) == len(senders), ('times and senders have to be the' + | |
' same size!') | |
times = np.array(times, dtype=np.float64) | |
senders = np.array(senders, dtype=np.int32) | |
with self._datalock: | |
if now is None: | |
self._now = nest.GetStatus((0,), 'time')[0] | |
else: | |
self._now = now | |
if len(times) == 0: | |
return | |
# compute indices which start at the the next insert position | |
idx = np.arange(self._nextInsert, times.size+self._nextInsert, | |
dtype=np.int32) | |
# wrap indices around the size | |
idx %= self._times.size | |
# update next insert position | |
self._nextInsert = idx[-1]+1 | |
# since our buffer is not full from the beginning we have to | |
# continually increase self._len | |
if self._len != self._times.size: | |
self._len = min(self._times.size, self._len+len(times)) | |
# insert new values at their appropriate positions | |
self._times[idx] = times | |
self._senders[idx] = senders | |
# update minsender and maxsender if necessary | |
if self._autoSenders: | |
minsender = np.min(senders) | |
maxsender = np.max(senders) | |
if self._minsender > minsender: | |
self._minsender = minsender | |
if self._maxsender < maxsender: | |
self._maxsender = maxsender | |
def _update(self, data): | |
""" | |
Updates the figure using `data` coming from :func:`_getData`. | |
""" | |
if data == 'close': | |
plt.close() | |
else: | |
curtime, data = data | |
if len(data) > 0: | |
self._scat.set_offsets(data) | |
# try not to have events on the edges of the plot | |
margin = 0.1*(self._maxsender-self._minsender), | |
self._plt.set_ylim( | |
self._minsender-margin, | |
self._maxsender+margin) | |
self._plt.set_xlabel('time since: {} ms'.format(curtime)) | |
def _getData(self): | |
""" | |
Data 'generator' for matplotlib's animation. | |
""" | |
while not self._stopEvent.is_set(): | |
with self._datalock: | |
curtime = self._now | |
# inFrame is a boolean mask initialized to all False so it can | |
# be used together with the next step to hide all invalid values | |
inFrame = np.zeros_like(self._times, dtype=np.bool) | |
# from the valid values (i.e. the ones until _len) filter out | |
# those which are too old | |
inFrame[:self._len] = np.greater_equal(self._times[:self._len], | |
curtime-self._timeframe) | |
if np.any(inFrame): | |
data = zip(self._times[inFrame]-curtime, | |
self._senders[inFrame]) | |
else: | |
data = [[],[]] | |
yield curtime, data | |
yield 'close' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment