Created
January 26, 2022 20:35
-
-
Save sam-writer/723baf81c501d9d24c6955f201d86bbb to your computer and use it in GitHub Desktop.
Use T5 Encoder for Sequence Classification with small linear head
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
class T5EncoderClassificationHead(nn.Module): | |
"""Head for sentence-level classification tasks.""" | |
def __init__(self, config): | |
super().__init__() | |
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size) | |
classifier_dropout = ( | |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob | |
) | |
self.dropout = nn.Dropout(classifier_dropout) | |
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
def forward(self, hidden_states, **kwargs): | |
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS]) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.dense(hidden_states) | |
hidden_states = torch.tanh(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.out_proj(hidden_states) | |
return hidden_states | |
class T5EncoderForSequenceClassification: | |
""" | |
Use an in-memory T5Encoder to do sequence classification""" | |
def __init__(self, t5_encoder, config): | |
self.num_labels = config.num_labels | |
self.config = config | |
self.encoder = t5_encoder # already initialized model | |
# either we are in eval mode, and the following code should do nothing | |
# or we are training, but we only want to fine tune the classifier head | |
# we do not want to fine-tune the encoder | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
self.classifier = T5EncoderClassificationHead(config) | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_hidden_states=None, | |
output_attentions=None, | |
return_dict=None, | |
): | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
encoder_outputs = self.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
head_mask=head_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = encoder_outputs[0] | |
logits = self.classifier(sequence_output) | |
loss = None | |
if labels is not None: | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
loss_fct = MSELoss() | |
if self.num_labels == 1: | |
loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
else: | |
loss = loss_fct(logits, labels) | |
elif self.config.problem_type == "single_label_classification": | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
elif self.config.problem_type == "multi_label_classification": | |
loss_fct = BCEWithLogitsLoss() | |
loss = loss_fct(logits, labels) | |
if not return_dict: | |
output = (logits,) + encoder_outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
) |
instead of using hidden_states[:, 0, :] , the below post insist using the mean of the last_hidden_state
https://stackoverflow.com/questions/64579258/sentence-embedding-using-t5
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @sam-writer,
How to use this with an example for multi labels text classification?