Last active
April 11, 2022 21:20
-
-
Save nbroad1881/0f7ee2ac87e70fe6ce62d6f4060f4324 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from transformers import AutoModel | |
class Model(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.model = AutoModel.from_pretrained(...) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.dropout1 = nn.Dropout(0.1) | |
self.dropout2 = nn.Dropout(0.2) | |
self.dropout3 = nn.Dropout(0.3) | |
self.dropout4 = nn.Dropout(0.4) | |
self.dropout5 = nn.Dropout(0.5) | |
self.output = nn.Linear(config.hidden_size, num_outputs) | |
self.loss_fn = ... | |
def forward(self, input_ids, attention_mask, labels): | |
output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
# When using for inference, remove everything between the # --- | |
# and replace with `logits = self.output(output[0])` | |
# --- | |
output = self.dropout(output[0]) | |
logits1 = self.output(self.dropout1(output)) | |
logits2 = self.output(self.dropout2(output)) | |
logits3 = self.output(self.dropout3(output)) | |
logits4 = self.output(self.dropout4(output)) | |
logits5 = self.output(self.dropout5(output)) | |
logits = (logits1 + logits2 + logits3 + logits4 + logits5) / 5 | |
loss1 = self.loss_fn(logits1, labels) | |
loss2 = self.loss_fn(logits2, labels) | |
loss3 = self.loss_fn(logits3, labels) | |
loss4 = self.loss_fn(logits4, labels) | |
loss5 = self.loss_fn(logits5, labels) | |
loss = (loss1 + loss2 + loss3 + loss4 + loss5) / 5 | |
# --- | |
return logits, loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment