Skip to content

Instantly share code, notes, and snippets.

@tvogels
Created March 31, 2021 07:59
Show Gist options
  • Save tvogels/265560fbe2afe19b7e5084ac2e8367b0 to your computer and use it in GitHub Desktop.
Save tvogels/265560fbe2afe19b7e5084ac2e8367b0 to your computer and use it in GitHub Desktop.
Variance reduction SGD algorithms
#%%
import contextlib
import os
import random
import sys
from copy import deepcopy
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib import pyplot as plt
from sklearn import datasets
from torch.utils import data
from mathdict import MathDict
#%%
MAX_EPOCHS = int(os.getenv('MAX_EPOCHS', 1000))
#%% Fix the seeds
def fix_seeds(seed=10):
random.seed(seed)
torch.random.manual_seed(seed)
np.random.seed(seed)
#%% Create a dataset
fix_seeds()
X, y, optimum = datasets.make_regression(coef=True, noise=1.0, n_features=50, n_informative=20)
X = X.astype(np.float32)
y = y.astype(np.float32)
X = torch.from_numpy(X)
y = torch.from_numpy(y)
generated_dataset = data.TensorDataset(X, y)
#%% Define the initial model
fix_seeds()
start_model = torch.nn.Linear(50, 1)
#%%
class Problem():
def __init__(self, dataset, start_model, loss=torch.nn.MSELoss()):
self.model = deepcopy(start_model)
self.params = self.model.state_dict(keep_vars=True)
self.loss = loss
self.dataset = dataset
def state(self, clone=True):
"""Current set of parameters"""
if clone:
return MathDict({ key: value.data.clone() for key, value in self.params.items() })
else:
return MathDict({ key: value.data for key, value in self.params.items() })
def partial_gradient(self, data_index, at_state=None):
"""Get a partial gradient at a certain datapoint. Optionally, you can pass a parameter state at which to compute"""
x, y = self.dataset[data_index]
with self.temporary_state(at_state):
self.zero_grad()
loss = self.loss(self.model(x).view(-1), y.view(-1))
loss.backward()
return MathDict({ key: value.grad.clone() for key, value in self.params.items() })
@contextlib.contextmanager
def temporary_state(self, state):
"""Temporarily use a different parameter for evaluations inside this context.
Does nothing for state==None"""
if state is not None:
old_state = self.state(clone=True)
self.set_state(state)
yield
if state is not None:
self.set_state(old_state)
def full_gradient(self, at_state=None):
"""Compute a full gradient"""
x, y = self.dataset[:]
with self.temporary_state(at_state):
self.zero_grad()
loss = self.loss(self.model(x).view(-1), y.view(-1))
loss.backward()
return MathDict({ key: value.grad.clone() for key, value in self.params.items() })
def mean_loss(self):
"""Compute the mean loss over the whole dataset"""
self.model.eval()
x, y = self.dataset[:]
loss = self.loss(self.model(x).view(-1), y.view(-1))
self.model.train()
return loss.item()
def zero_grad(self):
"""Set all gradients on the model to zero. Pytorch autodiff needs this."""
for key, value in self.params.items():
if value.grad is not None:
value.grad.data.fill_(0.0)
def set_state(self, new_state):
"""Got to a certain parameter state"""
for key, value in self.params.items():
value.data = new_state.get(key)
def apply_gradient(self, gradient, learning_rate):
"""Make a step in the negative gradient direction"""
for key, value in self.params.items():
value.data.add_(-learning_rate, gradient.get(key))
class RunningAverage:
def __init__(self, update_weight=1):
self.average = None
self.counter = 0
self.update_weight = update_weight
def add(self, value):
self.counter += 1
if self.average is None:
self.average = deepcopy(value)
else:
delta = value - self.average
self.average += delta * float(self.update_weight) / float(self.counter + self.update_weight - 1)
if type(self.average) == torch.Tensor:
self.average.detach()
def reset(self):
self.average = None
self.counter = 0
#%% Find the optimal loss
problem = Problem(generated_dataset, start_model)
learning_rate = 0.27
for epoch in range(100*MAX_EPOCHS):
# Train one epoch
gradient = problem.full_gradient()
problem.apply_gradient(gradient, learning_rate)
optimal_model = deepcopy(problem.model)
loss_at_optimum = problem.mean_loss()
print('Loss at the optimum:', loss_at_optimum)
#%% Keep a record of all the result so we can plot it later
results = []
result_id = 0
def evaluate(problem, method_name, epoch, gradients_accessed, iterations, learning_rate):
global result_id
global results
mean_loss = problem.mean_loss()
results.append({ 'method': method_name, 'epoch': epoch, 'loss': mean_loss, 'learning_rate': learning_rate, 'gradients_accessed': gradients_accessed, 'iterations': iterations })
if (epoch % 10) == 0:
print('{epoch:06d}: {method_name} loss = {mean_loss:.4f}'.format(**vars()))
#%% SGD
problem = Problem(generated_dataset, start_model)
n = len(problem.dataset)
learning_rate = 0.0092
tavg = RunningAverage(update_weight=2)
for epoch in range(MAX_EPOCHS):
# Train one epoch
for datapoint in np.random.permutation(len(problem.dataset)):
partial_grad = problem.partial_gradient(datapoint)
problem.apply_gradient(partial_grad, learning_rate)
tavg.add(problem.state(clone=False))
# Evaluate
evaluate(problem, 'SGD', epoch, n * epoch, n * epoch, learning_rate)
with problem.temporary_state(tavg.average):
evaluate(problem, 'SGD (t-avg)', epoch, n * epoch, n * epoch, learning_rate)
#%% Minibatch SGD
problem = Problem(generated_dataset, start_model)
n = len(problem.dataset)
learning_rate = 0.0092
tavg = RunningAverage(update_weight=2)
for epoch in range(MAX_EPOCHS):
# Train one epoch
permutation = np.random.permutation(len(problem.dataset))
for datapoints in zip(permutation[0::2], permutation[1::2]):
partial_grad = problem.partial_gradient(list(datapoints))
problem.apply_gradient(partial_grad, learning_rate)
tavg.add(problem.state(clone=False))
# Evaluate
evaluate(problem, 'Minibatch SGD', epoch, n * epoch, n * epoch, learning_rate)
with problem.temporary_state(tavg.average):
evaluate(problem, 'Minibatch SGD (t-avg)', epoch, n * epoch, n * epoch, learning_rate)
#%% (Full batch) gradient descent
problem = Problem(generated_dataset, start_model)
learning_rate = 0.27
for epoch in range(MAX_EPOCHS):
# Train one epoch
gradient = problem.full_gradient()
problem.apply_gradient(gradient, learning_rate)
# Evaluate
evaluate(problem, 'GD', epoch, n * epoch, epoch, learning_rate)
#%% Train with momentum (optimized learning rate)
problem = Problem(generated_dataset, start_model)
learning_rate = 0.0009
momentum = 0.9
accumulation = 0.0
tavg = RunningAverage(update_weight=2)
for epoch in range(MAX_EPOCHS):
# Train one epoch
for datapoint in np.random.permutation(len(generated_dataset)):
partial_grad = problem.partial_gradient(datapoint)
accumulation = partial_grad + accumulation * momentum
problem.apply_gradient(accumulation, learning_rate)
tavg.add(problem.state(clone=False))
# Evaluate
evaluate(problem, 'Momentum', epoch,n * epoch, n * epoch, learning_rate)
with problem.temporary_state(tavg.average):
evaluate(problem, 'Momentum (t-avg)', epoch,n * epoch, n * epoch, learning_rate)
#%% SAG
problem = Problem(generated_dataset, start_model)
learning_rate = 0.005
n = len(generated_dataset)
gradients = [0.0 for i in range(n)]
mean_gradient = 0.0
for epoch in range(MAX_EPOCHS):
# Train one epoch
for datapoint in np.random.permutation(n):
previous_partial_gradient = gradients[datapoint]
gradients[datapoint] = problem.partial_gradient(datapoint)
mean_gradient = (gradients[datapoint] - previous_partial_gradient) / n + mean_gradient
problem.apply_gradient(mean_gradient, learning_rate)
# Evaluate
evaluate(problem, 'SAG', epoch, n * epoch, n * epoch, learning_rate)
#%% SAGA
problem = Problem(generated_dataset, start_model)
n = len(generated_dataset)
learning_rate = 0.0024
gradients = [0.0 for i in range(n)]
mean_gradient = 0.0
for epoch in range(MAX_EPOCHS):
# Train one epoch
for datapoint in np.random.permutation(n):
previous_partial_gradient = gradients[datapoint]
gradients[datapoint] = problem.partial_gradient(datapoint)
update = gradients[datapoint] - previous_partial_gradient + mean_gradient
problem.apply_gradient(update, learning_rate)
mean_gradient = (gradients[datapoint] - previous_partial_gradient) / n + mean_gradient
# Evaluate
evaluate(problem, 'SAGA', epoch, n * epoch, n * epoch, learning_rate)
#%% SVRG
problem = Problem(generated_dataset, start_model)
n = len(generated_dataset)
learning_rate = 0.006
for epoch in range(MAX_EPOCHS//2):
# Train one epoch
# 1. Compute a full gradient
snapshot_point = problem.state(clone=True)
snapshot_gradient = problem.full_gradient()
for datapoint in np.random.permutation(n):
cur_grad = problem.partial_gradient(datapoint)
partial_grad_at_snapshot = problem.partial_gradient(datapoint, at_state=snapshot_point)
update = snapshot_gradient - partial_grad_at_snapshot + cur_grad
problem.apply_gradient(update, learning_rate)
# Evaluate
evaluate(problem, 'SVRG', epoch, 2 * n * epoch, n * epoch, learning_rate)
#%% Plot
sns.set_style('darkgrid')
sns.set_palette(None)
df = pd.DataFrame(results)
df['excess loss'] = df.loss - loss_at_optimum
pd.to_pickle(df, 'variance_reduction_techniques.out.pickle')
df = df[df['excess loss'] < 1e4]
for group, subdf in df.groupby('method'):
plt.semilogy('gradients_accessed', 'excess loss', data=subdf, label=group, alpha=0.8, linewidth=1)
plt.xlabel('# gradients accessed')
plt.ylabel('Excess loss (L2)')
plt.legend()
plt.savefig('variance_reduction_techniques.out.pdf', facecolor='w')
#%% Reset results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment