Skip to content

Instantly share code, notes, and snippets.

@bkj
Last active December 20, 2017 19:16
Show Gist options
  • Save bkj/212968d95265c5c76ba8d7ba2420529f to your computer and use it in GitHub Desktop.
Save bkj/212968d95265c5c76ba8d7ba2420529f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
sgdr.py
Code to reproduce Fig. 1 from https://arxiv.org/pdf/1608.03983.pdf
"""
import numpy as np
from matplotlib import pyplot as plt
def power_sum(base, k):
return (base ** (k + 1) - 1) / (base - 1)
def inv_power_sum(x, base):
return np.log(x * (base - 1) + 1) / np.log(base) - 1
def sgdr(progress, period_length=50, lr_max=0.05, lr_min=0, t_mult=1):
if t_mult > 1:
period_id = np.floor(inv_power_sum(progress / period_length, t_mult)) + 1
offsets = power_sum(t_mult, period_id - 1) * period_length
period_progress = (progress - offsets) / (t_mult ** period_id * period_length)
else:
period_progress = (progress % period_length) / period_length
return lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(period_progress * np.pi))
# --
# Test
progress = np.linspace(0, 200, 10000)
_ = plt.plot(progress, sgdr(progress, period_length=50), c='lime')
_ = plt.plot(progress, sgdr(progress, period_length=100), c='black')
_ = plt.plot(progress, sgdr(progress, period_length=200), c='grey')
_ = plt.plot(progress, sgdr(progress, period_length=10, t_mult=2), c='magenta')
_ = plt.plot(progress, sgdr(progress, period_length=50, t_mult=2), c='cyan')
_ = plt.yscale('log')
_ = plt.ylim(1e-5, 1)
_ = plt.xticks(np.arange(0, 200, 20))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment