Last active
February 9, 2016 14:33
-
-
Save JensMeiners/83eddf8f6676b8ce0f17 to your computer and use it in GitHub Desktop.
For a model that generates a prediction probability, this script plots the respective descision line with colour coding the classes and assigning labels to gradient steps
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
This program is free software. It comes without any warranty, to | |
the extent permitted by applicable law. You can redistribute it | |
and/or modify it under the terms of the Do What The Fuck You Want | |
To Public License, Version 2, as published by Sam Hocevar. See | |
http://www.wtfpl.net/ for more details. | |
''' | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.qda import QDA | |
cov1 = [[0.3, 0], | |
[0, 0.3]] | |
cov2 = [[0.2, 0], | |
[0, 0.2]] | |
mean1 = [1., 1.] | |
mean2 = [-1., -1.] | |
X1, y1 = np.random.multivariate_normal(mean1, cov1, 200).T | |
X2, y2 = np.random.multivariate_normal(mean2, cov2, 200).T | |
X = np.vstack((np.hstack((X1,X2)),np.hstack((y1,y2)))).T | |
Y = np.ones(X.shape[0]) | |
Y[:X.shape[0]/2] *= -1 | |
qda = QDA() | |
qda.fit(X, Y) | |
# evenly sampled points | |
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 | |
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 | |
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 50), | |
np.linspace(y_min, y_max, 50)) | |
plt.xlim(xx.min(), xx.max()) | |
plt.ylim(yy.min(), yy.max()) | |
#plot background colors | |
ax = plt.gca() | |
Z = qda.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1] | |
Z = Z.reshape(xx.shape) | |
cs = ax.contourf(xx, yy, Z, cmap='RdBu', alpha=.5) | |
cs2 = ax.contour(xx, yy, Z, cmap='RdBu', alpha=.5) | |
plt.clabel(cs2, fmt = '%2.1f', colors = 'k', fontsize=10) | |
# Plot the points | |
plt.scatter(X1, y1, c='red') | |
plt.scatter(X2, y2, c='blue') | |
# make legend | |
plt.legend(loc='upper left', scatterpoints=1, numpoints=1) | |
plt.show() |
Author
JensMeiners
commented
Dec 12, 2015
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment