Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Last active January 23, 2021 08:06
Show Gist options
  • Save williamFalcon/5fdeb63e29b6f101d739742815a55d5b to your computer and use it in GitHub Desktop.
Save williamFalcon/5fdeb63e29b6f101d739742815a55d5b to your computer and use it in GitHub Desktop.
from transformers import BertModel
import torch.nn.functional as F
class BertMNLIFinetuner(pl.LightningModule):
def __init__(self):
super(BertMNLIFinetuner, self).__init__()
# use pretrained BERT
self.bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)
# fine tuner (3 classes)
self.W = nn.Linear(bert.config.hidden_size, 3)
self.num_classes = 3
def forward(self, input_ids, attention_mask, token_type_ids):
h, _, attn = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
h_cls = h[:, 0]
logits = self.W(h_cls)
return logits, attn
def training_step(self, batch, batch_nb):
# batch
input_ids, attention_mask, token_type_ids, label = batch
# fwd
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids)
# loss
loss = F.cross_entropy(y_hat, label)
# logs
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_nb):
# batch
input_ids, attention_mask, token_type_ids, label = batch
# fwd
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids)
# loss
loss = F.cross_entropy(y_hat, label)
# acc
a, y_hat = torch.max(y_hat, dim=1)
val_acc = accuracy_score(y_hat.cpu(), label.cpu())
val_acc = torch.tensor(val_acc)
return {'val_loss': loss, 'val_acc': val_acc}
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss, 'avg_val_acc': avg_val_acc}
return {'avg_val_loss': avg_loss, 'progress_bar': tensorboard_logs}
def test_step(self, batch, batch_nb):
input_ids, attention_mask, token_type_ids, label = batch
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids)
a, y_hat = torch.max(y_hat, dim=1)
test_acc = accuracy_score(y_hat.cpu(), label.cpu())
return {'test_acc': torch.tensor(test_acc)}
def test_end(self, outputs):
avg_test_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
tensorboard_logs = {'avg_test_acc': avg_test_acc}
return {'avg_test_acc': avg_test_acc, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
@pl.data_loader
def train_dataloader(self):
return bert_mnli_train_dataloader
@pl.data_loader
def val_dataloader(self):
return bert_mnli_val_dataloader
@pl.data_loader
def test_dataloader(self):
return bert_mnli_test_dataloader
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment