Skip to content

Instantly share code, notes, and snippets.

@TNick
Created March 23, 2015 23:13
Show Gist options
  • Save TNick/ead6e553fc5395449d03 to your computer and use it in GitHub Desktop.
Save TNick/ead6e553fc5395449d03 to your computer and use it in GitHub Desktop.
pylearn2/train_extensions/live_monitoring.py
"""
Training extension for allowing querying of monitoring values while an
experiment executes.
"""
__authors__ = "Dustin Webb, Adam Stone"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
__credits__ = ["Dustin Webb, Adam Stone"]
__license__ = "3-clause BSD"
__maintainer__ = "LISA Lab"
__email__ = "pylearn-dev@googlegroups"
import copy
import logging
LOG = logging.getLogger('LiveMonitor')
try:
import zmq
ZMQ_AVAILABLE = True
except Exception:
ZMQ_AVAILABLE = False
try:
from PySide import QtCore, QtGui
import matplotlib
import numpy as np
matplotlib.use('Qt4Agg')
matplotlib.rcParams['backend.qt4'] = 'PySide'
from matplotlib.backends.backend_qt4agg import (
FigureCanvasQTAgg as FigureCanvas,
NavigationToolbar2QT as NavigationToolbar)
from matplotlib.figure import Figure
QT_AVAILABLE = True
except Exception:
QT_AVAILABLE = False
try:
import matplotlib.pyplot as plt
PYPLOT_AVAILABLE = True
except ImportError:
PYPLOT_AVAILABLE = False
MPLDC_AVAILABLE = False
try:
if PYPLOT_AVAILABLE:
import mpldatacursor as mpldc
MPLDC_AVAILABLE = True
except ImportError:
pass
from functools import wraps
from pylearn2.monitor import Monitor
from pylearn2.train_extensions import TrainExtension
class LiveMonitorMsg(object):
"""
Base class that defines the required interface for all Live Monitor
messages.
"""
response_set = False
def get_response(self):
"""
Method that instantiates a response message for a given request
message. It is not necessary to implement this function on response
messages.
"""
raise NotImplementedError('get_response is not implemented.')
class ChannelListResponse(LiveMonitorMsg):
"""
A message containing the list of channels being monitored.
"""
pass
class ChannelListRequest(LiveMonitorMsg):
"""
A message indicating a request for a list of channels being monitored.
"""
@wraps(LiveMonitorMsg.get_response)
def get_response(self):
return ChannelListResponse()
class ChannelsResponse(LiveMonitorMsg):
"""
A message containing monitoring data related to the channels specified.
Data can be requested for all epochs or select epochs.
Parameters
----------
channel_list : list
A list of the channels for which data has been requested.
start : int
The starting epoch for which data should be returned.
end : int
The epoch after which data should be returned.
step : int
The number of epochs to be skipped between data points.
"""
def __init__(self, channel_list, start, end, step=1):
assert(
isinstance(channel_list, list)
and len(channel_list) > 0
)
self.channel_list = channel_list
assert start >= 0
self.start = start
self.end = end
assert step > 0
self.step = step
self.data = {}
class ChannelsRequest(LiveMonitorMsg):
"""
A message for requesting data related to the channels specified.
Parameters
----------
channel_list : list
A list of the channels for which data has been requested.
start : int
The starting epoch for which data should be returned.
end : int
The epoch after which data should be returned.
step : int
The number of epochs to be skipped between data points.
"""
def __init__(self, channel_list, start=0, end=-1, step=1):
assert(
isinstance(channel_list, list)
and len(channel_list) > 0
)
self.channel_list = channel_list
assert start >= 0
self.start = start
self.end = end
assert step > 0
self.step = step
@wraps(LiveMonitorMsg.get_response)
def get_response(self):
return ChannelsResponse(
self.channel_list,
self.start,
self.end,
self.step
)
class LiveMonitoring(TrainExtension):
"""
A training extension for remotely monitoring and filtering the channels
being monitored in real time. PyZMQ must be installed for this extension
to work.
Parameters
----------
address : string
The IP addresses of the interfaces on which the monitor should listen.
req_port : int
The port number to be used to service request.
pub_port : int
The port number to be used to publish updates.
"""
def __init__(self, address='*', req_port=5555, pub_port=5556):
if not ZMQ_AVAILABLE:
raise ImportError('zeromq needs to be installed to '
'use this module.')
self.address = 'tcp://%s' % address
assert req_port != pub_port
assert req_port > 1024 and req_port < 65536
self.req_port = req_port
assert pub_port > 1024 and pub_port < 65536
self.pub_port = pub_port
address_template = self.address + ':%d'
self.context = zmq.Context()
self.req_sock = None
if self.req_port > 0:
self.req_sock = self.context.socket(zmq.REP)
self.req_sock.bind(address_template % self.req_port)
self.pub_sock = None
if self.pub_port > 0:
self.pub_sock = self.context.socket(zmq.PUB)
self.pub_sock.bind(address_template % self.pub_port)
# Tracks the number of times on_monitor has been called
self.counter = 0
self.post_size = 0
def __build_channel_resp__(self, monitor, channel_list,
start=0, end=-1, step=1):
"""
Constructs a response for a channel data request.
The response will either be an exception or a dictionary, with
keys being the names of the channels.
"""
result = {}
if not isinstance(channel_list, list) or len(channel_list) == 0:
channel_list = []
result = TypeError('ChannelResponse requires a list of channels.')
else:
for channel_name in channel_list:
if channel_name in monitor.channels.keys():
chan = copy.deepcopy(
monitor.channels[channel_name]
)
if self.post_size == 0:
self.post_size = len(chan.batch_record)
if end == -1:
end = len(chan.batch_record)
# TODO copying and truncating the records individually
# like this is brittle. Is there a more robust
# solution?
chan.batch_record = chan.batch_record[
start:end:step
]
chan.epoch_record = chan.epoch_record[
start:end:step
]
chan.example_record = chan.example_record[
start:end:step
]
chan.time_record = chan.time_record[
start:end:step
]
chan.val_record = chan.val_record[
start:end:step
]
result[channel_name] = chan
else:
result[channel_name] = KeyError(
'Invalid channel: %s' % channel_name
)
return result
def __reply_to_req__(self, monitor):
"""
Replies to a request for specific channels or to list all channels.
"""
try:
rsqt_msg = self.req_sock.recv_pyobj(flags=zmq.NOBLOCK)
# Determine what type of message was received
rsp_msg = rsqt_msg.get_response()
if isinstance(rsp_msg, ChannelListResponse):
rsp_msg.data = list(monitor.channels.keys())
elif isinstance(rsp_msg, ChannelsResponse):
channel_list = rsp_msg.channel_list
rsp_msg.data = self.__build_channel_resp__(monitor,
channel_list,
rsp_msg.start,
rsp_msg.end,
rsp_msg.step)
self.req_sock.send_pyobj(rsp_msg)
except zmq.Again:
pass
def __publish_results__(self, monitor):
"""
Publishes all channels to dedicated ZMQ slot.
"""
if self.pub_sock is None:
return
try:
channel_list = list(monitor.channels.keys())
start = self.counter*self.post_size
end = -1 if self.post_size == 0 else start + self.post_size
rsp_msg = ChannelsResponse(channel_list, start, end, step=1)
rsp_msg.data = self.__build_channel_resp__(monitor,
channel_list,
start, end)
self.pub_sock.send_pyobj(rsp_msg)
except:
pass
@wraps(TrainExtension.on_monitor)
def on_monitor(self, model, dataset, algorithm):
monitor = Monitor.get_monitor(model)
self.__reply_to_req__(monitor)
self.__publish_results__(monitor)
self.counter += 1
class LiveMonitor(object):
"""
A utility class for requested data from a LiveMonitoring training
extension.
Parameters
----------
address : string
The IP address on which a LiveMonitoring process is listening.
req_port : int
The port number on which a LiveMonitoring process is listening.
"""
def __init__(self, address='127.0.0.1', req_port=5555, subscribe=False):
"""
"""
if not ZMQ_AVAILABLE:
raise ImportError('zeromq needs to be installed to '
'use this module.')
self.address = 'tcp://%s' % address
assert req_port > 0
self.req_port = req_port
self.subscribe = subscribe
self.context = zmq.Context()
if subscribe:
self.req_sock = self.context.socket(zmq.SUB)
self.req_sock.setsockopt(zmq.SUBSCRIBE, "")
else:
self.req_sock = self.context.socket(zmq.REQ)
self.req_sock.connect(self.address + ':' + str(self.req_port))
self.channels = {}
def list_channels(self):
"""
Returns a list of the channels being monitored.
"""
if self.subscribe:
return self.channels
else:
self.req_sock.send_pyobj(ChannelListRequest())
return self.req_sock.recv_pyobj()
def update_channels(self, channel_list, start=-1, end=-1, step=1):
"""
Retrieves data for a specified set of channels and combines that data
with any previously retrived data.
This assumes all the channels have the same number of values. It is
unclear as to whether this is a reasonable assumption. If they do not
have the same number of values then it may request to much or too
little data leading to duplicated data or wholes in the data
respectively. This could be made more robust by making a call to
retrieve all the data for all of the channels.
Parameters
----------
channel_list : list
A list of the channels for which data should be requested.
start : int
The starting epoch for which data should be requested.
step : int
The number of epochs to be skipped between data points.
"""
assert (start == -1 and end == -1) or end > start
if self.subscribe:
# all we have to do is request an update from the publisher
pass
else:
if start == -1:
start = 0
if len(self.channels.keys()) > 0:
channel_name = list(self.channels.keys())[0]
start = len(self.channels[channel_name].epoch_record)
self.req_sock.send_pyobj(ChannelsRequest(
channel_list, start=start, end=end, step=step
))
rsp_msg = self.req_sock.recv_pyobj()
if isinstance(rsp_msg.data, Exception):
raise rsp_msg.data
if not hasattr(rsp_msg, 'data'):
LOG.warn("No data attribute received by update_channels")
return
for channel in rsp_msg.data.keys():
rsp_chan = rsp_msg.data[channel]
if isinstance(rsp_chan, Exception):
raise rsp_chan
if self.subscribe:
if channel not in channel_list:
continue
if channel not in self.channels.keys():
self.channels[channel] = rsp_chan
else:
chan = self.channels[channel]
chan.batch_record += rsp_chan.batch_record
chan.epoch_record += rsp_chan.epoch_record
chan.example_record += rsp_chan.example_record
chan.time_record += rsp_chan.time_record
chan.val_record += rsp_chan.val_record
rsp_chan = chan
def follow_channels(self, channel_list, use_qt=False):
"""
Tracks and plots a specified set of channels in real time.
Parameters
----------
channel_list : list or dict
A list of the channels for which data will be requested an plotted
or a dictionary where keys will become the names of the plots while
values are lists of channel names.
use_qt : bool
Use a PySide GUI for plotting, if available.
"""
if use_qt:
if not QT_AVAILABLE:
LOG.warning(
'follow_channels called with use_qt=True, but PySide '
'is not available. Falling back on matplotlib ion().')
else:
# only create new qt app if running the first time in session
if isinstance(channel_list, dict):
self.channel_dict = channel_list
tmp_list = []
for k in channel_list:
tmp_list.extend(channel_list[k])
channel_list = tmp_list
# remove duplicates in the list of channels
self.channel_list = list(set(tmp_list))
else:
self.channel_list = channel_list
self.channel_dict = {'': channel_list}
if len(self.channel_list) == 0:
raise ValueError('No channel name provided; '
'channel_list must be either '
'a list or a dict')
if not hasattr(self, 'gui'):
self.gui = LiveMonitorGUI(self,
self.channel_list,
self.channel_dict)
self.gui.start()
elif not PYPLOT_AVAILABLE:
raise ImportError('pyplot needs to be installed for '
'this functionality.')
else:
plt.clf()
plt.ion()
while True:
self.update_channels(channel_list)
plt.clf()
for channel_name in self.channels:
plt.plot(
self.channels[channel_name].epoch_record,
self.channels[channel_name].val_record,
label=channel_name
)
plt.legend()
plt.ion()
plt.draw()
if QT_AVAILABLE:
class LiveMonitorGUI(QtGui.QMainWindow):
"""
PySide GUI implementation for live monitoring channels.
Parameters
----------
live_mon : LiveMonitor instance
The LiveMonitor instance to which the GUI belongs.
channel_list : list
A list of the channels to display.
"""
def __init__(self, live_mon, channel_list, channel_dict):
self.app = QtGui.QApplication(["Live Monitor"])
super(LiveMonitorGUI, self).__init__()
self.live_mon = live_mon
self.channel_list = channel_list
self.channel_dict = channel_dict
self.updater_thread = UpdaterThread(live_mon, channel_list)
self.updater_thread.updated.connect(self.__refresh__)
self.__init_ui__()
def __common_ui__(self):
if MPLDC_AVAILABLE:
opts = {'hover': True,
'xytext':(15, -30),
'formatter':"{label} {y:0.3g}\nat epoch {x:0.0f}".format,
'keybindings':{'hide':'h', 'toggle':'e'},
'bbox':{'fc':'white'},
'arrowprops': {'arrowstyle':'simple',
'fc':'white',
'alpha':0.1}}
#draggable=True
mpldc.datacursor(axes=self.fig.axes, **opts)
def __init_ui__(self):
matplotlib.rcParams.update({'font.size': 8})
self.resize(600, 400)
self.fig = Figure(figsize=(600, 400), dpi=72,
facecolor=(1, 1, 1), edgecolor=(0, 0, 0))
arrange = {1: [1, 1], 2: [1, 2], 3: [1, 3], 4: [2, 2],
5: [2, 3], 6: [2, 3], 7: [2, 4], 8: [2, 4],
9: [3, 3], 10: [3, 4], 11: [3, 4], 12: [3, 4]}
splot_len = len(self.channel_dict)
if splot_len < 13:
splot_layout = arrange[splot_len]
else:
splot_layout = [splot_len//5, 5]
#splot_r = 0
#splot_c = 0
#splot_i = 1
self.ax = []
for splot_i in enumerate(len(self.channel_dict)):
self.ax.append(self.fig.add_subplot(splot_layout[0],
splot_layout[1],
splot_i+1))
# If we need columns/rows specific setup, this is one way:
# splot_c = splot_c + 1
# if splot_c >= splot_layout[1]:
# splot_c = 0
# splot_r = splot_r + 1
self.fig.subplots_adjust(left=0.02, right=0.98,
top=0.98, bottom=0.02,
hspace=0.1)
self.__common_ui__()
self.canvas = FigureCanvas(self.fig)
self.setCentralWidget(self.canvas)
ntb = NavigationToolbar(self.canvas, self)
self.addToolBar(ntb)
def __refresh__(self):
if not self.live_mon.channels:
self.updater_thread.start()
return
splot_i = 0
for splot_name in self.channel_dict:
self.ax[splot_i].cla() # clear previous plot
chan_list = self.channel_dict[splot_name]
for channel_name in chan_list:
if not channel_name in self.live_mon.channels:
splot_i = splot_i + 1
continue
X = epoch_record = self.live_mon.channels[channel_name].epoch_record
Y = val_record = self.live_mon.channels[channel_name].val_record
indices = np.nonzero(np.diff(epoch_record))[0] + 1
epoch_record_split = np.split(epoch_record, indices)
val_record_split = np.split(val_record, indices)
X = np.zeros(len(epoch_record))
Y = np.zeros(len(epoch_record))
for i, epoch in enumerate(epoch_record_split):
j = i*len(epoch_record_split[0])
X[j: j + len(epoch)] = (
1.*np.arange(len(epoch)) / len(epoch) + epoch[0])
Y[j: j + len(epoch)] = val_record_split[i]
self.ax[splot_i].plot(X, Y, label=channel_name)
self.ax[splot_i].legend(loc='best', fancybox=True, framealpha=0.5)
self.fig.axes[splot_i].set_xlabel('Epoch')
self.fig.axes[splot_i].set_ylabel('Value')
self.fig.axes[splot_i].set_title(splot_name)
splot_i = splot_i + 1
#self.fig.axes[splot_i].set_title('Tracking %d channels' % len(chan_list))
self.__common_ui__()
self.canvas.draw()
self.updater_thread.start()
def start(self):
self.show()
self.updater_thread.start()
self.app.exec_()
class UpdaterThread(QtCore.QThread):
updated = QtCore.Signal()
def __init__(self, live_mon, channel_list):
super(UpdaterThread, self).__init__()
self.live_mon = live_mon
self.channel_list = channel_list
def run(self):
self.live_mon.update_channels(self.channel_list) # blocking
self.updated.emit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment