Skip to content

Instantly share code, notes, and snippets.

@smurching
Created January 11, 2017 04:19
Show Gist options
  • Save smurching/18b90dc0d0039ce9c95ed92668101502 to your computer and use it in GitHub Desktop.
Save smurching/18b90dc0d0039ce9c95ed92668101502 to your computer and use it in GitHub Desktop.
import numpy as np
from matplotlib import pyplot as plt
def make_plot(X, y, clf, title, filename):
'''
Plots the decision boundary of the classifier <clf> (assumed to have been fitted
to X via clf.fit()) against the matrix of examples X with corresponding labels y.
Uses <title> as the title of the plot, saving the plot to <filename>.
'''
# Create a mesh of points at which to evaluate our classifier
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
np.arange(y_min, y_max, 0.02))
# Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, x_max]x[y_min, y_max].
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8, vmin=-1, vmax=1)
# Also plot the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm)
plt.xlabel('x1')
plt.ylabel('x2')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.title(title)
plt.savefig(filename)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment