Skip to content

Instantly share code, notes, and snippets.

@prrao87
Last active September 8, 2019 20:18
Show Gist options
  • Save prrao87/3fd118976840bda5a773432f51f8f4ad to your computer and use it in GitHub Desktop.
Save prrao87/3fd118976840bda5a773432f51f8f4ad to your computer and use it in GitHub Desktop.
"""
Code below is as per the NAACL transfer learning tutorial:
https://github.com/huggingface/naacl_transfer_learning_tutorial
"""
class TransformerWithClfHead(nn.Module):
def __init__(self, config, fine_tuning_config):
""" Transformer with a classification head. """
super().__init__()
self.config = fine_tuning_config
self.transformer = Transformer(config.embed_dim, config.hidden_dim, config.num_embeddings,
config.num_max_positions, config.num_heads, config.num_layers,
fine_tuning_config.dropout, causal=not config.mlm)
self.classification_head = nn.Linear(config.embed_dim, fine_tuning_config.num_classes)
self.apply(self.init_weights)
def init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
module.weight.data.normal_(mean=0.0, std=self.config.init_range)
if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
module.bias.data.zero_()
def forward(self, x, clf_tokens_mask, lm_labels=None, clf_labels=None, padding_mask=None):
hidden_states = self.transformer(x, padding_mask)
clf_tokens_states = (hidden_states * clf_tokens_mask.unsqueeze(-1).float()).sum(dim=0)
clf_logits = self.classification_head(clf_tokens_states)
if clf_labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(clf_logits.view(-1, clf_logits.size(-1)), clf_labels.view(-1))
return clf_logits, loss
return clf_logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment