Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Last active January 23, 2021 08:09
Show Gist options
  • Save williamFalcon/146a016187c9a2edbd2f97a63cbc2d0d to your computer and use it in GitHub Desktop.
Save williamFalcon/146a016187c9a2edbd2f97a63cbc2d0d to your computer and use it in GitHub Desktop.
from transformers import BertModel
import torch.nn.functional as F
class BertMNLIFinetuner(pl.LightningModule):
def __init__(self):
super(BertMNLIFinetuner, self).__init__()
# use pretrained BERT
self.bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)
# fine tuner (3 classes)
self.W = nn.Linear(bert.config.hidden_size, 3)
self.num_classes = 3
def forward(self, input_ids, attention_mask, token_type_ids):
h, _, attn = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
h_cls = h[:, 0]
logits = self.W(h_cls)
return logits, attn
def training_step(self, batch, batch_nb):
# batch
input_ids, attention_mask, token_type_ids, label = batch
# fwd
y_hat, attn = self.forward(input_ids, attention_mask, token_type_ids)
# loss
loss = F.cross_entropy(y_hat, label)
# logs
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment