Skip to content

Instantly share code, notes, and snippets.

@mblondel
Last active September 12, 2024 21:44
Show Gist options
  • Save mblondel/6f3b7aaad90606b98f71 to your computer and use it in GitHub Desktop.
Save mblondel/6f3b7aaad90606b98f71 to your computer and use it in GitHub Desktop.
Projection onto the simplex
"""
License: BSD
Author: Mathieu Blondel
Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection.
For details and references, see the following paper:
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex
Mathieu Blondel, Akinori Fujino, and Naonori Ueda.
ICPR 2014.
http://www.mblondel.org/publications/mblondel-icpr2014.pdf
"""
import numpy as np
def projection_simplex_sort(v, z=1):
n_features = v.shape[0]
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - z
ind = np.arange(n_features) + 1
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / float(rho)
w = np.maximum(v - theta, 0)
return w
def projection_simplex_pivot(v, z=1, random_state=None):
rs = np.random.RandomState(random_state)
n_features = len(v)
U = np.arange(n_features)
s = 0
rho = 0
while len(U) > 0:
G = []
L = []
k = U[rs.randint(0, len(U))]
ds = v[k]
for j in U:
if v[j] >= v[k]:
if j != k:
ds += v[j]
G.append(j)
elif v[j] < v[k]:
L.append(j)
drho = len(G) + 1
if s + ds - (rho + drho) * v[k] < z:
s += ds
rho += drho
U = L
else:
U = G
theta = (s - z) / float(rho)
return np.maximum(v - theta, 0)
def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000):
func = lambda x: np.sum(np.maximum(v - x, 0)) - z
lower = np.min(v) - z / len(v)
upper = np.max(v)
for it in range(max_iter):
midpoint = (upper + lower) / 2.0
value = func(midpoint)
if abs(value) <= tau:
break
if value <= 0:
upper = midpoint
else:
lower = midpoint
return np.maximum(v - midpoint, 0)
if __name__ == '__main__':
v = np.array([1.1, 0.2, 0.2])
z = 2
w = projection_simplex_sort(v, z)
print(np.sum(w))
w = projection_simplex_pivot(v, z)
print(np.sum(w))
w = projection_simplex_bisection(v, z)
print(np.sum(w))
@joseortiz3
Copy link

joseortiz3 commented Apr 25, 2021

At least in Python 3.8, numpy version 1.20.2, after replacing xrange -> range and putting parentheses around print() for Py3 compatibility, I find this: for

v = np.array([1+0.1,0+0.2,0+0.2])
z = 2

the script prints

2
2.0
2.0
1.5

Is this an issue with projection_simplex_bisection, or just a pathological example?

@mblondel
Copy link
Author

Thanks @joseortiz3. I fixed the bug (there was a mistake in the bracketing interval).

@flbbb
Copy link

flbbb commented Nov 29, 2021

In case someone needs to project each column of a 2D array onto the simplex, here is an adapted code sample (Python 3.9.5):

    def projection_simplex_sort_2d(v, z=1):
        """v array of shape (n_features, n_samples)."""
        p, n = v.shape
        u = np.sort(v, axis=0)[::-1, ...]
        pi = np.cumsum(u, axis=0) - z
        ind = (np.arange(p) + 1).reshape(-1, 1)
        mask = (u - pi / ind) > 0
        rho = p - 1 - np.argmax(mask[::-1, ...], axis=0)
        theta = pi[tuple([rho, np.arange(n)])] / (rho + 1)
        w = np.maximum(v - theta, 0)
        return w

@mblondel
Copy link
Author

Hehe, I also wrote a vectorized version here https://gist.github.com/mblondel/c99e575a5207c76a99d714e8c6e08e89

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment