Created
July 14, 2020 17:27
-
-
Save lgray/2d9a035378c2bce91bed7d54fcd02b8d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import os.path as osp | |
| import math | |
| import numpy as np | |
| import torch | |
| import gc | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch_geometric.transforms as T | |
| from torch.utils.checkpoint import checkpoint | |
| from torch_cluster import knn_graph | |
| from torch_geometric.nn import EdgeConv, NNConv | |
| from torch_geometric.nn.pool.edge_pool import EdgePooling | |
| from torch_geometric.utils import normalized_cut | |
| from torch_geometric.utils import remove_self_loops | |
| from torch_geometric.utils.undirected import to_undirected | |
| from torch_geometric.nn import (graclus, max_pool, max_pool_x, | |
| global_mean_pool, global_max_pool, | |
| global_add_pool) | |
| transform = T.Cartesian(cat=False) | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| import math | |
| import torch | |
| import sys | |
| class ReduceMaxLROnRestart: | |
| def __init__(self, ratio=0.75): | |
| self.ratio = ratio | |
| def __call__(self, eta_min, eta_max): | |
| return eta_min, eta_max * self.ratio | |
| class ExpReduceMaxLROnIteration: | |
| def __init__(self, gamma=1): | |
| self.gamma = gamma | |
| def __call__(self, eta_min, eta_max, iterations): | |
| return eta_min, eta_max * self.gamma ** iterations | |
| class CosinePolicy: | |
| def __call__(self, t_cur, restart_period): | |
| return 0.5 * (1. + math.cos(math.pi * | |
| (t_cur / restart_period))) | |
| class ArccosinePolicy: | |
| def __call__(self, t_cur, restart_period): | |
| return (math.acos(max(-1, min(1, 2 * t_cur | |
| / restart_period - 1))) / math.pi) | |
| class TriangularPolicy: | |
| def __init__(self, triangular_step=0.5): | |
| self.triangular_step = triangular_step | |
| def __call__(self, t_cur, restart_period): | |
| inflection_point = self.triangular_step * restart_period | |
| point_of_triangle = (t_cur / inflection_point | |
| if t_cur < inflection_point | |
| else 1.0 - (t_cur - inflection_point) | |
| / (restart_period - inflection_point)) | |
| return point_of_triangle | |
| class CyclicLRWithRestarts(_LRScheduler): | |
| """Decays learning rate with cosine annealing, normalizes weight decay | |
| hyperparameter value, implements restarts. | |
| https://arxiv.org/abs/1711.05101 | |
| Args: | |
| optimizer (Optimizer): Wrapped optimizer. | |
| batch_size: minibatch size | |
| epoch_size: training samples per epoch | |
| restart_period: epoch count in the first restart period | |
| t_mult: multiplication factor by which the next restart period will expand/shrink | |
| policy: ["cosine", "arccosine", "triangular", "triangular2", "exp_range"] | |
| min_lr: minimum allowed learning rate | |
| verbose: print a message on every restart | |
| gamma: exponent used in "exp_range" policy | |
| eta_on_restart_cb: callback executed on every restart, adjusts max or min lr | |
| eta_on_iteration_cb: callback executed on every iteration, adjusts max or min lr | |
| triangular_step: adjusts ratio of increasing/decreasing phases for triangular policy | |
| Example: | |
| >>> scheduler = CyclicLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2) | |
| >>> for epoch in range(100): | |
| >>> scheduler.step() | |
| >>> train(...) | |
| >>> ... | |
| >>> optimizer.zero_grad() | |
| >>> loss.backward() | |
| >>> optimizer.step() | |
| >>> scheduler.batch_step() | |
| >>> validate(...) | |
| """ | |
| def __init__(self, optimizer, batch_size, epoch_size, restart_period=100, | |
| t_mult=2, last_epoch=-1, verbose=False, | |
| policy="cosine", policy_fn=None, min_lr=1e-7, | |
| eta_on_restart_cb=None, eta_on_iteration_cb=None, | |
| gamma=1.0, triangular_step=0.5): | |
| if not isinstance(optimizer, Optimizer): | |
| raise TypeError('{} is not an Optimizer'.format( | |
| type(optimizer).__name__)) | |
| self.optimizer = optimizer | |
| if last_epoch == -1: | |
| for group in optimizer.param_groups: | |
| group.setdefault('initial_lr', group['lr']) | |
| group.setdefault('minimum_lr', min_lr) | |
| else: | |
| for i, group in enumerate(optimizer.param_groups): | |
| if 'initial_lr' not in group: | |
| raise KeyError("param 'initial_lr' is not specified " | |
| "in param_groups[{}] when resuming an" | |
| " optimizer".format(i)) | |
| self.base_lrs = [group['initial_lr'] for group | |
| in optimizer.param_groups] | |
| self.min_lrs = [group['minimum_lr'] for group | |
| in optimizer.param_groups] | |
| self.base_weight_decays = [group['weight_decay'] for group | |
| in optimizer.param_groups] | |
| self.policy = policy | |
| self.eta_on_restart_cb = eta_on_restart_cb | |
| self.eta_on_iteration_cb = eta_on_iteration_cb | |
| if policy_fn is not None: | |
| self.policy_fn = policy_fn | |
| elif self.policy == "cosine": | |
| self.policy_fn = CosinePolicy() | |
| elif self.policy == "arccosine": | |
| self.policy_fn = ArccosinePolicy() | |
| elif self.policy == "triangular": | |
| self.policy_fn = TriangularPolicy(triangular_step=triangular_step) | |
| elif self.policy == "triangular2": | |
| self.policy_fn = TriangularPolicy(triangular_step=triangular_step) | |
| self.eta_on_restart_cb = ReduceMaxLROnRestart(ratio=0.5) | |
| elif self.policy == "exp_range": | |
| self.policy_fn = TriangularPolicy(triangular_step=triangular_step) | |
| self.eta_on_iteration_cb = ExpReduceMaxLROnIteration(gamma=gamma) | |
| self.last_epoch = last_epoch | |
| self.batch_size = batch_size | |
| self.epoch_size = epoch_size | |
| self.iteration = 0 | |
| self.total_iterations = 0 | |
| self.t_mult = t_mult | |
| self.verbose = verbose | |
| self.restart_period = math.ceil(restart_period) | |
| self.restarts = 0 | |
| self.t_epoch = -1 | |
| self.epoch = -1 | |
| self.eta_min = 0 | |
| self.eta_max = 1 | |
| self.end_of_period = False | |
| self.batch_increments = [] | |
| self._set_batch_increment() | |
| def _on_restart(self): | |
| if self.eta_on_restart_cb is not None: | |
| self.eta_min, self.eta_max = self.eta_on_restart_cb(self.eta_min, | |
| self.eta_max) | |
| def _on_iteration(self): | |
| if self.eta_on_iteration_cb is not None: | |
| self.eta_min, self.eta_max = self.eta_on_iteration_cb(self.eta_min, | |
| self.eta_max, | |
| self.total_iterations) | |
| def get_lr(self, t_cur): | |
| eta_t = (self.eta_min + (self.eta_max - self.eta_min) | |
| * self.policy_fn(t_cur, self.restart_period)) | |
| weight_decay_norm_multi = math.sqrt(self.batch_size / | |
| (self.epoch_size * | |
| self.restart_period)) | |
| lrs = [min_lr + (base_lr - min_lr) * eta_t for base_lr, min_lr | |
| in zip(self.base_lrs, self.min_lrs)] | |
| weight_decays = [base_weight_decay #* eta_t * weight_decay_norm_multi | |
| for base_weight_decay in self.base_weight_decays] | |
| if (self.t_epoch + 1) % self.restart_period < self.t_epoch: | |
| self.end_of_period = True | |
| if self.t_epoch % self.restart_period < self.t_epoch: | |
| if self.verbose: | |
| print("Restart {} at epoch {}".format(self.restarts + 1, | |
| self.last_epoch)) | |
| self.restart_period = math.ceil(self.restart_period * self.t_mult) | |
| self.restarts += 1 | |
| self.t_epoch = 0 | |
| self._on_restart() | |
| self.end_of_period = False | |
| return zip(lrs, weight_decays) | |
| def _set_batch_increment(self): | |
| d, r = divmod(self.epoch_size, self.batch_size) | |
| batches_in_epoch = d + 2 if r > 0 else d + 1 | |
| self.iteration = 0 | |
| self.batch_increments = torch.linspace(0, 1, batches_in_epoch).tolist() | |
| def step(self): | |
| self.last_epoch += 1 | |
| self.t_epoch += 1 | |
| self._set_batch_increment() | |
| self.batch_step() | |
| def batch_step(self): | |
| try: | |
| t_cur = self.t_epoch + self.batch_increments[self.iteration] | |
| self._on_iteration() | |
| self.iteration += 1 | |
| self.total_iterations += 1 | |
| except (IndexError): | |
| raise StopIteration("Epoch size and batch size used in the " | |
| "training loop and while initializing " | |
| "scheduler should be the same.") | |
| for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups, | |
| self.get_lr(t_cur)): | |
| param_group['lr'] = lr | |
| param_group['weight_decay'] = weight_decay | |
| def normalized_cut_2d(edge_index, pos): | |
| row, col = edge_index | |
| edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1) | |
| return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0)) | |
| class DynamicReductionNetwork(nn.Module): | |
| # This model clusters nearest neighbour graphs | |
| # in two steps. | |
| # The latent space trained to group useful features at each level | |
| # of aggregration. | |
| # This allows single quantities to be regressed from complex point counts | |
| # in a location and orientation invariant way. | |
| # One encoding layer is used to abstract away the input features. | |
| def __init__(self, input_dim=5, hidden_dim=64, output_dim=1, k=16, aggr='add', | |
| norm=torch.tensor([1./500., 1./500., 1./54., 1/25., 1./1000.])): | |
| super(DynamicReductionNetwork, self).__init__() | |
| self.datanorm = nn.Parameter(norm) | |
| self.k = k | |
| start_width = 2 * hidden_dim | |
| middle_width = 3 * hidden_dim // 2 | |
| self.inputnet = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim//2), | |
| nn.ELU(), | |
| nn.Linear(hidden_dim//2, hidden_dim), | |
| nn.ELU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ELU() | |
| ) | |
| convnn1 = nn.Sequential(nn.Linear(start_width, middle_width), | |
| nn.ELU(), | |
| nn.Linear(middle_width, hidden_dim), | |
| nn.ELU(), | |
| ) | |
| convnn2 = nn.Sequential(nn.Linear(start_width, middle_width), | |
| nn.ELU(), | |
| nn.Linear(middle_width, hidden_dim), | |
| nn.ELU(), | |
| ) | |
| self.edgeconv1 = EdgeConv(nn=convnn1, aggr=aggr) | |
| self.edgeconv2 = EdgeConv(nn=convnn2, aggr=aggr) | |
| self.output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), | |
| nn.ELU(), | |
| nn.Linear(hidden_dim, hidden_dim//2), | |
| nn.ELU(), | |
| nn.Linear(hidden_dim//2, output_dim)) | |
| def forward(self, data): | |
| data.x = self.datanorm * data.x | |
| data.x = self.inputnet(data.x) | |
| data.edge_index = to_undirected(knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow)) | |
| data.x = self.edgeconv1(data.x, data.edge_index) | |
| weight = normalized_cut_2d(data.edge_index, data.x) | |
| cluster = graclus(data.edge_index, weight, data.x.size(0)) | |
| data.edge_attr = None | |
| data = max_pool(cluster, data) | |
| data.edge_index = to_undirected(knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv2.flow)) | |
| data.x = self.edgeconv2(data.x, data.edge_index) | |
| weight = normalized_cut_2d(data.edge_index, data.x) | |
| cluster = graclus(data.edge_index, weight, data.x.size(0)) | |
| x, batch = max_pool_x(cluster, data.x, data.batch) | |
| x = global_max_pool(x, batch) | |
| return self.output(x).squeeze(-1) | |
| import os.path as osp | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.datasets import MNISTSuperpixels | |
| import torch_geometric.transforms as T | |
| from torch_geometric.data import DataLoader | |
| import warnings | |
| warnings.simplefilter('ignore') | |
| batch_size = 256 | |
| path = osp.join('./', '..', 'data', 'MNIST') | |
| transform = T.Cartesian(cat=False) | |
| train_dataset = MNISTSuperpixels(path, True, transform=transform) | |
| test_dataset = MNISTSuperpixels(path, False, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| epoch_size = len(train_dataset) | |
| print(epoch_size, batch_size) | |
| d = train_dataset | |
| print('features ->', d.num_features) | |
| print('classes ->',d.num_classes) | |
| hidden_dim = int(sys.argv[1]) | |
| print('hidden_dim = %d' % hidden_dim) | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super(Net, self).__init__() | |
| self.drn = DynamicReductionNetwork(input_dim=3, hidden_dim=hidden_dim, | |
| k=4, | |
| output_dim=d.num_classes, aggr='add', | |
| norm=torch.tensor([1., 1./27., 1./27.])) | |
| def forward(self, data): | |
| logits = self.drn(data) | |
| return F.log_softmax(logits, dim=1) | |
| def print_model_summary(model): | |
| """Override as needed""" | |
| print( | |
| 'Model: \n%s\nParameters: %i' % | |
| (model, sum(p.numel() for p in model.parameters())) | |
| ) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = Net().to(device) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3) | |
| scheduler = CyclicLRWithRestarts(optimizer, batch_size, epoch_size, restart_period=400, t_mult=1.2, policy="cosine") | |
| print_model_summary(model) | |
| def train(epoch): | |
| model.train() | |
| scheduler.step() | |
| for data in train_loader: | |
| data = data.to(device) | |
| mask = (data.x > 0.).squeeze() | |
| data.x = torch.cat([data.x, data.pos], dim=-1) | |
| data.x = data.x[mask,:] | |
| #print(data.x) | |
| data.pos = data.pos[mask,:] | |
| data.batch = data.batch[mask.squeeze()] | |
| optimizer.zero_grad() | |
| result = model(data) | |
| loss = F.nll_loss(result, data.y) | |
| loss.backward() | |
| #print(torch.unique(torch.argmax(result, dim=-1))) | |
| #print(torch.unique(data.y)) | |
| optimizer.step() | |
| scheduler.batch_step() | |
| def test(): | |
| model.eval() | |
| correct = 0 | |
| for data in test_loader: | |
| data = data.to(device) | |
| mask = (data.x > 0.).squeeze() | |
| data.x = torch.cat([data.x, data.pos], dim=-1) | |
| data.x = data.x[mask,:] | |
| #print(data.x) | |
| data.pos = data.pos[mask,:] | |
| data.batch = data.batch[mask.squeeze()] | |
| pred = model(data).max(1)[1] | |
| correct += pred.eq(data.y).sum().item() | |
| return correct / len(test_dataset) | |
| for epoch in range(1, 401): | |
| train(epoch) | |
| test_acc = test() | |
| print('Epoch: {:02d}, Test: {:.4f}'.format(epoch, test_acc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment