Last active
September 12, 2024 21:44
-
-
Save mblondel/6f3b7aaad90606b98f71 to your computer and use it in GitHub Desktop.
Projection onto the simplex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
License: BSD | |
Author: Mathieu Blondel | |
Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection. | |
For details and references, see the following paper: | |
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex | |
Mathieu Blondel, Akinori Fujino, and Naonori Ueda. | |
ICPR 2014. | |
http://www.mblondel.org/publications/mblondel-icpr2014.pdf | |
""" | |
import numpy as np | |
def projection_simplex_sort(v, z=1): | |
n_features = v.shape[0] | |
u = np.sort(v)[::-1] | |
cssv = np.cumsum(u) - z | |
ind = np.arange(n_features) + 1 | |
cond = u - cssv / ind > 0 | |
rho = ind[cond][-1] | |
theta = cssv[cond][-1] / float(rho) | |
w = np.maximum(v - theta, 0) | |
return w | |
def projection_simplex_pivot(v, z=1, random_state=None): | |
rs = np.random.RandomState(random_state) | |
n_features = len(v) | |
U = np.arange(n_features) | |
s = 0 | |
rho = 0 | |
while len(U) > 0: | |
G = [] | |
L = [] | |
k = U[rs.randint(0, len(U))] | |
ds = v[k] | |
for j in U: | |
if v[j] >= v[k]: | |
if j != k: | |
ds += v[j] | |
G.append(j) | |
elif v[j] < v[k]: | |
L.append(j) | |
drho = len(G) + 1 | |
if s + ds - (rho + drho) * v[k] < z: | |
s += ds | |
rho += drho | |
U = L | |
else: | |
U = G | |
theta = (s - z) / float(rho) | |
return np.maximum(v - theta, 0) | |
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): | |
func = lambda x: np.sum(np.maximum(v - x, 0)) - z | |
lower = np.min(v) - z / len(v) | |
upper = np.max(v) | |
for it in range(max_iter): | |
midpoint = (upper + lower) / 2.0 | |
value = func(midpoint) | |
if abs(value) <= tau: | |
break | |
if value <= 0: | |
upper = midpoint | |
else: | |
lower = midpoint | |
return np.maximum(v - midpoint, 0) | |
if __name__ == '__main__': | |
v = np.array([1.1, 0.2, 0.2]) | |
z = 2 | |
w = projection_simplex_sort(v, z) | |
print(np.sum(w)) | |
w = projection_simplex_pivot(v, z) | |
print(np.sum(w)) | |
w = projection_simplex_bisection(v, z) | |
print(np.sum(w)) |
In case someone needs to project each column of a 2D array onto the simplex, here is an adapted code sample (Python 3.9.5):
def projection_simplex_sort_2d(v, z=1):
"""v array of shape (n_features, n_samples)."""
p, n = v.shape
u = np.sort(v, axis=0)[::-1, ...]
pi = np.cumsum(u, axis=0) - z
ind = (np.arange(p) + 1).reshape(-1, 1)
mask = (u - pi / ind) > 0
rho = p - 1 - np.argmax(mask[::-1, ...], axis=0)
theta = pi[tuple([rho, np.arange(n)])] / (rho + 1)
w = np.maximum(v - theta, 0)
return w
Hehe, I also wrote a vectorized version here https://gist.github.com/mblondel/c99e575a5207c76a99d714e8c6e08e89
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks @joseortiz3. I fixed the bug (there was a mistake in the bracketing interval).