Created
November 12, 2010 10:24
-
-
Save pprett/673953 to your computer and use it in GitHub Desktop.
A simple graphical frontend for scikit.learn Libsvm bindings.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
========== | |
Libsvm GUI | |
========== | |
A simple graphical frontend for Libsvm mainly intended for didactic | |
purposes. You can create data points by point and click and visualize | |
the decision region induced by different kernels and parameter settings. | |
To create positive examples click the left mouse button; to create | |
negative examples click the right button. | |
If all examples are from the same class, it uses a one-class svm. | |
Requirements | |
------------ | |
- Tkinter | |
- scikits.learn | |
- matplotlib with TkAgg | |
""" | |
from __future__ import division | |
print __doc__ | |
#!/usr/bin/env python | |
# | |
# Author: Peter Prettenhoer <[email protected]> | |
# | |
# License: BSD Style. | |
import matplotlib | |
matplotlib.use('TkAgg') | |
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg | |
from matplotlib.backends.backend_tkagg import NavigationToolbar2TkAgg | |
from matplotlib.figure import Figure | |
import Tkinter as Tk | |
import sys | |
import numpy as np | |
from scikits.learn import svm | |
y_min, y_max = -50, 50 | |
x_min, x_max = -50, 50 | |
class Model(object): | |
def __init__(self): | |
self.observers = [] | |
self.surface = None | |
self.data = [] | |
self.cls = None | |
self.surface_type = 0 | |
def changed(self, event): | |
for observer in self.observers: | |
observer.update(event, self) | |
def add_observer(self, observer): | |
self.observers.append(observer) | |
def set_surface(self, surface): | |
self.surface = surface | |
class Controller(object): | |
def __init__(self, model): | |
self.model = model | |
self.kernel = Tk.IntVar() | |
self.surface_type = Tk.IntVar() | |
def classify(self): | |
print "classifying data" | |
train = np.array(self.model.data) | |
X = train[:, :2] | |
y = train[:, 2] | |
C = float(self.complexity.get()) | |
gamma = float(self.gamma.get()) | |
coef0 = float(self.coef0.get()) | |
degree = int(self.degree.get()) | |
kernel_map = {0: "linear", 1: "rbf", 2: "poly"} | |
if len(np.unique(y)) == 1: | |
clf = svm.OneClassSVM(kernel=kernel_map[self.kernel.get()], | |
C=C, gamma=gamma, coef0=coef0, degree=degree) | |
clf.fit(X) | |
else: | |
clf = svm.SVC(kernel=kernel_map[self.kernel.get()], C=C, | |
gamma=gamma, coef0=coef0, degree=degree) | |
clf.fit(X, y) | |
if hasattr(clf, 'score'): | |
print "Accuracy:", clf.score(X, y) * 100 | |
X1, X2, Z = self.decision_surface(clf) | |
self.model.clf = clf | |
self.model.set_surface((X1, X2, Z)) | |
self.model.surface_type = self.surface_type.get() | |
self.model.changed("surface") | |
def decision_surface(self, cls): | |
delta = 1 | |
x = np.arange(x_min, x_max + delta, delta) | |
y = np.arange(y_min, y_max + delta, delta) | |
X1, X2 = np.meshgrid(x, y) | |
Z = cls.predict_margin(np.c_[X1.ravel(), X2.ravel()]) | |
Z = Z.reshape(X1.shape) | |
return X1, X2, Z | |
def clear_data(self): | |
self.model.data = [] | |
self.model.changed("clear") | |
def add_example(self, x, y, label): | |
self.model.data.append((x, y, label)) | |
self.model.changed("example_added") | |
class View(object): | |
def __init__(self, root, controller): | |
f = Figure() | |
ax = f.add_subplot(111) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_xlim((x_min, x_max)) | |
ax.set_ylim((y_min, y_max)) | |
canvas = FigureCanvasTkAgg(f, master=root) | |
canvas.show() | |
canvas.get_tk_widget().pack(side=Tk.TOP, fill=Tk.BOTH, expand=1) | |
canvas._tkcanvas.pack(side=Tk.TOP, fill=Tk.BOTH, expand=1) | |
canvas.mpl_connect('button_press_event', self.onclick) | |
toolbar = NavigationToolbar2TkAgg(canvas, root) | |
toolbar.update() | |
self.controllbar = ControllBar(root, controller) | |
self.f = f | |
self.ax = ax | |
self.canvas = canvas | |
self.controller = controller | |
self.hascolormaps = False | |
self.contours = [] | |
self.c_labels = None | |
self.plot_kernels() | |
def plot_kernels(self): | |
self.ax.text(-50, -60, "Linear: $u^T v$") | |
self.ax.text(-20, -60, "RBF: $\exp (-\gamma \| u-v \|^2)$") | |
self.ax.text(10, -60, "Poly: $(\gamma \, u^T v + r)^d$") | |
def onclick(self, event): | |
if event.xdata and event.ydata: | |
if event.button == 1: | |
self.controller.add_example(event.xdata, event.ydata, 1) | |
elif event.button == 3: | |
self.controller.add_example(event.xdata, event.ydata, -1) | |
def update(self, event, model): | |
#print "update. msg:%s" % event | |
if event == "example_added": | |
x, y, l = model.data[-1] | |
if l == 1: | |
color = 'w' | |
elif l == -1: | |
color = 'k' | |
self.ax.plot([x], [y], "%so" % color, scalex=0.0, scaley=0.0) | |
if event == "clear": | |
self.ax.clear() | |
self.ax.set_xticks([]) | |
self.ax.set_yticks([]) | |
self.contours = [] | |
self.c_labels = None | |
self.plot_kernels() | |
if event == "surface": | |
self.plot_decision_surface(model.surface, model.surface_type) | |
self.canvas.draw() | |
def plot_decision_surface(self, surface, type): | |
X1, X2, Z = surface | |
if len(self.contours) > 0: | |
for contour in self.contours: | |
for lineset in contour.collections: | |
lineset.remove() | |
self.contours = [] | |
if self.c_labels: | |
for label in self.c_labels: | |
label.remove() | |
if type == 0: | |
levels = [-1.0, 0.0, 1.0] | |
linestyles = ['dashed', 'solid', 'dashed'] | |
colors = 'k' | |
self.contours.append(self.ax.contour(X1, X2, Z, levels, | |
colors=colors, | |
linestyles=linestyles)) | |
elif type == 1: | |
self.contours.append(self.ax.contourf(X1, X2, Z, 10, | |
cmap=matplotlib.cm.bone, | |
origin='lower', | |
alpha=0.85)) | |
self.contours.append(self.ax.contour(X1, X2, Z, [0.0], | |
colors='k', | |
linestyles=['solid'])) | |
else: | |
raise ValueError("surface type unknown") | |
class ControllBar: | |
def __init__(self, root, controller): | |
fm = Tk.Frame(root) | |
kernel_group = Tk.Frame(fm) | |
Tk.Radiobutton(kernel_group, text="Linear", variable=controller.kernel, | |
value=0).pack(anchor=Tk.W) | |
Tk.Radiobutton(kernel_group, text="RBF", variable=controller.kernel, | |
value=1).pack(anchor=Tk.W) | |
Tk.Radiobutton(kernel_group, text="Poly", variable=controller.kernel, | |
value=2).pack(anchor=Tk.W) | |
kernel_group.pack(side=Tk.LEFT) | |
valbox = Tk.Frame(fm) | |
controller.complexity = Tk.StringVar() | |
controller.complexity.set("1.0") | |
c = Tk.Frame(valbox) | |
Tk.Label(c, text="C:", anchor="e", width=7).pack(side=Tk.LEFT) | |
Tk.Entry(c, width=6, textvariable=controller.complexity).pack( | |
side=Tk.LEFT) | |
c.pack() | |
controller.gamma = Tk.StringVar() | |
controller.gamma.set("0.01") | |
g = Tk.Frame(valbox) | |
Tk.Label(g, text="gamma:", anchor="e", width=7).pack(side=Tk.LEFT) | |
Tk.Entry(g, width=6, textvariable=controller.gamma).pack(side=Tk.LEFT) | |
g.pack() | |
controller.degree = Tk.StringVar() | |
controller.degree.set("3") | |
d = Tk.Frame(valbox) | |
Tk.Label(d, text="degree:", anchor="e", width=7).pack(side=Tk.LEFT) | |
Tk.Entry(d, width=6, textvariable=controller.degree).pack(side=Tk.LEFT) | |
d.pack() | |
controller.coef0 = Tk.StringVar() | |
controller.coef0.set("0") | |
r = Tk.Frame(valbox) | |
Tk.Label(r, text="coef0:", anchor="e", width=7).pack(side=Tk.LEFT) | |
Tk.Entry(r, width=6, textvariable=controller.coef0).pack( | |
side=Tk.LEFT) | |
r.pack() | |
valbox.pack(side=Tk.LEFT) | |
cmap_group = Tk.Frame(fm) | |
Tk.Radiobutton(cmap_group, text="Hyperplanes", | |
variable=controller.surface_type, value=0).pack( | |
anchor=Tk.W) | |
Tk.Radiobutton(cmap_group, text="Surface", | |
variable=controller.surface_type, value=1).pack( | |
anchor=Tk.W) | |
cmap_group.pack(side=Tk.LEFT) | |
train_button = Tk.Button(fm, text='Train', command=controller.classify) | |
train_button.pack() | |
fm.pack(side=Tk.LEFT) | |
Tk.Button(fm, text='Clear', | |
command=controller.clear_data).pack(side=Tk.LEFT) | |
def main(argv): | |
root = Tk.Tk() | |
model = Model() | |
controller = Controller(model) | |
root.wm_title("SVM") | |
view = View(root, controller) | |
model.add_observer(view) | |
Tk.mainloop() | |
if __name__ == "__main__": | |
main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment