Created
March 31, 2021 07:59
-
-
Save tvogels/265560fbe2afe19b7e5084ac2e8367b0 to your computer and use it in GitHub Desktop.
Variance reduction SGD algorithms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import contextlib | |
import os | |
import random | |
import sys | |
from copy import deepcopy | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import torch | |
from matplotlib import pyplot as plt | |
from sklearn import datasets | |
from torch.utils import data | |
from mathdict import MathDict | |
#%% | |
MAX_EPOCHS = int(os.getenv('MAX_EPOCHS', 1000)) | |
#%% Fix the seeds | |
def fix_seeds(seed=10): | |
random.seed(seed) | |
torch.random.manual_seed(seed) | |
np.random.seed(seed) | |
#%% Create a dataset | |
fix_seeds() | |
X, y, optimum = datasets.make_regression(coef=True, noise=1.0, n_features=50, n_informative=20) | |
X = X.astype(np.float32) | |
y = y.astype(np.float32) | |
X = torch.from_numpy(X) | |
y = torch.from_numpy(y) | |
generated_dataset = data.TensorDataset(X, y) | |
#%% Define the initial model | |
fix_seeds() | |
start_model = torch.nn.Linear(50, 1) | |
#%% | |
class Problem(): | |
def __init__(self, dataset, start_model, loss=torch.nn.MSELoss()): | |
self.model = deepcopy(start_model) | |
self.params = self.model.state_dict(keep_vars=True) | |
self.loss = loss | |
self.dataset = dataset | |
def state(self, clone=True): | |
"""Current set of parameters""" | |
if clone: | |
return MathDict({ key: value.data.clone() for key, value in self.params.items() }) | |
else: | |
return MathDict({ key: value.data for key, value in self.params.items() }) | |
def partial_gradient(self, data_index, at_state=None): | |
"""Get a partial gradient at a certain datapoint. Optionally, you can pass a parameter state at which to compute""" | |
x, y = self.dataset[data_index] | |
with self.temporary_state(at_state): | |
self.zero_grad() | |
loss = self.loss(self.model(x).view(-1), y.view(-1)) | |
loss.backward() | |
return MathDict({ key: value.grad.clone() for key, value in self.params.items() }) | |
@contextlib.contextmanager | |
def temporary_state(self, state): | |
"""Temporarily use a different parameter for evaluations inside this context. | |
Does nothing for state==None""" | |
if state is not None: | |
old_state = self.state(clone=True) | |
self.set_state(state) | |
yield | |
if state is not None: | |
self.set_state(old_state) | |
def full_gradient(self, at_state=None): | |
"""Compute a full gradient""" | |
x, y = self.dataset[:] | |
with self.temporary_state(at_state): | |
self.zero_grad() | |
loss = self.loss(self.model(x).view(-1), y.view(-1)) | |
loss.backward() | |
return MathDict({ key: value.grad.clone() for key, value in self.params.items() }) | |
def mean_loss(self): | |
"""Compute the mean loss over the whole dataset""" | |
self.model.eval() | |
x, y = self.dataset[:] | |
loss = self.loss(self.model(x).view(-1), y.view(-1)) | |
self.model.train() | |
return loss.item() | |
def zero_grad(self): | |
"""Set all gradients on the model to zero. Pytorch autodiff needs this.""" | |
for key, value in self.params.items(): | |
if value.grad is not None: | |
value.grad.data.fill_(0.0) | |
def set_state(self, new_state): | |
"""Got to a certain parameter state""" | |
for key, value in self.params.items(): | |
value.data = new_state.get(key) | |
def apply_gradient(self, gradient, learning_rate): | |
"""Make a step in the negative gradient direction""" | |
for key, value in self.params.items(): | |
value.data.add_(-learning_rate, gradient.get(key)) | |
class RunningAverage: | |
def __init__(self, update_weight=1): | |
self.average = None | |
self.counter = 0 | |
self.update_weight = update_weight | |
def add(self, value): | |
self.counter += 1 | |
if self.average is None: | |
self.average = deepcopy(value) | |
else: | |
delta = value - self.average | |
self.average += delta * float(self.update_weight) / float(self.counter + self.update_weight - 1) | |
if type(self.average) == torch.Tensor: | |
self.average.detach() | |
def reset(self): | |
self.average = None | |
self.counter = 0 | |
#%% Find the optimal loss | |
problem = Problem(generated_dataset, start_model) | |
learning_rate = 0.27 | |
for epoch in range(100*MAX_EPOCHS): | |
# Train one epoch | |
gradient = problem.full_gradient() | |
problem.apply_gradient(gradient, learning_rate) | |
optimal_model = deepcopy(problem.model) | |
loss_at_optimum = problem.mean_loss() | |
print('Loss at the optimum:', loss_at_optimum) | |
#%% Keep a record of all the result so we can plot it later | |
results = [] | |
result_id = 0 | |
def evaluate(problem, method_name, epoch, gradients_accessed, iterations, learning_rate): | |
global result_id | |
global results | |
mean_loss = problem.mean_loss() | |
results.append({ 'method': method_name, 'epoch': epoch, 'loss': mean_loss, 'learning_rate': learning_rate, 'gradients_accessed': gradients_accessed, 'iterations': iterations }) | |
if (epoch % 10) == 0: | |
print('{epoch:06d}: {method_name} loss = {mean_loss:.4f}'.format(**vars())) | |
#%% SGD | |
problem = Problem(generated_dataset, start_model) | |
n = len(problem.dataset) | |
learning_rate = 0.0092 | |
tavg = RunningAverage(update_weight=2) | |
for epoch in range(MAX_EPOCHS): | |
# Train one epoch | |
for datapoint in np.random.permutation(len(problem.dataset)): | |
partial_grad = problem.partial_gradient(datapoint) | |
problem.apply_gradient(partial_grad, learning_rate) | |
tavg.add(problem.state(clone=False)) | |
# Evaluate | |
evaluate(problem, 'SGD', epoch, n * epoch, n * epoch, learning_rate) | |
with problem.temporary_state(tavg.average): | |
evaluate(problem, 'SGD (t-avg)', epoch, n * epoch, n * epoch, learning_rate) | |
#%% Minibatch SGD | |
problem = Problem(generated_dataset, start_model) | |
n = len(problem.dataset) | |
learning_rate = 0.0092 | |
tavg = RunningAverage(update_weight=2) | |
for epoch in range(MAX_EPOCHS): | |
# Train one epoch | |
permutation = np.random.permutation(len(problem.dataset)) | |
for datapoints in zip(permutation[0::2], permutation[1::2]): | |
partial_grad = problem.partial_gradient(list(datapoints)) | |
problem.apply_gradient(partial_grad, learning_rate) | |
tavg.add(problem.state(clone=False)) | |
# Evaluate | |
evaluate(problem, 'Minibatch SGD', epoch, n * epoch, n * epoch, learning_rate) | |
with problem.temporary_state(tavg.average): | |
evaluate(problem, 'Minibatch SGD (t-avg)', epoch, n * epoch, n * epoch, learning_rate) | |
#%% (Full batch) gradient descent | |
problem = Problem(generated_dataset, start_model) | |
learning_rate = 0.27 | |
for epoch in range(MAX_EPOCHS): | |
# Train one epoch | |
gradient = problem.full_gradient() | |
problem.apply_gradient(gradient, learning_rate) | |
# Evaluate | |
evaluate(problem, 'GD', epoch, n * epoch, epoch, learning_rate) | |
#%% Train with momentum (optimized learning rate) | |
problem = Problem(generated_dataset, start_model) | |
learning_rate = 0.0009 | |
momentum = 0.9 | |
accumulation = 0.0 | |
tavg = RunningAverage(update_weight=2) | |
for epoch in range(MAX_EPOCHS): | |
# Train one epoch | |
for datapoint in np.random.permutation(len(generated_dataset)): | |
partial_grad = problem.partial_gradient(datapoint) | |
accumulation = partial_grad + accumulation * momentum | |
problem.apply_gradient(accumulation, learning_rate) | |
tavg.add(problem.state(clone=False)) | |
# Evaluate | |
evaluate(problem, 'Momentum', epoch,n * epoch, n * epoch, learning_rate) | |
with problem.temporary_state(tavg.average): | |
evaluate(problem, 'Momentum (t-avg)', epoch,n * epoch, n * epoch, learning_rate) | |
#%% SAG | |
problem = Problem(generated_dataset, start_model) | |
learning_rate = 0.005 | |
n = len(generated_dataset) | |
gradients = [0.0 for i in range(n)] | |
mean_gradient = 0.0 | |
for epoch in range(MAX_EPOCHS): | |
# Train one epoch | |
for datapoint in np.random.permutation(n): | |
previous_partial_gradient = gradients[datapoint] | |
gradients[datapoint] = problem.partial_gradient(datapoint) | |
mean_gradient = (gradients[datapoint] - previous_partial_gradient) / n + mean_gradient | |
problem.apply_gradient(mean_gradient, learning_rate) | |
# Evaluate | |
evaluate(problem, 'SAG', epoch, n * epoch, n * epoch, learning_rate) | |
#%% SAGA | |
problem = Problem(generated_dataset, start_model) | |
n = len(generated_dataset) | |
learning_rate = 0.0024 | |
gradients = [0.0 for i in range(n)] | |
mean_gradient = 0.0 | |
for epoch in range(MAX_EPOCHS): | |
# Train one epoch | |
for datapoint in np.random.permutation(n): | |
previous_partial_gradient = gradients[datapoint] | |
gradients[datapoint] = problem.partial_gradient(datapoint) | |
update = gradients[datapoint] - previous_partial_gradient + mean_gradient | |
problem.apply_gradient(update, learning_rate) | |
mean_gradient = (gradients[datapoint] - previous_partial_gradient) / n + mean_gradient | |
# Evaluate | |
evaluate(problem, 'SAGA', epoch, n * epoch, n * epoch, learning_rate) | |
#%% SVRG | |
problem = Problem(generated_dataset, start_model) | |
n = len(generated_dataset) | |
learning_rate = 0.006 | |
for epoch in range(MAX_EPOCHS//2): | |
# Train one epoch | |
# 1. Compute a full gradient | |
snapshot_point = problem.state(clone=True) | |
snapshot_gradient = problem.full_gradient() | |
for datapoint in np.random.permutation(n): | |
cur_grad = problem.partial_gradient(datapoint) | |
partial_grad_at_snapshot = problem.partial_gradient(datapoint, at_state=snapshot_point) | |
update = snapshot_gradient - partial_grad_at_snapshot + cur_grad | |
problem.apply_gradient(update, learning_rate) | |
# Evaluate | |
evaluate(problem, 'SVRG', epoch, 2 * n * epoch, n * epoch, learning_rate) | |
#%% Plot | |
sns.set_style('darkgrid') | |
sns.set_palette(None) | |
df = pd.DataFrame(results) | |
df['excess loss'] = df.loss - loss_at_optimum | |
pd.to_pickle(df, 'variance_reduction_techniques.out.pickle') | |
df = df[df['excess loss'] < 1e4] | |
for group, subdf in df.groupby('method'): | |
plt.semilogy('gradients_accessed', 'excess loss', data=subdf, label=group, alpha=0.8, linewidth=1) | |
plt.xlabel('# gradients accessed') | |
plt.ylabel('Excess loss (L2)') | |
plt.legend() | |
plt.savefig('variance_reduction_techniques.out.pdf', facecolor='w') | |
#%% Reset results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment