Skip to content

Instantly share code, notes, and snippets.

@necroshine0
Created December 2, 2023 22:25
Show Gist options
  • Save necroshine0/f3ab793ddcd0834412e848a02a7798f7 to your computer and use it in GitHub Desktop.
Save necroshine0/f3ab793ddcd0834412e848a02a7798f7 to your computer and use it in GitHub Desktop.
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
X, y = make_classification(n_samples=20000, n_features=2, n_redundant=0, random_state=10)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=10)
model = SVC(kernel='linear').fit(X_train, y_train)
weights = model.coef_[0]
a = - weights[0] / weights[1]
def get_grid(X_train):
xlim = [X_train[:, 0].min() - 0.5, X_train[:, 0].max() + 0.5]
ylim = [X_train[:, 1].min() - 0.5, X_train[:, 1].max() + 0.5]
return xlim, ylim
xlim, _ = get_grid(X_train)
xx = np.linspace(xlim[0], xlim[1])
yy = a * xx - (model.intercept_[0] / weights[1])
# Украдено с https://jakevdp.github.io/PythonDataScienceHandbook/05.07-support-vector-machines.html
def plot_svc_decision_function(model, ax=None, support=True, line_c='k'):
"""Plot the decision function for a 2D SVC"""
if ax is None:
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
# create grid to evaluate model
x, y = get_grid(X_train)
Y, X = np.meshgrid(y, x)
xy = np.vstack([X.ravel(), Y.ravel()]).T
P = model.decision_function(xy).reshape(X.shape)
# plot decision boundary and margins
ax.contour(X, Y, P, colors=line_c,
levels=[-1, 0, 1], alpha=0.8,
linestyles=['--', '-', '--'], linewidths=3)
# plot support vectors
if support:
ax.scatter(model.support_vectors_[:, 0],
model.support_vectors_[:, 1], linewidth=1, c='purple', s=70, edgecolor='k', label='support vectors')
ax.set_xlim(xlim)
ax.set_ylim(ylim)
plt.figure(figsize=(16, 10))
plt.scatter(X_test[y_test == 0, 0], X_test[y_test == 0, 1], edgecolor='k', s=70, color='deepskyblue', label='+')
plt.scatter(X_test[y_test == 1, 0], X_test[y_test == 1, 1], edgecolor='k', s=70, color='crimson', label='-')
plt.plot(xx, yy, label='sep line', color='blue')
plot_svc_decision_function(model)
plt.xlabel('feature_1')
plt.ylabel('feature_2')
plt.xlim(xlim)
plt.ylim(xlim)
plt.title('Samples Vizualization (w/ support vectors)')
plt.legend(shadow=False, fontsize=14)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment