Created
April 20, 2022 01:00
-
-
Save sithu/8f11085d02dc0d8c716b17b4739d3978 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot | |
# Some functions to plot our points and draw the lines | |
def plot_points(features, labels, fix_margins=True): | |
X = np.array(features) | |
y = np.array(labels) | |
spam = X[np.argwhere(y==1)] | |
ham = X[np.argwhere(y==0)] | |
if fix_margins: | |
pyplot.xlim(0, 11) | |
pyplot.ylim(0, 11) | |
pyplot.scatter([s[0][0] for s in spam], | |
[s[0][1] for s in spam], | |
s = 100, | |
color = 'cyan', | |
edgecolor = 'k', | |
marker = '^') | |
pyplot.scatter([s[0][0] for s in ham], | |
[s[0][1] for s in ham], | |
s = 100, | |
color = 'red', | |
edgecolor = 'k', | |
marker = 's') | |
pyplot.xlabel('Lottery') | |
pyplot.ylabel('Sale') | |
pyplot.legend(['Spam','Ham']) | |
def plot_model(X, y, model, fix_margins=True): | |
X = np.array(X) | |
y = np.array(y) | |
plot_points(X, y) | |
plot_step = 0.01 | |
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 | |
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 | |
if fix_margins: | |
x_min=0 | |
y_min=0 | |
x_max=12 | |
y_max=12 | |
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), | |
np.arange(y_min, y_max, plot_step)) | |
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) | |
Z = Z.reshape(xx.shape) | |
pyplot.contourf(xx, yy, Z, colors=['red', 'blue'], alpha=0.2, levels=range(-1,2)) | |
pyplot.contour(xx, yy, Z,colors = 'k',linewidths = 3) | |
pyplot.show() | |
def display_tree(dt): | |
from sklearn.externals.six import StringIO | |
from IPython.display import Image | |
from sklearn.tree import export_graphviz | |
import pydotplus | |
dot_data = StringIO() | |
export_graphviz(dt, out_file=dot_data, | |
filled=True, rounded=True, | |
special_characters=True) | |
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) | |
return Image(graph.create_png()) | |
def plot_trees(model): | |
estimators = gradient_boosting_model.estimators_ | |
for i in range(len(estimators)): | |
tree.plot_tree(estimators[i][0]) | |
pyplot.show() | |
#plot_model(new_X, new_y, estimators[i][0]) | |
def plot_regressor(model, features, labels): | |
x = np.linspace(0,85,1000) | |
pyplot.scatter(features, labels) | |
pyplot.plot(x, model.predict(x.reshape([-1,1]))) | |
pyplot.xlabel("Age") | |
pyplot.ylabel("Days per week") | |
pyplot.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment