Last active
May 12, 2021 08:15
-
-
Save egrefen/19ad0b65cf4d997a4b5bebf6e98b0562 to your computer and use it in GitHub Desktop.
Train maml model with torchmeta and higher v0.2.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Based on the code in https://github.com/tristandeleu/pytorch-meta/tree/master/examples/maml | |
# Basically, we only use the dataset loaders/helpers from TorchMeta and replace usage of MetaModules | |
# with normal pytorch nn.Modules, letting higher deal with making the inner loop unrollable and the | |
# optimizers differentiable. This makes it easier to use another optimizer than SGD, or any arbitrary | |
# third-party model, when doing MAML using this codebase. | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
import logging | |
from collections import OrderedDict | |
import higher # tested with higher v0.2 | |
from torchmeta.datasets.helpers import omniglot | |
from torchmeta.utils.data import BatchMetaDataLoader | |
logger = logging.getLogger(__name__) | |
def conv3x3(in_channels, out_channels, **kwargs): | |
return nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs), | |
nn.BatchNorm2d(out_channels, momentum=1., track_running_stats=False), | |
nn.ReLU(), | |
nn.MaxPool2d(2) | |
) | |
class ConvolutionalNeuralNetwork(nn.Module): | |
def __init__(self, in_channels, out_features, hidden_size=64): | |
super(ConvolutionalNeuralNetwork, self).__init__() | |
self.in_channels = in_channels | |
self.out_features = out_features | |
self.hidden_size = hidden_size | |
self.features = nn.Sequential( | |
conv3x3(in_channels, hidden_size), | |
conv3x3(hidden_size, hidden_size), | |
conv3x3(hidden_size, hidden_size), | |
conv3x3(hidden_size, hidden_size) | |
) | |
self.classifier = nn.Linear(hidden_size, out_features) | |
def forward(self, inputs, params=None): | |
features = self.features(inputs) | |
features = features.view((features.size(0), -1)) | |
logits = self.classifier(features) | |
return logits | |
def get_accuracy(logits, targets): | |
"""Compute the accuracy (after adaptation) of MAML on the test/query points | |
Parameters | |
---------- | |
logits : `torch.FloatTensor` instance | |
Outputs/logits of the model on the query points. This tensor has shape | |
`(num_examples, num_classes)`. | |
targets : `torch.LongTensor` instance | |
A tensor containing the targets of the query points. This tensor has | |
shape `(num_examples,)`. | |
Returns | |
------- | |
accuracy : `torch.FloatTensor` instance | |
Mean accuracy on the query points | |
""" | |
_, predictions = torch.max(logits, dim=-1) | |
return torch.mean(predictions.eq(targets).float()) | |
def train(args): | |
logger.warning('This script is an example to showcase the data-loading ' | |
'features of Torchmeta in conjunction with using higher to ' | |
'make models "unrollable" and optimizers differentiable, ' | |
'and as such has been very lightly tested.') | |
dataset = omniglot(args.folder, | |
shots=args.num_shots, | |
ways=args.num_ways, | |
shuffle=True, | |
test_shots=15, | |
meta_train=True, | |
download=args.download) | |
dataloader = BatchMetaDataLoader(dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
num_workers=args.num_workers) | |
model = ConvolutionalNeuralNetwork(1, | |
args.num_ways, | |
hidden_size=args.hidden_size) | |
model.to(device=args.device) | |
model.train() | |
inner_optimiser = torch.optim.SGD(model.parameters(), lr=args.step_size) | |
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
# Training loop | |
with tqdm(dataloader, total=args.num_batches) as pbar: | |
for batch_idx, batch in enumerate(pbar): | |
model.zero_grad() | |
train_inputs, train_targets = batch['train'] | |
train_inputs = train_inputs.to(device=args.device) | |
train_targets = train_targets.to(device=args.device) | |
test_inputs, test_targets = batch['test'] | |
test_inputs = test_inputs.to(device=args.device) | |
test_targets = test_targets.to(device=args.device) | |
outer_loss = torch.tensor(0., device=args.device) | |
accuracy = torch.tensor(0., device=args.device) | |
for task_idx, (train_input, train_target, test_input, | |
test_target) in enumerate(zip(train_inputs, train_targets, | |
test_inputs, test_targets)): | |
with higher.innerloop_ctx(model, inner_optimiser, copy_initial_weights=False) as (fmodel, diffopt): | |
train_logit = fmodel(train_input) | |
inner_loss = F.cross_entropy(train_logit, train_target) | |
diffopt.step(inner_loss) | |
test_logit = fmodel(test_input) | |
outer_loss += F.cross_entropy(test_logit, test_target) | |
with torch.no_grad(): | |
accuracy += get_accuracy(test_logit, test_target) | |
outer_loss.div_(args.batch_size) | |
accuracy.div_(args.batch_size) | |
outer_loss.backward() | |
meta_optimizer.step() | |
pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item())) | |
if batch_idx >= args.num_batches: | |
break | |
# Save model | |
if args.output_folder is not None: | |
filename = os.path.join(args.output_folder, 'maml_omniglot_' | |
'{0}shot_{1}way.th'.format(args.num_shots, args.num_ways)) | |
with open(filename, 'wb') as f: | |
state_dict = model.state_dict() | |
torch.save(state_dict, f) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)') | |
parser.add_argument('folder', type=str, | |
help='Path to the folder the data is downloaded to.') | |
parser.add_argument('--num-shots', type=int, default=5, | |
help='Number of examples per class (k in "k-shot", default: 5).') | |
parser.add_argument('--num-ways', type=int, default=5, | |
help='Number of classes per task (N in "N-way", default: 5).') | |
parser.add_argument('--step-size', type=float, default=0.4, | |
help='Step-size for the gradient step for adaptation (default: 0.4).') | |
parser.add_argument('--hidden-size', type=int, default=64, | |
help='Number of channels for each convolutional layer (default: 64).') | |
parser.add_argument('--output-folder', type=str, default=None, | |
help='Path to the output folder for saving the model (optional).') | |
parser.add_argument('--batch-size', type=int, default=16, | |
help='Number of tasks in a mini-batch of tasks (default: 16).') | |
parser.add_argument('--num-batches', type=int, default=100, | |
help='Number of batches the model is trained over (default: 100).') | |
parser.add_argument('--num-workers', type=int, default=1, | |
help='Number of workers for data loading (default: 1).') | |
parser.add_argument('--download', action='store_true', | |
help='Download the Omniglot dataset in the data folder.') | |
parser.add_argument('--use-cuda', action='store_true', | |
help='Use CUDA if available.') | |
args = parser.parse_args() | |
args.device = torch.device('cuda' if args.use_cuda | |
and torch.cuda.is_available() else 'cpu') | |
train(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment