Created
August 22, 2018 23:31
-
-
Save stsievert/885a344bae0dcd247e79380e40d1ba9a to your computer and use it in GitHub Desktop.
Successive halving and "stop on plateau" comparison
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import skorch.utils | |
from skorch import NeuralNetRegressor | |
import torch.nn as nn | |
import torch | |
import skorch | |
from distributed.utils import log_errors | |
def _initialize(method, layer, gain=1): | |
weight = layer.weight.data | |
_before = weight.data.clone() | |
kwargs = {'gain': gain} if 'xavier' in str(method) else {} | |
method(weight.data, **kwargs) | |
assert torch.all(weight.data != _before) | |
class Autoencoder(nn.Module): | |
def __init__(self, activation='ReLU', init='xavier_uniform_', | |
**kwargs): | |
super().__init__() | |
self.activation = activation | |
self.init = init | |
self._iters = 0 | |
init_method = getattr(torch.nn.init, init) | |
act_layer = getattr(nn, activation) | |
act_kwargs = {'inplace': True} if self.activation != 'PReLU' else {} | |
gain = 1 | |
if self.activation in ['LeakyReLU', 'ReLU']: | |
name = 'leaky_relu' if self.activation == 'LeakyReLU' else 'relu' | |
gain = torch.nn.init.calculate_gain(name) | |
inter_dim = 28 * 28 // 4 | |
latent_dim = inter_dim // 4 | |
layers = [ | |
nn.Linear(28 * 28, inter_dim), | |
act_layer(**act_kwargs), | |
nn.Linear(inter_dim, latent_dim), | |
act_layer(**act_kwargs) | |
] | |
for layer in layers: | |
if hasattr(layer, 'weight') and layer.weight.data.dim() > 1: | |
_initialize(init_method, layer) | |
self.encoder = nn.Sequential(*layers) | |
layers = [ | |
nn.Linear(latent_dim, inter_dim), | |
act_layer(**act_kwargs), | |
nn.Linear(inter_dim, 28 * 28), | |
nn.Sigmoid() | |
] | |
layers = [ | |
nn.Linear(latent_dim, 28 * 28), | |
nn.Sigmoid() | |
] | |
for layer in layers: | |
if hasattr(layer, 'weight') and layer.weight.data.dim() > 1: | |
_initialize(init_method, layer) | |
self.decoder = nn.Sequential(*layers) | |
def forward(self, x): | |
self._iters += 1 | |
shape = x.size() | |
x = x.view(x.shape[0], -1) | |
x = self.encoder(x) | |
x = self.decoder(x) | |
return x.view(shape) | |
class NegLossScore(NeuralNetRegressor): | |
steps = 0 | |
def partial_fit(self, *args, **kwargs): | |
super().partial_fit(*args, **kwargs) | |
self.steps += 1 | |
def score(self, X, y): | |
X = skorch.utils.to_tensor(X, device=self.device) | |
y = skorch.utils.to_tensor(y, device=self.device) | |
self.initialize_criterion() | |
y_hat = self.predict(X) | |
y_hat = skorch.utils.to_tensor(y_hat, device=self.device) | |
loss = super().get_loss(y_hat, y, X=X, training=False).item() | |
print(f'steps = {self.steps}, loss = {loss}') | |
return -1 * loss | |
def initialize(self, *args, **kwargs): | |
super().initialize(*args, **kwargs) | |
self.callbacks_ = [] | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import toolz | |
import numpy as np | |
from time import time | |
def stop_on_plateau(info, patience=10, tol=0.001, max_iter=None): | |
out = {} | |
for ident, records in info.items(): | |
pf_calls = records[-1]['partial_fit_calls'] | |
if max_iter is not None and pf_calls > max_iter: | |
out[ident] = 0 | |
elif pf_calls > patience: | |
# old = records[-patience]['score'] | |
plateau = {d['partial_fit_calls']: d['score'] | |
for d in records | |
if pf_calls - patience <= d['partial_fit_calls']} | |
plateau_start = plateau[min(plateau)] | |
if all(score < plateau_start + tol for score in plateau.values()): | |
out[ident] = 0 | |
else: | |
out[ident] = 1 | |
else: | |
out[ident] = 1 | |
return out | |
def _hyperband_paper_alg(R, eta=3): | |
""" | |
Algorithm 1 from the Hyperband paper [1]_. | |
References | |
---------- | |
1. "Hyperband: A novel bandit-based approach to hyperparameter | |
optimization", 2016 by L. Li, K. Jamieson, G. DeSalvo, A. Rostamizadeh, | |
and A. Talwalkar. https://arxiv.org/abs/1603.06560 | |
""" | |
s_max = math.floor(math.log(R, eta)) | |
B = (s_max + 1) * R | |
brackets = reversed(range(int(s_max + 1))) | |
hists = {} | |
for s in brackets: | |
n = int(math.ceil(B / R * eta ** s / (s + 1))) | |
r = int(R * eta ** -s) | |
T = set(range(n)) | |
hist = { | |
"num_models": n, | |
"models": {n: 0 for n in range(n)}, | |
"iters": [], | |
} | |
for i in range(s + 1): | |
n_i = math.floor(n * eta ** -i) | |
r_i = np.round(r * eta ** i).astype(int) | |
L = {model: r_i for model in T} | |
hist["models"].update(L) | |
hist["iters"] += [r_i] | |
to_keep = math.floor(n_i / eta) | |
T = {model for i, model in enumerate(T) if i < to_keep} | |
hists["bracket={s}".format(s=s)] = hist | |
info = [ | |
{ | |
"bracket": k, | |
"num_models": hist["num_models"], | |
"num_partial_fit_calls": sum(hist["models"].values()), | |
"iters": {int(h) for h in hist["iters"]}, | |
} | |
for k, hist in hists.items() | |
] | |
return info | |
class SHA: | |
def __init__(self, n, r, eta=3, limit=None, | |
patience=np.inf, tol=0.001): | |
""" | |
Perform the successive halving algorithm. | |
Parameters | |
---------- | |
n : int | |
Number of models to evaluate initially | |
r : int | |
Number of times to call partial fit initially | |
eta : float, default=3 | |
How aggressive to be in culling off the models. Higher | |
values correspond to being more aggressive in killing off | |
models. The "infinite horizon" theory suggests eta=np.e=2.718... | |
is optimal. | |
patience : int | |
Passed to `stop_on_plateau` | |
tol : int | |
Passed to `stop_on_plateau` | |
""" | |
self.steps = 0 | |
self.n = n | |
self.r = r | |
self.eta = eta | |
self.meta = [] | |
self.start = time() | |
self.patience = patience | |
self.tol = tol | |
self.limit = limit | |
def fit(self, info): | |
n, r, eta = self.n, self.r, self.eta | |
n_i = math.floor(n * eta ** -self.steps) | |
r_i = np.round(r * eta**self.steps).astype(int) | |
# Initial case | |
# partial fit has already been called once | |
if r_i == 1: | |
# if r_i == 1, a step has already been completed for us | |
assert self.steps == 0 | |
self.steps = 1 | |
pf_calls = {k: info[k][-1]['partial_fit_calls'] for k in info} | |
return self.fit(info) | |
# this ordering is important; typically r_i==1 when steps==0 | |
if self.steps == 0: | |
# we have r_i - 1 more steps to train to | |
self.steps = 1 | |
return {k: r_i - 1 for k in info} | |
keep_training = stop_on_plateau(info, | |
patience=self.patience, | |
tol=self.tol) | |
if sum(keep_training.values()) == 0: | |
return keep_training | |
info = {k: info[k] for k in keep_training} | |
best = toolz.topk(n_i, info, key=lambda k: info[k][-1]['score']) | |
self.steps += 1 | |
if len(best) in {0, 1} and self.steps > self.limit: | |
return {0: 0} | |
pf_calls = {k: info[k][-1]['partial_fit_calls'] for k in best} | |
addtl_pf_calls = {k: r_i - pf_calls[k] | |
for k in best} | |
return addtl_pf_calls | |
from sklearn.base import BaseEstimator | |
class _Constant(BaseEstimator): | |
def __init__(self, value=0, meta=None): | |
self.value = value | |
if meta is None: | |
meta = {} | |
self.meta = meta | |
super().__init__() | |
def partial_fit(self, *args, **kwargs): | |
pass | |
return self | |
def score(self, *args, **kwargs): | |
return self.value | |
from sklearn.model_selection import ParameterSampler | |
from sklearn.base import clone | |
from dask_ml.model_selection._incremental import fit | |
from dask_ml.datasets import make_classification | |
from distributed import Client | |
if __name__ == "__main__": | |
client = Client('localhost:8786') | |
X, y = make_classification(n_features=5, n_samples=200, chunks=10) | |
R = 100 | |
eta = 3.0 | |
# def hyperband(R, eta=3): | |
info = _hyperband_paper_alg(R, eta=eta) | |
# Because we call `partial_fit` before | |
for i in info: | |
i['iters'].update({1}) | |
sh_info = [] | |
s_max = math.floor(math.log(R, eta)) | |
B = (s_max + 1) * R | |
for s in reversed(np.arange(s_max + 1)): | |
n = np.ceil(B / R * eta**s / (s + 1)) | |
r = np.floor(R * eta**-s) | |
alg = SHA(n, r, limit=s+1) | |
model = _Constant() | |
params = {'value': np.linspace(0, 1, num=1000)} | |
params_list = list(ParameterSampler(params, n)) | |
_, _, hist = fit(model, params_list, X, y, X, y, alg.fit) | |
ids = {h['model_id'] for h in hist} | |
info_hist = {i: [] for i in ids} | |
for h in hist: | |
info_hist[h['model_id']] += [h] | |
hist = info_hist | |
calls = {k: max(hi['partial_fit_calls'] for hi in h) | |
for k, h in hist.items()} | |
iters = {hi['partial_fit_calls'] for h in hist.values() for hi in h} | |
sh_info += [{'bracket': f'bracket={s}', | |
'iters': iters, | |
'num_models': len(hist), | |
'num_partial_fit_calls': sum(calls.values())}] | |
assert sh_info == info |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.datasets import mnist | |
import numpy as np | |
import skimage.util | |
import random | |
import skimage.filters | |
import skimage | |
import scipy.signal | |
def noise_img(x): | |
noises = [ | |
{"mode": "s&p", "amount": np.random.uniform(0.0, 0.2)}, | |
{"mode": "gaussian", "var": np.random.uniform(0.0, 0.15)}, | |
] | |
# noise = random.choice(noises) | |
noise = noises[1] | |
return skimage.util.random_noise(x, **noise) | |
def train_formatting(img): | |
img = img.reshape(28, 28).astype("float32") | |
return img.flat[:] | |
def blur_img(img): | |
assert img.ndim == 1 | |
n = int(np.sqrt(img.shape[0])) | |
img = img.reshape(n, n) | |
h = np.zeros((n, n)) | |
angle = np.random.uniform(-5, 5) | |
w = random.choice(range(1, 3)) | |
h[n // 2, n // 2 - w : n // 2 + w] = 1 | |
h = skimage.transform.rotate(h, angle) | |
h /= h.sum() | |
y = scipy.signal.convolve(img, h, mode="same") | |
return y.flat[:] | |
def dataset(n=None): | |
(x_train, _), (x_test, _) = mnist.load_data() | |
x = np.concatenate((x_train, x_test)) | |
if n: | |
x = x[:n] | |
else: | |
n = int(70e3) | |
x = x.astype("float32") / 255. | |
x = np.reshape(x, (len(x), 28 * 28)) | |
y = np.apply_along_axis(train_formatting, 1, x) | |
clean = y.copy() | |
noisy = y.copy() | |
# order = [noise_img, blur_img] | |
# order = [blur_img] | |
order = [noise_img] | |
random.shuffle(order) | |
for fn in order: | |
noisy = np.apply_along_axis(fn, 1, noisy) | |
noisy = noisy.astype("float32") | |
clean = clean.astype("float32") | |
# noisy = noisy.reshape(-1, 1, 28, 28).astype("float32") | |
# clean = clean.reshape(-1, 1, 28, 28).astype("float32") | |
return noisy, clean |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment