Last active
May 20, 2021 15:05
-
-
Save ThilinaRajapakse/c83fc5bed7fc56218e639dc6f63f9b8f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from transformers import ( | |
RobertaModel | |
) | |
class RobertaForMultiLabelSequenceClassification(BertPreTrainedModel): | |
""" | |
Roberta model adapted for multi-label sequence classification | |
""" | |
config_class = RobertaConfig | |
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST | |
base_model_prefix = "roberta" | |
def __init__(self, config, pos_weight=None): | |
super(RobertaForMultiLabelSequenceClassification, self).__init__(config) | |
self.num_labels = config.num_labels | |
self.pos_weight = pos_weight | |
self.roberta = RobertaModel(config) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = nn.Linear(config.hidden_size, config.hidden_size) | |
self.classifier_1 = nn.Linear(config.hidden_size, config.num_labels) | |
self.relu = nn.ReLU() | |
self.init_weights() | |
def forward( | |
self, | |
input_ids, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
labels=None, | |
): | |
outputs = self.roberta(input_ids, token_type_ids=token_type_ids,attention_mask=attention_mask, head_mask=head_mask) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
logits = self.relu(logits) | |
logits = self.classifier_1(logits) | |
outputs = (logits,) + outputs[2:] | |
if labels is not None: | |
loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) | |
labels = labels.float() | |
loss = loss_fct( | |
logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) | |
) | |
outputs = (loss,) + outputs | |
return outputs | |
def unfreeze(self,start_layer,end_layer): | |
def children(m): | |
return m if isinstance(m, (list, tuple)) else list(m.children()) | |
def set_trainable_attr(m, b): | |
m.trainable = b | |
for p in m.parameters(): | |
p.requires_grad = b | |
def apply_leaf(m, f): | |
c = children(m) | |
if isinstance(m, nn.Module): | |
f(m) | |
if len(c) > 0: | |
for l in c: | |
apply_leaf(l, f) | |
def set_trainable(l, b): | |
apply_leaf(l, lambda m: set_trainable_attr(m, b)) | |
# You can unfreeze the last layer of bert by calling set_trainable(model.bert.encoder.layer[23], True) | |
set_trainable(self.bert, False) | |
for i in range(start_layer, end_layer+1): | |
set_trainable(self.bert.encoder.layer[i], True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment