Last active
April 24, 2020 10:06
-
-
Save iacolippo/1262c8afbfd9f5e491add5fbae105afa to your computer and use it in GitHub Desktop.
First example of using ray and deep_architect together - ray logging does not work
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import ray | |
from ray import tune | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.utils.data import DataLoader, TensorDataset | |
import deep_architect.modules as mo | |
import deep_architect.hyperparameters as hp | |
import deep_architect.helpers.pytorch_support as hpt | |
import deep_architect.searchers.random as se | |
from deep_architect.contrib.misc.datasets.loaders import load_mnist | |
D = hp.Discrete | |
def dense(h_units): | |
def compile_fn(di, dh): | |
(_, in_features) = di['in'].size() | |
m = nn.Linear(in_features, dh['units']) | |
def fn(di): | |
return {'out': m(di['in'])} | |
return fn, [m] | |
return hpt.siso_pytorch_module('Dense', compile_fn, {'units': h_units}) | |
def nonlinearity(h_nonlin_name): | |
def Nonlinearity(nonlin_name): | |
if nonlin_name == 'relu': | |
m = nn.ReLU() | |
elif nonlin_name == 'tanh': | |
m = nn.Tanh() | |
elif nonlin_name == 'elu': | |
m = nn.ELU() | |
else: | |
raise ValueError | |
return m | |
return hpt.siso_pytorch_module_from_pytorch_layer_fn( | |
Nonlinearity, {'nonlin_name': h_nonlin_name}) | |
def dropout(h_drop_rate): | |
return hpt.siso_pytorch_module_from_pytorch_layer_fn( | |
nn.Dropout, {'p': h_drop_rate}) | |
def batch_normalization(): | |
def compile_fn(di, dh): | |
(_, in_features) = di['in'].size() | |
bn = nn.BatchNorm1d(in_features) | |
def fn(di): | |
return {'out': bn(di['in'])} | |
return fn, [bn] | |
return hpt.siso_pytorch_module('BatchNormalization', compile_fn, {}) | |
def dnn_cell(h_num_hidden, h_nonlin_name, h_swap, h_opt_drop, h_opt_bn, | |
h_drop_rate): | |
return mo.siso_sequential([ | |
dense(h_num_hidden), | |
nonlinearity(h_nonlin_name), | |
mo.siso_permutation([ | |
lambda: mo.siso_optional(lambda: dropout(h_drop_rate), h_opt_drop), | |
lambda: mo.siso_optional(batch_normalization, h_opt_bn), | |
], h_swap) | |
]) | |
def dnn_net(num_classes): | |
h_nonlin_name = D(['relu', 'tanh', 'elu']) | |
h_swap = D([0, 1]) | |
h_opt_drop = D([0, 1]) | |
h_opt_bn = D([0, 1]) | |
return mo.siso_sequential([ | |
mo.siso_repeat( | |
lambda: dnn_cell(D([64, 128, 256, 512, 1024]), | |
h_nonlin_name, h_swap, h_opt_drop, h_opt_bn, | |
D([0.25, 0.5, 0.75])), D([1, 2, 4])), | |
dense(D([num_classes])) | |
]) | |
def get_dataloaders(batch_size): | |
(X_train, y_train, X_val, y_val, X_test, y_test) = load_mnist(flatten=True, | |
one_hot=False) | |
train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train)) | |
val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.LongTensor(y_val)) | |
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
val_dataloader = DataLoader(val_dataset, batch_size=batch_size) | |
return train_dataloader, val_dataloader | |
def sample_model(in_features, num_classes): | |
search_space_fn = lambda: dnn_net(num_classes) | |
searcher = se.RandomSearcher(search_space_fn) | |
inputs, outputs, _, searcher_eval_token = searcher.sample() | |
in_tensor = torch.zeros(128, in_features) | |
def model(): return hpt.PyTorchModel(inputs, outputs, {'in': in_tensor}) | |
return model | |
class SimpleClassifierTrainable(tune.Trainable): | |
def _setup(self, config): | |
use_cuda = torch.cuda.is_available() | |
self.device = torch.device("cuda" if use_cuda else "cpu") | |
self.batch_size = config["batch_size"] | |
self.learning_rate = config.get("lr", 0.01) | |
self.train_loader, self.val_loader = get_dataloaders(self.batch_size) | |
self.model = config["model"]().to(self.device) | |
self.criterion = nn.CrossEntropyLoss() | |
self.optimizer = optim.Adam(self.model.parameters(), | |
lr=self.learning_rate) | |
def _train(self): | |
self.model.train() | |
correct = 0 | |
train_samples = 0 | |
for data, target in self.train_loader: | |
data, target = data.to(self.device), target.to(self.device) | |
self.optimizer.zero_grad() | |
output = self.model({'in': data}) | |
pred = output["out"].data.max(1)[1] | |
correct += pred.eq(target).sum().item() | |
train_samples += target.numel() | |
loss = F.cross_entropy(output["out"], target) | |
loss.backward() | |
self.optimizer.step() | |
train_acc = float(correct) / train_samples | |
# compute validation accuracy | |
self.model.eval() | |
correct = 0 | |
val_samples = 0 | |
with torch.no_grad(): | |
for data, target in self.val_loader: | |
data, target = data.to(self.device), target.to(self.device) | |
output = self.model({'in': data}) | |
pred = output["out"].data.max(1)[1] | |
correct += pred.eq(target).sum().item() | |
val_samples += target.numel() | |
val_acc = float(correct) / val_samples | |
print("validation accuracy: %0.4f" % val_acc) | |
return {'training_accuracy': train_acc, 'validation_accuracy': val_acc} | |
def _save(self, checkpoint_dir): | |
checkpoint_path = os.path.join(checkpoint_dir, "model.pth") | |
torch.save(self.model.state_dict(), checkpoint_path) | |
return checkpoint_path | |
def _restore(self, checkpoint_path): | |
self.model.load_state_dict(torch.load(checkpoint_path)) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser("Architecture search example") | |
parser.add_argument("--asha", action="store_true", help="If provided, will use ASHA scheduler (default FIFO)") | |
args = parser.parse_args() | |
ray.init() | |
if args.asha: | |
print("Using ASHA scheduler") | |
from ray.tune.schedulers import ASHAScheduler | |
scheduler = ASHAScheduler(metric="train_mean_accuracy") | |
else: | |
from ray.tune.schedulers import FIFOScheduler | |
scheduler = FIFOScheduler() | |
analysis = tune.run( | |
SimpleClassifierTrainable, | |
scheduler=scheduler, | |
stop={ | |
"training_accuracy": 0.95, | |
"training_iteration": 2, | |
}, | |
resources_per_trial={ | |
"cpu": 3, | |
"gpu": 0.25 | |
}, | |
num_samples=4, | |
config={ | |
"model": tune.sample_from(lambda spec: sample_model(in_features=784, num_classes=10)), | |
"lr": 1e-3, | |
"batch_size": 128 | |
}) | |
print("Best config is:", analysis.get_best_config(metric="validation_accuracy")) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment