Skip to content

Instantly share code, notes, and snippets.

@ThilinaRajapakse
Last active May 20, 2021 15:05
Show Gist options
  • Save ThilinaRajapakse/c83fc5bed7fc56218e639dc6f63f9b8f to your computer and use it in GitHub Desktop.
Save ThilinaRajapakse/c83fc5bed7fc56218e639dc6f63f9b8f to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import (
RobertaModel
)
class RobertaForMultiLabelSequenceClassification(BertPreTrainedModel):
"""
Roberta model adapted for multi-label sequence classification
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "roberta"
def __init__(self, config, pos_weight=None):
super(RobertaForMultiLabelSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.pos_weight = pos_weight
self.roberta = RobertaModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.hidden_size)
self.classifier_1 = nn.Linear(config.hidden_size, config.num_labels)
self.relu = nn.ReLU()
self.init_weights()
def forward(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
labels=None,
):
outputs = self.roberta(input_ids, token_type_ids=token_type_ids,attention_mask=attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
logits = self.relu(logits)
logits = self.classifier_1(logits)
outputs = (logits,) + outputs[2:]
if labels is not None:
loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
labels = labels.float()
loss = loss_fct(
logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
)
outputs = (loss,) + outputs
return outputs
def unfreeze(self,start_layer,end_layer):
def children(m):
return m if isinstance(m, (list, tuple)) else list(m.children())
def set_trainable_attr(m, b):
m.trainable = b
for p in m.parameters():
p.requires_grad = b
def apply_leaf(m, f):
c = children(m)
if isinstance(m, nn.Module):
f(m)
if len(c) > 0:
for l in c:
apply_leaf(l, f)
def set_trainable(l, b):
apply_leaf(l, lambda m: set_trainable_attr(m, b))
# You can unfreeze the last layer of bert by calling set_trainable(model.bert.encoder.layer[23], True)
set_trainable(self.bert, False)
for i in range(start_layer, end_layer+1):
set_trainable(self.bert.encoder.layer[i], True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment