This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trainer = Trainer(gpus=4, nb_gpu_nodes=4, distributed_backend='ddp') | |
trainer.fit(model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trainer = Trainer(gpus=4, distributed_backend='dp') | |
trainer.fit(model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning import Trainer | |
model = MNISTExample() | |
# most basic trainer, uses good defaults | |
trainer = Trainer() | |
trainer.fit(model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from torchvision.datasets import MNIST | |
from torchvision import transforms | |
from sklearn.metrics import accuracy_score | |
from pytorch_lightning import Trainer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Tutorial from: https://keras.io/examples/mnist_cnn/ | |
# ------------ | |
# SECTION 1 | |
# ------------ | |
from __future__ import print_function | |
import keras | |
from keras.datasets import mnist | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout, Flatten | |
from keras.layers import Conv2D, MaxPooling2D |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import BertModel | |
import torch.nn.functional as F | |
class BertMNLIFinetuner(pl.LightningModule): | |
def __init__(self): | |
super(BertMNLIFinetuner, self).__init__() | |
# use pretrained BERT |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def configure_optimizers(self): | |
return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08) | |
@pl.data_loader | |
def train_dataloader(self): | |
return bert_mnli_train_dataloader | |
@pl.data_loader | |
def val_dataloader(self): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_step(self, batch, batch_nb): | |
input_ids, attention_mask, token_type_ids, label = batch | |
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids) | |
a, y_hat = torch.max(y_hat, dim=1) | |
test_acc = accuracy_score(y_hat.cpu(), label.cpu()) | |
return {'test_acc': torch.tensor(test_acc)} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def validation_step(self, batch, batch_nb): | |
# batch | |
input_ids, attention_mask, token_type_ids, label = batch | |
# fwd | |
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids) | |
# loss | |
loss = F.cross_entropy(y_hat, label) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import BertModel | |
import torch.nn.functional as F | |
class BertMNLIFinetuner(pl.LightningModule): | |
def __init__(self): | |
super(BertMNLIFinetuner, self).__init__() | |
# use pretrained BERT |