Last active
February 25, 2018 22:38
-
-
Save aaronsnoswell/290d682195660816dc453f8d59ed488b to your computer and use it in GitHub Desktop.
Demonstrates the batch_perception binary classification algorithm (Rosenblatt 1958)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Demonstrates the batch_perception binary classification algorithm | |
(Rosenblatt 1958) from p190 of Shavel-Shwartz, 2014 "Understanding Machine | |
Learning" | |
Requires: | |
* Python 3 | |
* numpy | |
* scikit-learn | |
* matplotlib | |
""" | |
import numpy as np | |
def batch_perceptron(x, y, *, on_found_outlier=None): | |
""" | |
Batch Perceptron aglorithm implementation from p190 in Shalev-Shwartz | |
Given M training points of dimensionality N, determines the half-space to | |
perform binary classification. | |
@param x Training features: A numpy array of shape (M, N) and type float32 | |
@param y Training labels: A numpy array of shape (M) and type int | |
@param on_found_outlier: An optional method to be called every time an | |
outlier is found. It will be passed t, x_i, y_i, w_t-1 and w_t | |
""" | |
def unison_shuffled_copies(a, b): | |
""" | |
Helper function to shuffle two arrays in synch | |
""" | |
assert len(a) == len(b) | |
p = np.random.permutation(len(a)) | |
return a[p], b[p] | |
M, N = x.shape | |
assert y.ndim == 1 and y.shape[0] == M,\ | |
"Label vector must have shape == x.shape[0] ({}), but it was {}".format( | |
M, | |
y.shape[0] | |
) | |
# Initialize w | |
w_t = np.zeros( | |
shape=N | |
) | |
t = 0 | |
while True: | |
t += 1 | |
# Perform a shuffle on the input data each loop | |
x, y = unison_shuffled_copies(x, y) | |
# If there exists an i, such that y_i * inner(w_t, x_i) < 0 | |
i_exists = False | |
for i in range(M): | |
y_i = y[i] | |
x_i = x[i] | |
if y_i * np.inner(w_t, x_i) <= 0: | |
# Found an incorrectly labelled point - update w | |
i_exists = True | |
w_prev = w_t | |
w_t = w_t + y_i * x_i | |
if on_found_outlier is not None: | |
# If a callback was passed, call it | |
on_found_outlier(t, x_i, y_i, w_prev, w_t) | |
break | |
if not i_exists: | |
# All points are now correctly classified | |
break | |
return w_t | |
def main(): | |
""" | |
Demonstrates the batch perceptron algorithm | |
""" | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
# Import some data to play with | |
print("Loading Iris dataset...") | |
from sklearn import datasets | |
iris = datasets.load_iris() | |
# We grab only the first two targets as batch perceptron is a binary | |
# classifier | |
indices = (iris.target[:] == 0) | (iris.target[:] == 1) | |
y = iris.target[indices] | |
X = iris.data[indices, :2] | |
# Reduce the number of training points in the interests of time | |
M = 50 | |
#np.random.seed(seed=1337) | |
subset_indices = np.random.choice(range(len(y)), size=M) | |
y = y[subset_indices] | |
X = X[subset_indices, :] | |
# Batch perceptron expects labels to be -1 or 1 | |
y[y == 0] = -1 | |
# Compute mean in feature space (used for plotting hyperplane) | |
Xbar = np.mean(X, axis=0) | |
# Create a figure for visualisation | |
print("Creating figure...") | |
fig = plt.figure() | |
ax = plt.gca() | |
# Plot the training points | |
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1, edgecolor='k') | |
plt.xlabel('Sepal length') | |
plt.ylabel('Sepal width') | |
# Set figure parameters | |
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 | |
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 | |
plt.xlim(x_min, x_max) | |
plt.ylim(y_min, y_max) | |
ax.grid() | |
ax.set_aspect('equal', adjustable='box') | |
plt.title("{} points from the Iris data".format(M)) | |
plt.show(block=False) | |
# Delay in seconds to pause on each iteration | |
iteration_delay = 0.001 | |
current_pt = mpl.patches.Circle((0, 0), 0.15, color='r', fill=False) | |
current_hyperplane = mpl.lines.Line2D([], [], linewidth=2) | |
def on_found_outlier(t, x_i, y_i, w_prev, w_t): | |
""" | |
Callback to update plot when we find an outlier | |
""" | |
def gen_contour(f, x_range): | |
""" | |
Helper function to generate a contour from a symbolic function y=f(x) | |
""" | |
x = np.array(x_range) | |
y = eval(f) | |
return np.array(list(zip(x, y))) | |
print("Iteration {:> 3} - x_i=({:.1f}, {:.1f}), w_t=({: 6.2f}, {: 6.2f})".format( | |
t, | |
x_i[0], | |
x_i[1], | |
w_t[0], | |
w_t[1] | |
) | |
) | |
# Remove hyperplane for next frame | |
try: | |
current_hyperplane.remove() | |
except: pass | |
# Indicate which point was the outlier | |
current_pt.center = x_i[0], x_i[1] | |
ax.add_artist(current_pt) | |
# Update the hyperplane visualisation | |
# XXX ajs 23/02/2018 Hyperplane visualisation isn't quite right... | |
contour = gen_contour("{m}*(x - {x1}) + {y1}".format( | |
m=-1/w_t[0], | |
x1=Xbar[0], | |
y1=Xbar[1] | |
), | |
np.arange(x_min, x_max, (x_max-x_min)/10) | |
) | |
current_hyperplane.set_xdata(contour[:, 0]) | |
current_hyperplane.set_ydata(contour[:, 1]) | |
ax.add_artist(current_hyperplane) | |
plt.pause(iteration_delay) | |
# Remove current point for next frame | |
current_pt.remove() | |
# Solve for w to compute the half-space | |
print("Running batch perceptron algorithm") | |
w = batch_perceptron(X, y, on_found_outlier=on_found_outlier) | |
print("Final weight vector: {}".format(w)) | |
plt.show() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment