Last active
April 10, 2019 04:40
-
-
Save jdthorpe/386db030e0c10fc591b718fe07293b70 to your computer and use it in GitHub Desktop.
Take a step in the direction of the gradient restricted to the positive simplex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<svg height="210" width="400"> | |
<path d="M150 0 L75 200 L225 200 Z" /> | |
Sorry, your browser does not support inline SVG. | |
</svg> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Gradient descent restricted to the positive simplex | |
Method inspired by the qustion "what would a water droplet stuckin insise the | |
positive simplex go when pulled in the direction of the gradient which would | |
otherwise take the droplet outside of the simplex | |
""" | |
from numpy import ( | |
array, | |
append, | |
arange, | |
random, | |
zeros, | |
ones, | |
logical_or, | |
logical_not, | |
logical_and, | |
argmin, | |
where, | |
) | |
# pylint: disable=invalid-name | |
def _sub_simplex_project(d_hat, indx): | |
""" | |
A utility function which projects the gradient into the subspace of the | |
simplex which intersects the plane x[_index] == 0 | |
""" | |
# now project the gradient perpendicular to the edge we just came up against | |
n = len(d_hat) | |
_n = float(n) | |
a_dot_a = (n - 1) / _n | |
a_tilde = -ones(n) / _n | |
a_tilde[indx] += 1 # plus a' | |
proj_a_d = (d_hat.dot(a_tilde) / a_dot_a) * a_tilde | |
d_tilde = d_hat - proj_a_d | |
return d_tilde | |
def step(x, g, verbose=True): | |
""" | |
follow the gradint as far as you can within the positive simplex | |
""" | |
i = 0 | |
x, g = x.copy(), g.copy() | |
# project the gradient into the simplex | |
g = g - (g.sum() / len(x)) * ones(len(g)) | |
while True: | |
# we can move in the direction of the gradient if either | |
# (a) the gradient points away from the axis | |
# (b) we're not yet touching the axis | |
valid_directions = logical_or(g < 0, x > 0) | |
if verbose: | |
print( | |
" valid_directions(%s, %s, %s): %s " | |
% ( | |
valid_directions.sum(), | |
(g < 0).sum(), | |
(x > 0).sum(), | |
", ".join(str(x) for x in valid_directions), | |
) | |
) | |
if not valid_directions.any(): | |
break | |
if any(g[logical_not(valid_directions)] != 0): | |
# TODO: make sure is is invariant on the order of operations | |
n_valid = where(valid_directions)[0] | |
W = where(logical_not(valid_directions))[0] | |
for i, _w in enumerate(W): | |
# TODO: Project the invalid directions into the current (valid) subspace of | |
# the simplex | |
mask = append(array(_w), n_valid) | |
print("work in progress") | |
g[mask] = _sub_simplex_project(g[mask], 0) | |
g[_w] = 0 # may not be exactly zero due to rounding error | |
i += 1 | |
if i > len(x): | |
raise RuntimeError("something went wrong") | |
# HOW FAR CAN WE GO? | |
limit_directions = logical_and(valid_directions, g > 0) | |
xl = x[limit_directions] | |
gl = g[limit_directions] | |
ratios = xl / gl | |
c = ratios.min() | |
if c > 1: | |
x = x - g | |
break | |
arange(len(g)) | |
indx = argmin(ratios) | |
# MOVE | |
# there's gotta be a better way... | |
_indx = where(limit_directions)[0][indx] | |
tmp = -ones(len(x)) | |
tmp[valid_directions] = arange(valid_directions.sum()) | |
__indx = int(tmp[_indx]) | |
# get the index | |
del xl, gl, ratios | |
x = x - c * g | |
# PROJECT THE GRADIENT | |
d_tilde = _sub_simplex_project(g[valid_directions] * (1 - c), __indx) | |
if verbose: | |
print( | |
"i: %s, which: %s, g.sum(): %f, x.sum(): %f, x[i]: %f, g[i]: %f, d_tilde[i]: %f" | |
% ( | |
i, | |
indx, | |
g.sum(), | |
x.sum(), | |
x[valid_directions][__indx], | |
g[valid_directions][__indx], | |
d_tilde[__indx], | |
) | |
) | |
g[valid_directions] = d_tilde | |
# handle rounding error... | |
x[_indx] = 0 | |
g[_indx] = 0 | |
return x | |
if __name__ == "__main__": | |
# SETUP---------------------------------- | |
random.seed(10101) | |
dim = 14 | |
ZEROS_1 = random.choice(dim, 10) | |
ZEROS_2 = random.choice(dim, 8) | |
X0 = random.random(dim) | |
X0 = X0 / X0.sum() | |
# set most of the directions to zero | |
GRAD_1 = random.normal(0, 0.25, dim) | |
GRAD_1[ZEROS_1] = zeros(len(ZEROS_1)) | |
# set most of the directions to zero | |
GRAD_2 = random.normal(0, 0.25, dim) | |
GRAD_2[ZEROS_2] = zeros(len(ZEROS_2)) | |
# END SETUP------------------------------ | |
print("[GRADIENT STEP 1]") | |
X1 = step(X0, GRAD_1) | |
print("[GRADIENT STEP 2]") | |
X2 = step(X1, GRAD_2) | |
print("[DONE]") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment