Created
January 11, 2017 04:19
-
-
Save smurching/18b90dc0d0039ce9c95ed92668101502 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
def make_plot(X, y, clf, title, filename): | |
''' | |
Plots the decision boundary of the classifier <clf> (assumed to have been fitted | |
to X via clf.fit()) against the matrix of examples X with corresponding labels y. | |
Uses <title> as the title of the plot, saving the plot to <filename>. | |
''' | |
# Create a mesh of points at which to evaluate our classifier | |
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 | |
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 | |
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), | |
np.arange(y_min, y_max, 0.02)) | |
# Plot the decision boundary. For that, we will assign a color to each | |
# point in the mesh [x_min, x_max]x[y_min, y_max]. | |
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) | |
# Put the result into a color plot | |
Z = Z.reshape(xx.shape) | |
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8, vmin=-1, vmax=1) | |
# Also plot the training points | |
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm) | |
plt.xlabel('x1') | |
plt.ylabel('x2') | |
plt.xlim(xx.min(), xx.max()) | |
plt.ylim(yy.min(), yy.max()) | |
plt.xticks(()) | |
plt.yticks(()) | |
plt.title(title) | |
plt.savefig(filename) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment