Created
April 24, 2020 10:07
-
-
Save iacolippo/3f815fa90c254f7a065bdc446406233a to your computer and use it in GitHub Desktop.
Second example of using ray and deep_architect together - deepcopy protocol error
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import ray | |
from ray import tune | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.utils.data import DataLoader, TensorDataset | |
import deep_architect.modules as mo | |
import deep_architect.hyperparameters as hp | |
import deep_architect.helpers.pytorch_support as hpt | |
import deep_architect.searchers.random as se | |
from deep_architect.contrib.misc.datasets.loaders import load_mnist | |
D = hp.Discrete | |
def dense(h_units): | |
def compile_fn(di, dh): | |
(_, in_features) = di['in'].size() | |
m = nn.Linear(in_features, dh['units']) | |
def fn(di): | |
return {'out': m(di['in'])} | |
return fn, [m] | |
return hpt.siso_pytorch_module('Dense', compile_fn, {'units': h_units}) | |
def nonlinearity(h_nonlin_name): | |
def Nonlinearity(nonlin_name): | |
if nonlin_name == 'relu': | |
m = nn.ReLU() | |
elif nonlin_name == 'tanh': | |
m = nn.Tanh() | |
elif nonlin_name == 'elu': | |
m = nn.ELU() | |
else: | |
raise ValueError | |
return m | |
return hpt.siso_pytorch_module_from_pytorch_layer_fn( | |
Nonlinearity, {'nonlin_name': h_nonlin_name}) | |
def dropout(h_drop_rate): | |
return hpt.siso_pytorch_module_from_pytorch_layer_fn( | |
nn.Dropout, {'p': h_drop_rate}) | |
def batch_normalization(): | |
def compile_fn(di, dh): | |
(_, in_features) = di['in'].size() | |
bn = nn.BatchNorm1d(in_features) | |
def fn(di): | |
return {'out': bn(di['in'])} | |
return fn, [bn] | |
return hpt.siso_pytorch_module('BatchNormalization', compile_fn, {}) | |
def dnn_cell(h_num_hidden, h_nonlin_name, h_swap, h_opt_drop, h_opt_bn, | |
h_drop_rate): | |
return mo.siso_sequential([ | |
dense(h_num_hidden), | |
nonlinearity(h_nonlin_name), | |
mo.siso_permutation([ | |
lambda: mo.siso_optional(lambda: dropout(h_drop_rate), h_opt_drop), | |
lambda: mo.siso_optional(batch_normalization, h_opt_bn), | |
], h_swap) | |
]) | |
def dnn_net(num_classes): | |
h_nonlin_name = D(['relu', 'tanh', 'elu']) | |
h_swap = D([0, 1]) | |
h_opt_drop = D([0, 1]) | |
h_opt_bn = D([0, 1]) | |
return mo.siso_sequential([ | |
mo.siso_repeat( | |
lambda: dnn_cell(D([64, 128, 256, 512, 1024]), | |
h_nonlin_name, h_swap, h_opt_drop, h_opt_bn, | |
D([0.25, 0.5, 0.75])), D([1, 2, 4])), | |
dense(D([num_classes])) | |
]) | |
def get_dataloaders(batch_size): | |
(X_train, y_train, X_val, y_val, X_test, y_test) = load_mnist(flatten=True, | |
one_hot=False) | |
train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train)) | |
val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.LongTensor(y_val)) | |
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
val_dataloader = DataLoader(val_dataset, batch_size=batch_size) | |
return train_dataloader, val_dataloader | |
def sample_model(in_features, num_classes): | |
search_space_fn = lambda: dnn_net(num_classes) | |
searcher = se.RandomSearcher(search_space_fn) | |
inputs, outputs, _, searcher_eval_token = searcher.sample() | |
in_tensor = torch.zeros(128, in_features) | |
return hpt.PyTorchModel(inputs, outputs, {'in': in_tensor}) | |
class SimpleClassifierTrainable(tune.Trainable): | |
def _setup(self, config): | |
use_cuda = torch.cuda.is_available() | |
self.device = torch.device("cuda" if use_cuda else "cpu") | |
self.batch_size = config["batch_size"] | |
self.learning_rate = config.get("lr", 0.01) | |
self.train_loader, self.val_loader = get_dataloaders(self.batch_size) | |
self.model = config["model"].to(self.device) | |
self.criterion = nn.CrossEntropyLoss() | |
self.optimizer = optim.Adam(self.model.parameters(), | |
lr=self.learning_rate) | |
def _train(self): | |
self.model.train() | |
correct = 0 | |
train_samples = 0 | |
for data, target in self.train_loader: | |
data, target = data.to(self.device), target.to(self.device) | |
self.optimizer.zero_grad() | |
output = self.model({'in': data}) | |
pred = output["out"].data.max(1)[1] | |
correct += pred.eq(target).sum().item() | |
train_samples += target.numel() | |
loss = F.cross_entropy(output["out"], target) | |
loss.backward() | |
self.optimizer.step() | |
train_acc = float(correct) / train_samples | |
# compute validation accuracy | |
self.model.eval() | |
correct = 0 | |
val_samples = 0 | |
with torch.no_grad(): | |
for data, target in self.val_loader: | |
data, target = data.to(self.device), target.to(self.device) | |
output = self.model({'in': data}) | |
pred = output["out"].data.max(1)[1] | |
correct += pred.eq(target).sum().item() | |
val_samples += target.numel() | |
val_acc = float(correct) / val_samples | |
print("validation accuracy: %0.4f" % val_acc) | |
return {'training_accuracy': train_acc, 'validation_accuracy': val_acc} | |
def _save(self, checkpoint_dir): | |
checkpoint_path = os.path.join(checkpoint_dir, "model.pth") | |
torch.save(self.model.state_dict(), checkpoint_path) | |
return checkpoint_path | |
def _restore(self, checkpoint_path): | |
self.model.load_state_dict(torch.load(checkpoint_path)) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser("Architecture search example") | |
parser.add_argument("--asha", action="store_true", help="If provided, will use ASHA scheduler (default FIFO)") | |
args = parser.parse_args() | |
ray.init() | |
if args.asha: | |
print("Using ASHA scheduler") | |
from ray.tune.schedulers import ASHAScheduler | |
scheduler = ASHAScheduler(metric="train_mean_accuracy") | |
else: | |
from ray.tune.schedulers import FIFOScheduler | |
scheduler = FIFOScheduler() | |
analysis = tune.run( | |
SimpleClassifierTrainable, | |
scheduler=scheduler, | |
stop={ | |
"training_accuracy": 0.95, | |
"training_iteration": 2, | |
}, | |
resources_per_trial={ | |
"cpu": 3, | |
"gpu": 0.25 | |
}, | |
num_samples=4, | |
config={ | |
"model": tune.sample_from(lambda spec: sample_model(in_features=784, num_classes=10)), | |
"lr": 1e-3, | |
"batch_size": 128 | |
}) | |
print("Best config is:", analysis.get_best_config(metric="validation_accuracy")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment