Created
January 25, 2022 17:59
-
-
Save georgepar/01e341c15f21af2e4b13044609b2e77e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Optional | |
import torch | |
import torch.nn as nn | |
from torchcrf import CRF | |
from transformers import AutoModel | |
class TransformerSlidingWindower(nn.Module): | |
"""Apply model on a strided sliding window | |
Good for use with transformer models on long sequences. | |
We split the sequences in windows of fixed length with an overlap and apply the model on these windows | |
Some context is preserved because of the overlap. | |
Below we can see this operation. We have an intut sequence I of length 8. | |
We create a sliding window, which consists of 2 subsequences S and T of length 5 with overlap 2 | |
and pass each subsequence through BERT. | |
The output sequence features are pooled (for strided sequence elements). We can use max or mean pooling. | |
O1 O2 O3 O4 O5 O6 O7 O8 | |
_______________________ | |
POOLING | |
_______________________ | |
| | | |
| | | |
| BERT | |
| _______|______ | |
BERT | | | |
_______|______ | | |
| | | | | |
S1 S2 S3 S4 S5 | | |
T1 T2 T3 T4 T5 | |
I1 I2 I3 I4 I5 I6 I7 I8 | |
This class works well with Bert-esque models from the transformers library. | |
Args: | |
underlying_model (nn.Module): The model to use as a feature extractor (e.g. BertModel, RobertaModel) etc. | |
window_size (int): Size of sliding window | |
stride(int): overlap length | |
""" | |
def __init__( | |
self, | |
underlying_model: nn.Module, | |
window_size: int = 512, | |
stride: int = 128, | |
stride_aggregation: str = "mean", | |
pooler_aggregation: str = "mean", | |
): | |
super(TransformerSlidingWindower, self).__init__() | |
self.underlying_model = underlying_model | |
self.hidden_size = underlying_model.config.hidden_size | |
self.window_size = ( | |
window_size - 2 | |
) # Keep 2 elements for [CLS] and sep in each subsequence | |
self.stride = stride | |
assert stride_aggregation in [ | |
"mean", | |
"max", | |
], "Unsupported stride aggregation method. Only [mean, max] are supported" | |
assert pooler_aggregation in [ | |
"mean", | |
"max", | |
], "Unsupported pooler aggregation method. Only [mean, max] are supported" | |
self.stride_aggregation = stride_aggregation | |
self.pooler_aggregation = pooler_aggregation | |
def slider(self, sequence_length: int): | |
start_index = 0 | |
end_index = min(sequence_length, self.window_size) | |
while True: | |
if sequence_length <= self.window_size: | |
yield start_index, end_index | |
break | |
ost = start_index | |
oet = min(end_index, sequence_length) | |
start_index += self.window_size - self.stride | |
end_index = start_index + self.window_size | |
yield ost, oet | |
if oet >= sequence_length: | |
break | |
def _augment_with_cls( | |
self, window_input_ids, window_attention_mask, window_token_type_ids | |
): | |
aug_cls = ( | |
torch.zeros_like(window_input_ids[:, 0]).unsqueeze(1) + 101 | |
) # Add CLS tokens | |
window_input_ids = torch.cat((aug_cls, window_input_ids), dim=1) | |
if window_attention_mask is not None: | |
aug_attmask = window_attention_mask[:, 0].clone().unsqueeze(1) | |
window_attention_mask = torch.cat( # Use previous attention_mask value | |
(aug_attmask, window_attention_mask), dim=1 | |
) | |
if window_token_type_ids is not None: | |
aug_ttids = window_token_type_ids[:, 0].clone().unsqueeze(1) | |
window_token_type_ids = torch.cat( # Use previous token type | |
(aug_ttids, window_token_type_ids), dim=1 | |
) | |
return window_input_ids, window_attention_mask, window_token_type_ids | |
def _augment_with_sep( | |
self, window_input_ids, window_attention_mask, window_token_type_ids | |
): | |
aug_sep = ( | |
torch.zeros_like(window_input_ids[:, 0]).unsqueeze(1) + 102 | |
) # Add SEP tokens | |
window_input_ids = torch.cat((window_input_ids, aug_sep), dim=1) | |
if window_attention_mask is not None: | |
aug_attmask = window_attention_mask[:, -1].clone().unsqueeze(1) | |
window_attention_mask = torch.cat( | |
(window_attention_mask, aug_attmask), dim=1 | |
) | |
if window_token_type_ids is not None: | |
aug_ttids = window_token_type_ids[:, -1].clone().unsqueeze(1) | |
window_token_type_ids = torch.cat((window_token_type_ids, aug_ttids), dim=1) | |
return window_input_ids, window_attention_mask, window_token_type_ids | |
def _aggregator(self, x1, x2, selector): | |
if selector == "mean": | |
return torch.mean(torch.stack((x1, x2)), dim=0) | |
elif selector == "max": | |
return torch.maximum(x1, x2) | |
else: | |
raise ValueError(f"Unsupported aggregation method {selector}") | |
def _aggregate_hidden_states( | |
self, start_index, end_index, last_hidden_prev, current_hidden | |
): | |
# last_hidden_prev: (B, S, F) | |
# current_hidden: (B, W, F) | |
last_hidden_whole_sequence = last_hidden_prev.clone() | |
if start_index == 0: | |
last_hidden_whole_sequence[:, start_index:end_index] = current_hidden | |
return last_hidden_whole_sequence | |
last_hidden_whole_sequence[ | |
:, start_index : start_index + self.stride | |
] = self._aggregator( | |
last_hidden_whole_sequence[ | |
:, start_index : start_index + self.stride | |
].clone(), | |
current_hidden[:, : self.stride], | |
self.stride_aggregation, | |
) # Average with last strided output | |
last_hidden_whole_sequence[ | |
:, start_index + self.stride : end_index | |
] = current_hidden[ | |
:, self.stride : | |
] # Set the rest of the hidden states | |
return last_hidden_whole_sequence | |
def _aggregate_pooler_output(self, up2now_out, current_out): | |
return self._aggregator(up2now_out, current_out, self.pooler_aggregation) | |
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None): | |
kwargs = { | |
"past_key_values": None, # incompatible arguments. DON'T USE THIS FOR SEQUENTIAL DECODING | |
"use_cache": False, | |
"return_dict": True, | |
} | |
batch_size, sequence_length = input_ids.shape | |
window_attention_mask, window_token_type_ids = None, None | |
last_hidden_state = torch.zeros( | |
(batch_size, sequence_length, self.hidden_size), dtype=torch.float # type: ignore | |
).to(input_ids.device) | |
pooler_output = torch.zeros( | |
(batch_size, self.hidden_size), dtype=torch.float # type: ignore | |
).to(input_ids.device) | |
for start_index, end_index in self.slider(sequence_length): | |
window_input_ids = input_ids[:, start_index:end_index] | |
if attention_mask is not None: | |
window_attention_mask = attention_mask[:, start_index:end_index] | |
if token_type_ids is not None: | |
window_token_type_ids = token_type_ids[:, start_index:end_index] | |
if start_index > 0: | |
( | |
window_input_ids, | |
window_attention_mask, | |
window_token_type_ids, | |
) = self._augment_with_cls( | |
window_input_ids, window_attention_mask, window_token_type_ids | |
) | |
if end_index < sequence_length: | |
( | |
window_input_ids, | |
window_attention_mask, | |
window_token_type_ids, | |
) = self._augment_with_sep( | |
window_input_ids, window_attention_mask, window_token_type_ids | |
) | |
outputs = self.underlying_model( | |
input_ids=window_input_ids, | |
attention_mask=window_attention_mask, | |
token_type_ids=window_token_type_ids, | |
**kwargs, | |
) | |
pooler_output = self._aggregate_pooler_output( | |
pooler_output, outputs.pooler_output | |
) | |
current_last_hidden_state = outputs.last_hidden_state | |
if start_index > 0: | |
current_last_hidden_state = current_last_hidden_state[ | |
:, 1:, : | |
] # remove extra cls hidden | |
if end_index < sequence_length: | |
current_last_hidden_state = current_last_hidden_state[ | |
:, :-1, : | |
] # remove extra sep hidden | |
last_hidden_state = self._aggregate_hidden_states( | |
start_index, | |
end_index, | |
last_hidden_state, | |
current_last_hidden_state, | |
) | |
return last_hidden_state, pooler_output | |
class TransformerDocumentCRF(nn.Module): | |
""" | |
[ BERT Sliding windower ] | |
O1 O2 O3 ............. ON | |
/ | \ | |
/ | \ | |
[ARG1 CRF] [Connector CRF] [ARG2 CRF] | |
L_ARG1 L_CONNECTOR L_ARG2 | |
LOSS = w_ARG1 * L_ARG1 + w_CONNECTOR * L_CONNECTOR + w_ARG2 * L_ARG2 | |
""" | |
def __init__( | |
self, | |
multitask_num_tags: Dict[str, int], | |
multitask_weights: Optional[Dict[str, float]] = None, | |
pretrained_model: str = "bert-base-uncased", | |
window_size: int = 512, | |
stride: int = 128, | |
stride_aggregation: str = "mean", | |
pooler_aggregation: str = "mean", | |
): | |
super(TransformerDocumentCRF, self).__init__() | |
underlying_model = AutoModel.from_pretrained(pretrained_model) | |
self.sliding_bert = TransformerSlidingWindower( | |
underlying_model, | |
window_size=window_size, | |
stride=stride, | |
stride_aggregation=stride_aggregation, | |
pooler_aggregation=pooler_aggregation, | |
) | |
self.hidden_size = self.sliding_bert.hidden_size | |
self.projectors = nn.ModuleDict( | |
{ | |
task: nn.Linear(self.hidden_size, num_tags) | |
for task, num_tags in multitask_num_tags.items() | |
} | |
) | |
self.decoders = nn.ModuleDict( | |
{ | |
task: CRF(num_tags=num_tags, batch_first=True) | |
for task, num_tags in multitask_num_tags.items() | |
} | |
) | |
if multitask_weights is None: | |
task_weight = 1.0 / len(multitask_num_tags) | |
self.multitask_weights = { # Same weight for all losses if not provided | |
task: task_weight for task in multitask_num_tags.keys() | |
} | |
else: | |
self.multitask_weights = multitask_weights | |
def forward(self, input_ids, tags, attention_mask=None, token_type_ids=None): | |
emissions, _ = self.sliding_bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
) | |
negative_loglik = 0 | |
task_losses = {} | |
for task, decoder in self.decoders.items(): | |
w = self.multitask_weights[task] | |
emissions_logits = self.projectors[task](emissions) | |
task_loss = decoder( | |
emissions_logits, tags[task], mask=attention_mask.type(torch.bool) | |
) | |
negative_loglik += w * task_loss | |
task_losses[task] = task_loss | |
return negative_loglik, task_losses | |
def decode(self, input_ids, attention_mask=None, token_type_ids=None): | |
emissions, _ = self.sliding_bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
) | |
tags = {} | |
for task, decoder in self.decoders.items(): | |
tags[task] = decoder.decode( | |
self.projectors[task]( | |
emissions | |
) # , mask=attention_mask.type(torch.bool) | |
) | |
return tags | |
if __name__ == "__main__": | |
from transformers import AutoModel, AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
text = """ | |
As BERT can only accept/take as input only 512 tokens at a time, we must specify the truncation parameter to True. The add special tokens parameter is just for BERT to add tokens like the start, end, [SEP], and [CLS] tokens. Return_tensors = “pt” is just for the tokenizer to return PyTorch tensors. If you don’t want this to happen(maybe you want it to return a list), then you can remove the parameter and it will return lists. | |
In the code below, you will see me not adding all the parameters I listed above and this is primarily because this is not necessary as I am not tokenizing text for a real project. In a real machine learning/NLP project, you will want to add these parameters, especially the truncation and padding as we have to do this for each batch in the dataset in a real project. | |
tokenizer.encode_plus() specifically returns a dictionary of values instead of just a list of values. Because tokenizer.encode_plus() can return many different types of information, like the attention_masks and token type ids, everything is returned in a dictionary format, and if you want to retrieve the specific parts of the encoding, you can do it like this: | |
Masked Language Modeling works by inserting a mask token at the desired position where you want to predict the best candidate word that would go in that position. You can simply insert the mask token by concatenating it at the desired position in your input like I did above. The Bert Model for Masked Language Modeling predicts the best word/token in its vocabulary that would replace that word. The logits are the output of the BERT Model before a softmax activation function is applied to the output of BERT. In order to get the logits, we have to specify return_dict = True in the parameters when initializing the model, otherwise, the above code will result in a compilation error. After we pass the input encoding into the BERT Model, we can get the logits simply by specifying output.logits, which returns a tensor, and after this we can finally apply a softmax activation function to the logits. By applying a softmax onto the output of BERT, we get probabilistic distributions for each of the words in BERT’s vocabulary. Word’s with a higher probability value will be better candidate replacement words for the mask token. In order to get the tensor of softmax values of all the words in BERT’s vocabulary for replacing the mask token, we can specify the masked token index, which we get using torch.where(). Because in this particular example I am retrieving the top 10 candidate replacement words for the mask token(you can get more than 10 by adjusting the parameter accordingly), I used the torch.topk() function, which allows you to retrieve the top k values in a given tensor, and it returns a tensor containing those top k values. After this, the process becomes relatively simple, as all we have to do is iterate through the tensor, and replace the mask token in the sentence with the candidate token. Here is the output the code above compiles: | |
Language Modeling works very similarly to Masked language modeling. To start off, we have to download the specific Bert Language Model Head Model, which is essentially a BERT model with a language modeling head on top of it. One additional parameter we have to specify while instantiating this model is the is_decoder = True parameter. We have to specify this parameter if we want to use this model as a standalone model for predicting the next best word in the sequence. The rest of the code is relatively the same as the one in masked language modeling: we have to retrieve the logits of the model, but instead of specifying the index to be that of the masked token, we just have to take the logits of the last hidden state of the model(using -1 index), compute the softmax of these logits, find the largest probability value in the vocabulary, and decode and print this token. | |
""" | |
encoding = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
truncation=False, | |
# padding="max_length", | |
return_attention_mask=True, | |
return_tensors="pt", | |
return_offsets_mapping=True, | |
) | |
model = TransformerSlidingWindower( | |
AutoModel.from_pretrained("bert-base-uncased"), | |
window_size=256, | |
stride=64, | |
stride_aggregation="max", | |
pooler_aggregation="max", | |
).cuda() | |
last, pool = model( | |
input_ids=encoding["input_ids"].cuda(), | |
attention_mask=encoding["attention_mask"].cuda(), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment