Skip to content

Instantly share code, notes, and snippets.

@mblondel
Last active March 3, 2023 07:57
Show Gist options
  • Save mblondel/97cffbea574a5890f0d7 to your computer and use it in GitHub Desktop.
Save mblondel/97cffbea574a5890f0d7 to your computer and use it in GitHub Desktop.
Multiclass SVMs
"""
Multiclass SVMs (Crammer-Singer formulation).
A pure Python re-implementation of:
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex.
Mathieu Blondel, Akinori Fujino, and Naonori Ueda.
ICPR 2014.
http://www.mblondel.org/publications/mblondel-icpr2014.pdf
"""
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import check_random_state
from sklearn.preprocessing import LabelEncoder
def projection_simplex(v, z=1):
"""
Projection onto the simplex:
w^* = argmin_w 0.5 ||w-v||^2 s.t. \sum_i w_i = z, w_i >= 0
"""
# For other algorithms computing the same projection, see
# https://gist.github.com/mblondel/6f3b7aaad90606b98f71
n_features = v.shape[0]
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - z
ind = np.arange(n_features) + 1
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / float(rho)
w = np.maximum(v - theta, 0)
return w
class MulticlassSVM(BaseEstimator, ClassifierMixin):
def __init__(self, C=1, max_iter=50, tol=0.05,
random_state=None, verbose=0):
self.C = C
self.max_iter = max_iter
self.tol = tol,
self.random_state = random_state
self.verbose = verbose
def _partial_gradient(self, X, y, i):
# Partial gradient for the ith sample.
g = np.dot(X[i], self.coef_.T) + 1
g[y[i]] -= 1
return g
def _violation(self, g, y, i):
# Optimality violation for the ith sample.
smallest = np.inf
for k in range(g.shape[0]):
if k == y[i] and self.dual_coef_[k, i] >= self.C:
continue
elif k != y[i] and self.dual_coef_[k, i] >= 0:
continue
smallest = min(smallest, g[k])
return g.max() - smallest
def _solve_subproblem(self, g, y, norms, i):
# Prepare inputs to the projection.
Ci = np.zeros(g.shape[0])
Ci[y[i]] = self.C
beta_hat = norms[i] * (Ci - self.dual_coef_[:, i]) + g / norms[i]
z = self.C * norms[i]
# Compute projection onto the simplex.
beta = projection_simplex(beta_hat, z)
return Ci - self.dual_coef_[:, i] - beta / norms[i]
def fit(self, X, y):
n_samples, n_features = X.shape
# Normalize labels.
self._label_encoder = LabelEncoder()
y = self._label_encoder.fit_transform(y)
# Initialize primal and dual coefficients.
n_classes = len(self._label_encoder.classes_)
self.dual_coef_ = np.zeros((n_classes, n_samples), dtype=np.float64)
self.coef_ = np.zeros((n_classes, n_features))
# Pre-compute norms.
norms = np.sqrt(np.sum(X ** 2, axis=1))
# Shuffle sample indices.
rs = check_random_state(self.random_state)
ind = np.arange(n_samples)
rs.shuffle(ind)
violation_init = None
for it in range(self.max_iter):
violation_sum = 0
for ii in range(n_samples):
i = ind[ii]
# All-zero samples can be safely ignored.
if norms[i] == 0:
continue
g = self._partial_gradient(X, y, i)
v = self._violation(g, y, i)
violation_sum += v
if v < 1e-12:
continue
# Solve subproblem for the ith sample.
delta = self._solve_subproblem(g, y, norms, i)
# Update primal and dual coefficients.
self.coef_ += (delta * X[i][:, np.newaxis]).T
self.dual_coef_[:, i] += delta
if it == 0:
violation_init = violation_sum
vratio = violation_sum / violation_init
if self.verbose >= 1:
print("iter", it + 1, "violation", vratio)
if vratio < self.tol:
if self.verbose >= 1:
print("Converged")
break
return self
def predict(self, X):
decision = np.dot(X, self.coef_.T)
pred = decision.argmax(axis=1)
return self._label_encoder.inverse_transform(pred)
if __name__ == '__main__':
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
clf = MulticlassSVM(C=0.1, tol=0.01, max_iter=100, random_state=0, verbose=1)
clf.fit(X, y)
print(clf.score(X, y))
@mblondel
Copy link
Author

mblondel commented Feb 5, 2020

@leme-lab: Indeed, fitting b leads to more complicated dual. The usual trick is to add a feature x_0 to all inputs, so that a weight w_0 is going to be learned. This is not exactly the same as as fitting b though (due to regularization).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment