Last active
September 12, 2024 21:44
-
-
Save mblondel/6f3b7aaad90606b98f71 to your computer and use it in GitHub Desktop.
Projection onto the simplex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
License: BSD | |
Author: Mathieu Blondel | |
Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection. | |
For details and references, see the following paper: | |
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex | |
Mathieu Blondel, Akinori Fujino, and Naonori Ueda. | |
ICPR 2014. | |
http://www.mblondel.org/publications/mblondel-icpr2014.pdf | |
""" | |
import numpy as np | |
def projection_simplex_sort(v, z=1): | |
n_features = v.shape[0] | |
u = np.sort(v)[::-1] | |
cssv = np.cumsum(u) - z | |
ind = np.arange(n_features) + 1 | |
cond = u - cssv / ind > 0 | |
rho = ind[cond][-1] | |
theta = cssv[cond][-1] / float(rho) | |
w = np.maximum(v - theta, 0) | |
return w | |
def projection_simplex_pivot(v, z=1, random_state=None): | |
rs = np.random.RandomState(random_state) | |
n_features = len(v) | |
U = np.arange(n_features) | |
s = 0 | |
rho = 0 | |
while len(U) > 0: | |
G = [] | |
L = [] | |
k = U[rs.randint(0, len(U))] | |
ds = v[k] | |
for j in U: | |
if v[j] >= v[k]: | |
if j != k: | |
ds += v[j] | |
G.append(j) | |
elif v[j] < v[k]: | |
L.append(j) | |
drho = len(G) + 1 | |
if s + ds - (rho + drho) * v[k] < z: | |
s += ds | |
rho += drho | |
U = L | |
else: | |
U = G | |
theta = (s - z) / float(rho) | |
return np.maximum(v - theta, 0) | |
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): | |
func = lambda x: np.sum(np.maximum(v - x, 0)) - z | |
lower = np.min(v) - z / len(v) | |
upper = np.max(v) | |
for it in range(max_iter): | |
midpoint = (upper + lower) / 2.0 | |
value = func(midpoint) | |
if abs(value) <= tau: | |
break | |
if value <= 0: | |
upper = midpoint | |
else: | |
lower = midpoint | |
return np.maximum(v - midpoint, 0) | |
if __name__ == '__main__': | |
v = np.array([1.1, 0.2, 0.2]) | |
z = 2 | |
w = projection_simplex_sort(v, z) | |
print(np.sum(w)) | |
w = projection_simplex_pivot(v, z) | |
print(np.sum(w)) | |
w = projection_simplex_bisection(v, z) | |
print(np.sum(w)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hehe, I also wrote a vectorized version here https://gist.github.com/mblondel/c99e575a5207c76a99d714e8c6e08e89