Skip to content

Instantly share code, notes, and snippets.

@sir-wabbit
Created February 17, 2017 18:31
Show Gist options
  • Select an option

  • Save sir-wabbit/2917182bb0f7d612fe647d14cf9d2792 to your computer and use it in GitHub Desktop.

Select an option

Save sir-wabbit/2917182bb0f7d612fe647d14cf9d2792 to your computer and use it in GitHub Desktop.
def train_adadelta(func, fprime, x0, x_min, x_max,
step_rate=1, decay=0.9, momentum=0, offset=1e-4,
iterations=1000, tolerance=1e-1):
x_min = x_min if x_min is not None else (np.ones_like(x0) * (-np.inf))
x_max = x_max if x_max is not None else (np.ones_like(x0) * np.inf)
clamp = lambda x: np.where(x < x_min, x_min,
np.where(x > x_max, x_max, x))
bounce = lambda x, dx: np.where(x + dx < x_min, x_min - x,
np.where(x + dx > x_max, x_max - x, dx))
x = clamp(x0.copy())
start_nll = np.squeeze(func(x))
min = (x.copy(), 0, start_nll)
Eg2 = 0
Edx2 = 0
step = 0
retries = 0
increase = 0
for i in range(iterations):
if retries == 0:
step1 = bounce(x, step * momentum)
x += step1
g = fprime(x)
Eg2 = decay * Eg2 + (1 - decay) * g ** 2
step2 = -(np.sqrt(Edx2 + offset) / np.sqrt(Eg2 + offset) *
g * step_rate)
step2 = bounce(x, step2)
step = step1 + step2
Edx2 = decay * Edx2 + (1 - decay) * step ** 2
x += step2
nll = np.squeeze(func(x))
(_, _, min_nll) = min
if min_nll > nll:
min = (x.copy(), i + 1, nll)
retries = 0
increase = 0
if np.abs(min_nll - nll) < tolerance:
break
else:
increase += 1
if increase >= 5:
retries += 1
increase = 0
else:
(x, _, min_nll) = min
step = -momentum * step
step = bounce(x, step)
x1 = x + step
nll = np.squeeze(func(x1))
if min_nll > nll:
Edx2 = step ** 2
x = x1
min = (x1.copy(), i + 1, nll)
retries = 0
if np.abs(min_nll - nll) < tolerance:
break
elif abs(step).max() >= 1e-10:
#print(nll, min_nll)
retries += 1
else:
# If the step is too short, just return 0.
break
print("NLL (%d/%d): %s" % (i + 1, retries, nll))
print("NLL (start): %s" % start_nll)
print("NLL (done): %s" % nll)
status = {
'grad_norm': la.norm(g),
'start_nll': start_nll,
'done_nll': nll,
'min_it': min[1],
'min_nll': min[-1]}
return (min[0], min[-1], status)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment