-
-
Save EdwardRaff/f4f4cf0c927c2addfb39 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implements four algorithms for projecting a vector onto the simplex: sort, pivot, bisection, and brent. | |
For details and references, see the following paper: | |
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex | |
Mathieu Blondel, Akinori Fujino, and Naonori Ueda. | |
ICPR 2014. | |
http://www.mblondel.org/publications/mblondel-icpr2014.pdf | |
""" | |
import numpy as np | |
from scipy.optimize import brentq | |
def projection_simplex_sort(v, z=1): | |
n_features = v.shape[0] | |
u = np.sort(v)[::-1] | |
cssv = np.cumsum(u) - z | |
ind = np.arange(n_features) + 1 | |
cond = u - cssv / ind > 0 | |
rho = ind[cond][-1] | |
theta = cssv[cond][-1] / float(rho) | |
w = np.maximum(v - theta, 0) | |
return w | |
def projection_simplex_pivot(v, z=1, random_state=None): | |
rs = np.random.RandomState(random_state) | |
n_features = len(v) | |
U = np.arange(n_features) | |
s = 0 | |
rho = 0 | |
while len(U) > 0: | |
G = [] | |
L = [] | |
k = U[rs.randint(0, len(U))] | |
ds = v[k] | |
for j in U: | |
if v[j] >= v[k]: | |
if j != k: | |
ds += v[j] | |
G.append(j) | |
elif v[j] < v[k]: | |
L.append(j) | |
drho = len(G) + 1 | |
if s + ds - (rho + drho) * v[k] < z: | |
s += ds | |
rho += drho | |
U = L | |
else: | |
U = G | |
theta = (s - z) / float(rho) | |
return np.maximum(v - theta, 0) | |
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): | |
lower = 0 | |
upper = np.max(v) | |
current = np.inf | |
iter = 0 | |
for it in xrange(max_iter): | |
if np.abs(current) / z < tau and current < 0: | |
break | |
theta = (upper + lower) / 2.0 | |
w = np.maximum(v - theta, 0) | |
current = np.sum(w) - z | |
if current <= 0: | |
upper = theta | |
else: | |
lower = theta | |
iter+=1 | |
print "bisection took ", str(iter) | |
return w | |
def projection_simplex_brent(v, z=1, tau=1e-9): | |
lower = 0 | |
upper = np.max(v) | |
def minFunc(theta): | |
return np.sum(np.maximum(v - theta, 0.0))-z | |
x0, r = brentq(minFunc, lower, upper, xtol=tau, full_output = True) | |
print "brent took ", r.iterations, " iterations and ", r.function_calls, " function calls" | |
return np.maximum(v - x0, 0) | |
if __name__ == '__main__': | |
rs = np.random.RandomState(0) | |
v = rs.rand(1000) | |
z = np.sum(v) * 0.5 | |
print z | |
w = projection_simplex_sort(v, z) | |
print np.sum(w) | |
w = projection_simplex_pivot(v, z) | |
print np.sum(w) | |
w = projection_simplex_bisection(v, z) | |
print np.sum(w) | |
w = projection_simplex_brent(v, z) | |
print np.sum(w) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment