Created
November 13, 2019 05:38
-
-
Save mgritter/4bf003cd399da2e57096af1050d64ddd to your computer and use it in GitHub Desktop.
Chen and Ye's algorithm for simplex projection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def project_onto_standard_simplex( y ): | |
| """Find the nearest neighbor of y on the |y|-element standard simplex | |
| x_1 + ... + x_n = 1 | |
| See Yunmei Chen and Xiaojing Ye, "Projection Onto a Simplex", | |
| https://arxiv.org/abs/1101.6081 | |
| """ | |
| n = len( y ) | |
| y_s = sorted( y, reverse=True ) | |
| # Sum the i largest y's. | |
| # Compute t_i = (sum) / (i) | |
| # if t_i >= next-smallest y, break | |
| # | |
| # In the paper this is reversed, so we get | |
| # t_i = ( -1 + sum_{j=i+1}^{n} y_j ) / (n-i) | |
| sum_y = 0 | |
| for i, y_i, y_next in zip( range( 1, n+1 ), y_s, y_s[1:] + [0.0] ): | |
| print( "i=", i, "y_i=", y_i, "y_next=", y_next ) | |
| sum_y += y_i | |
| t = (sum_y - 1) / i | |
| print( "sum=", sum_y, "t_i=", t ) | |
| if t >= y_next: | |
| break | |
| print( "t_hat=", t ) | |
| # t_hat has the same form if we get all the way to i = 0 | |
| return [ max( 0, y_i - t ) for y_i in y ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment