Implements four algorithms for projecting a vector onto the simplex: sort, pivot, bisection, and brent.
For details and references, see the following paper:
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex
Mathieu Blondel, Akinori Fujino, and Naonori Ueda.
ICPR 2014.
import numpy as np
from scipy.optimize import brentq
def projection_simplex_sort(v, z=1):
n_features = v.shape[0]
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - z
ind = np.arange(n_features) + 1
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / float(rho)
w = np.maximum(v - theta, 0)
return w
def projection_simplex_pivot(v, z=1, random_state=None):
rs = np.random.RandomState(random_state)
n_features = len(v)
U = np.arange(n_features)
s = 0
rho = 0
while len(U) > 0:
G = []
L = []
k = U[rs.randint(0, len(U))]
ds = v[k]
for j in U:
if v[j] >= v[k]:
if j != k:
ds += v[j]
elif v[j] < v[k]:
drho = len(G) + 1
if s + ds - (rho + drho) * v[k] < z:
s += ds
rho += drho
U = L
U = G
theta = (s - z) / float(rho)
return np.maximum(v - theta, 0)
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000):
lower = 0
upper = np.max(v)
current = np.inf
iter = 0
for it in xrange(max_iter):
if np.abs(current) / z < tau and current < 0:
theta = (upper + lower) / 2.0
w = np.maximum(v - theta, 0)
current = np.sum(w) - z
if current <= 0:
upper = theta
lower = theta
print "bisection took ", str(iter)
return w
def projection_simplex_brent(v, z=1, tau=1e-9):
lower = 0
upper = np.max(v)
def minFunc(theta):
return np.sum(np.maximum(v - theta, 0.0))-z
x0, r = brentq(minFunc, lower, upper, xtol=tau, full_output = True)
print "brent took ", r.iterations, " iterations and ", r.function_calls, " function calls"
return np.maximum(v - x0, 0)
if __name__ == '__main__':
rs = np.random.RandomState(0)
v = rs.rand(1000)
z = np.sum(v) * 0.5
print z
w = projection_simplex_sort(v, z)
print np.sum(w)
w = projection_simplex_pivot(v, z)
print np.sum(w)
w = projection_simplex_bisection(v, z)
print np.sum(w)
w = projection_simplex_brent(v, z)
print np.sum(w)
