Skip to content

Instantly share code, notes, and snippets.

View williamFalcon's full-sized avatar
🎯
Focusing

William Falcon williamFalcon

🎯
Focusing
View GitHub Profile
trainer = Trainer(gpus=4, nb_gpu_nodes=4, distributed_backend='ddp')
trainer.fit(model)
trainer = Trainer(gpus=4, distributed_backend='dp')
trainer.fit(model)
from pytorch_lightning import Trainer
model = MNISTExample()
# most basic trainer, uses good defaults
trainer = Trainer()
trainer.fit(model)
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from sklearn.metrics import accuracy_score
from pytorch_lightning import Trainer
# Tutorial from: https://keras.io/examples/mnist_cnn/
# ------------
# SECTION 1
# ------------
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from transformers import BertModel
import torch.nn.functional as F
class BertMNLIFinetuner(pl.LightningModule):
def __init__(self):
super(BertMNLIFinetuner, self).__init__()
# use pretrained BERT
def configure_optimizers(self):
return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
@pl.data_loader
def train_dataloader(self):
return bert_mnli_train_dataloader
@pl.data_loader
def val_dataloader(self):
def test_step(self, batch, batch_nb):
input_ids, attention_mask, token_type_ids, label = batch
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids)
a, y_hat = torch.max(y_hat, dim=1)
test_acc = accuracy_score(y_hat.cpu(), label.cpu())
return {'test_acc': torch.tensor(test_acc)}
def validation_step(self, batch, batch_nb):
# batch
input_ids, attention_mask, token_type_ids, label = batch
# fwd
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids)
# loss
loss = F.cross_entropy(y_hat, label)
from transformers import BertModel
import torch.nn.functional as F
class BertMNLIFinetuner(pl.LightningModule):
def __init__(self):
super(BertMNLIFinetuner, self).__init__()
# use pretrained BERT