Skip to content

Instantly share code, notes, and snippets.

@mgritter
Created November 13, 2019 05:38
Show Gist options
  • Select an option

  • Save mgritter/4bf003cd399da2e57096af1050d64ddd to your computer and use it in GitHub Desktop.

Select an option

Save mgritter/4bf003cd399da2e57096af1050d64ddd to your computer and use it in GitHub Desktop.
Chen and Ye's algorithm for simplex projection
def project_onto_standard_simplex( y ):
"""Find the nearest neighbor of y on the |y|-element standard simplex
x_1 + ... + x_n = 1
See Yunmei Chen and Xiaojing Ye, "Projection Onto a Simplex",
https://arxiv.org/abs/1101.6081
"""
n = len( y )
y_s = sorted( y, reverse=True )
# Sum the i largest y's.
# Compute t_i = (sum) / (i)
# if t_i >= next-smallest y, break
#
# In the paper this is reversed, so we get
# t_i = ( -1 + sum_{j=i+1}^{n} y_j ) / (n-i)
sum_y = 0
for i, y_i, y_next in zip( range( 1, n+1 ), y_s, y_s[1:] + [0.0] ):
print( "i=", i, "y_i=", y_i, "y_next=", y_next )
sum_y += y_i
t = (sum_y - 1) / i
print( "sum=", sum_y, "t_i=", t )
if t >= y_next:
break
print( "t_hat=", t )
# t_hat has the same form if we get all the way to i = 0
return [ max( 0, y_i - t ) for y_i in y ]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment