Skip to content

Instantly share code, notes, and snippets.

@georgepar
Created January 25, 2022 17:59
Show Gist options
  • Save georgepar/01e341c15f21af2e4b13044609b2e77e to your computer and use it in GitHub Desktop.
Save georgepar/01e341c15f21af2e4b13044609b2e77e to your computer and use it in GitHub Desktop.
from typing import Dict, Optional
import torch
import torch.nn as nn
from torchcrf import CRF
from transformers import AutoModel
class TransformerSlidingWindower(nn.Module):
"""Apply model on a strided sliding window
Good for use with transformer models on long sequences.
We split the sequences in windows of fixed length with an overlap and apply the model on these windows
Some context is preserved because of the overlap.
Below we can see this operation. We have an intut sequence I of length 8.
We create a sliding window, which consists of 2 subsequences S and T of length 5 with overlap 2
and pass each subsequence through BERT.
The output sequence features are pooled (for strided sequence elements). We can use max or mean pooling.
O1 O2 O3 O4 O5 O6 O7 O8
_______________________
POOLING
_______________________
| |
| |
| BERT
| _______|______
BERT | |
_______|______ |
| | | |
S1 S2 S3 S4 S5 |
T1 T2 T3 T4 T5
I1 I2 I3 I4 I5 I6 I7 I8
This class works well with Bert-esque models from the transformers library.
Args:
underlying_model (nn.Module): The model to use as a feature extractor (e.g. BertModel, RobertaModel) etc.
window_size (int): Size of sliding window
stride(int): overlap length
"""
def __init__(
self,
underlying_model: nn.Module,
window_size: int = 512,
stride: int = 128,
stride_aggregation: str = "mean",
pooler_aggregation: str = "mean",
):
super(TransformerSlidingWindower, self).__init__()
self.underlying_model = underlying_model
self.hidden_size = underlying_model.config.hidden_size
self.window_size = (
window_size - 2
) # Keep 2 elements for [CLS] and sep in each subsequence
self.stride = stride
assert stride_aggregation in [
"mean",
"max",
], "Unsupported stride aggregation method. Only [mean, max] are supported"
assert pooler_aggregation in [
"mean",
"max",
], "Unsupported pooler aggregation method. Only [mean, max] are supported"
self.stride_aggregation = stride_aggregation
self.pooler_aggregation = pooler_aggregation
def slider(self, sequence_length: int):
start_index = 0
end_index = min(sequence_length, self.window_size)
while True:
if sequence_length <= self.window_size:
yield start_index, end_index
break
ost = start_index
oet = min(end_index, sequence_length)
start_index += self.window_size - self.stride
end_index = start_index + self.window_size
yield ost, oet
if oet >= sequence_length:
break
def _augment_with_cls(
self, window_input_ids, window_attention_mask, window_token_type_ids
):
aug_cls = (
torch.zeros_like(window_input_ids[:, 0]).unsqueeze(1) + 101
) # Add CLS tokens
window_input_ids = torch.cat((aug_cls, window_input_ids), dim=1)
if window_attention_mask is not None:
aug_attmask = window_attention_mask[:, 0].clone().unsqueeze(1)
window_attention_mask = torch.cat( # Use previous attention_mask value
(aug_attmask, window_attention_mask), dim=1
)
if window_token_type_ids is not None:
aug_ttids = window_token_type_ids[:, 0].clone().unsqueeze(1)
window_token_type_ids = torch.cat( # Use previous token type
(aug_ttids, window_token_type_ids), dim=1
)
return window_input_ids, window_attention_mask, window_token_type_ids
def _augment_with_sep(
self, window_input_ids, window_attention_mask, window_token_type_ids
):
aug_sep = (
torch.zeros_like(window_input_ids[:, 0]).unsqueeze(1) + 102
) # Add SEP tokens
window_input_ids = torch.cat((window_input_ids, aug_sep), dim=1)
if window_attention_mask is not None:
aug_attmask = window_attention_mask[:, -1].clone().unsqueeze(1)
window_attention_mask = torch.cat(
(window_attention_mask, aug_attmask), dim=1
)
if window_token_type_ids is not None:
aug_ttids = window_token_type_ids[:, -1].clone().unsqueeze(1)
window_token_type_ids = torch.cat((window_token_type_ids, aug_ttids), dim=1)
return window_input_ids, window_attention_mask, window_token_type_ids
def _aggregator(self, x1, x2, selector):
if selector == "mean":
return torch.mean(torch.stack((x1, x2)), dim=0)
elif selector == "max":
return torch.maximum(x1, x2)
else:
raise ValueError(f"Unsupported aggregation method {selector}")
def _aggregate_hidden_states(
self, start_index, end_index, last_hidden_prev, current_hidden
):
# last_hidden_prev: (B, S, F)
# current_hidden: (B, W, F)
last_hidden_whole_sequence = last_hidden_prev.clone()
if start_index == 0:
last_hidden_whole_sequence[:, start_index:end_index] = current_hidden
return last_hidden_whole_sequence
last_hidden_whole_sequence[
:, start_index : start_index + self.stride
] = self._aggregator(
last_hidden_whole_sequence[
:, start_index : start_index + self.stride
].clone(),
current_hidden[:, : self.stride],
self.stride_aggregation,
) # Average with last strided output
last_hidden_whole_sequence[
:, start_index + self.stride : end_index
] = current_hidden[
:, self.stride :
] # Set the rest of the hidden states
return last_hidden_whole_sequence
def _aggregate_pooler_output(self, up2now_out, current_out):
return self._aggregator(up2now_out, current_out, self.pooler_aggregation)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
kwargs = {
"past_key_values": None, # incompatible arguments. DON'T USE THIS FOR SEQUENTIAL DECODING
"use_cache": False,
"return_dict": True,
}
batch_size, sequence_length = input_ids.shape
window_attention_mask, window_token_type_ids = None, None
last_hidden_state = torch.zeros(
(batch_size, sequence_length, self.hidden_size), dtype=torch.float # type: ignore
).to(input_ids.device)
pooler_output = torch.zeros(
(batch_size, self.hidden_size), dtype=torch.float # type: ignore
).to(input_ids.device)
for start_index, end_index in self.slider(sequence_length):
window_input_ids = input_ids[:, start_index:end_index]
if attention_mask is not None:
window_attention_mask = attention_mask[:, start_index:end_index]
if token_type_ids is not None:
window_token_type_ids = token_type_ids[:, start_index:end_index]
if start_index > 0:
(
window_input_ids,
window_attention_mask,
window_token_type_ids,
) = self._augment_with_cls(
window_input_ids, window_attention_mask, window_token_type_ids
)
if end_index < sequence_length:
(
window_input_ids,
window_attention_mask,
window_token_type_ids,
) = self._augment_with_sep(
window_input_ids, window_attention_mask, window_token_type_ids
)
outputs = self.underlying_model(
input_ids=window_input_ids,
attention_mask=window_attention_mask,
token_type_ids=window_token_type_ids,
**kwargs,
)
pooler_output = self._aggregate_pooler_output(
pooler_output, outputs.pooler_output
)
current_last_hidden_state = outputs.last_hidden_state
if start_index > 0:
current_last_hidden_state = current_last_hidden_state[
:, 1:, :
] # remove extra cls hidden
if end_index < sequence_length:
current_last_hidden_state = current_last_hidden_state[
:, :-1, :
] # remove extra sep hidden
last_hidden_state = self._aggregate_hidden_states(
start_index,
end_index,
last_hidden_state,
current_last_hidden_state,
)
return last_hidden_state, pooler_output
class TransformerDocumentCRF(nn.Module):
"""
[ BERT Sliding windower ]
O1 O2 O3 ............. ON
/ | \
/ | \
[ARG1 CRF] [Connector CRF] [ARG2 CRF]
L_ARG1 L_CONNECTOR L_ARG2
LOSS = w_ARG1 * L_ARG1 + w_CONNECTOR * L_CONNECTOR + w_ARG2 * L_ARG2
"""
def __init__(
self,
multitask_num_tags: Dict[str, int],
multitask_weights: Optional[Dict[str, float]] = None,
pretrained_model: str = "bert-base-uncased",
window_size: int = 512,
stride: int = 128,
stride_aggregation: str = "mean",
pooler_aggregation: str = "mean",
):
super(TransformerDocumentCRF, self).__init__()
underlying_model = AutoModel.from_pretrained(pretrained_model)
self.sliding_bert = TransformerSlidingWindower(
underlying_model,
window_size=window_size,
stride=stride,
stride_aggregation=stride_aggregation,
pooler_aggregation=pooler_aggregation,
)
self.hidden_size = self.sliding_bert.hidden_size
self.projectors = nn.ModuleDict(
{
task: nn.Linear(self.hidden_size, num_tags)
for task, num_tags in multitask_num_tags.items()
}
)
self.decoders = nn.ModuleDict(
{
task: CRF(num_tags=num_tags, batch_first=True)
for task, num_tags in multitask_num_tags.items()
}
)
if multitask_weights is None:
task_weight = 1.0 / len(multitask_num_tags)
self.multitask_weights = { # Same weight for all losses if not provided
task: task_weight for task in multitask_num_tags.keys()
}
else:
self.multitask_weights = multitask_weights
def forward(self, input_ids, tags, attention_mask=None, token_type_ids=None):
emissions, _ = self.sliding_bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
negative_loglik = 0
task_losses = {}
for task, decoder in self.decoders.items():
w = self.multitask_weights[task]
emissions_logits = self.projectors[task](emissions)
task_loss = decoder(
emissions_logits, tags[task], mask=attention_mask.type(torch.bool)
)
negative_loglik += w * task_loss
task_losses[task] = task_loss
return negative_loglik, task_losses
def decode(self, input_ids, attention_mask=None, token_type_ids=None):
emissions, _ = self.sliding_bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
tags = {}
for task, decoder in self.decoders.items():
tags[task] = decoder.decode(
self.projectors[task](
emissions
) # , mask=attention_mask.type(torch.bool)
)
return tags
if __name__ == "__main__":
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text = """
As BERT can only accept/take as input only 512 tokens at a time, we must specify the truncation parameter to True. The add special tokens parameter is just for BERT to add tokens like the start, end, [SEP], and [CLS] tokens. Return_tensors = “pt” is just for the tokenizer to return PyTorch tensors. If you don’t want this to happen(maybe you want it to return a list), then you can remove the parameter and it will return lists.
In the code below, you will see me not adding all the parameters I listed above and this is primarily because this is not necessary as I am not tokenizing text for a real project. In a real machine learning/NLP project, you will want to add these parameters, especially the truncation and padding as we have to do this for each batch in the dataset in a real project.
tokenizer.encode_plus() specifically returns a dictionary of values instead of just a list of values. Because tokenizer.encode_plus() can return many different types of information, like the attention_masks and token type ids, everything is returned in a dictionary format, and if you want to retrieve the specific parts of the encoding, you can do it like this:
Masked Language Modeling works by inserting a mask token at the desired position where you want to predict the best candidate word that would go in that position. You can simply insert the mask token by concatenating it at the desired position in your input like I did above. The Bert Model for Masked Language Modeling predicts the best word/token in its vocabulary that would replace that word. The logits are the output of the BERT Model before a softmax activation function is applied to the output of BERT. In order to get the logits, we have to specify return_dict = True in the parameters when initializing the model, otherwise, the above code will result in a compilation error. After we pass the input encoding into the BERT Model, we can get the logits simply by specifying output.logits, which returns a tensor, and after this we can finally apply a softmax activation function to the logits. By applying a softmax onto the output of BERT, we get probabilistic distributions for each of the words in BERT’s vocabulary. Word’s with a higher probability value will be better candidate replacement words for the mask token. In order to get the tensor of softmax values of all the words in BERT’s vocabulary for replacing the mask token, we can specify the masked token index, which we get using torch.where(). Because in this particular example I am retrieving the top 10 candidate replacement words for the mask token(you can get more than 10 by adjusting the parameter accordingly), I used the torch.topk() function, which allows you to retrieve the top k values in a given tensor, and it returns a tensor containing those top k values. After this, the process becomes relatively simple, as all we have to do is iterate through the tensor, and replace the mask token in the sentence with the candidate token. Here is the output the code above compiles:
Language Modeling works very similarly to Masked language modeling. To start off, we have to download the specific Bert Language Model Head Model, which is essentially a BERT model with a language modeling head on top of it. One additional parameter we have to specify while instantiating this model is the is_decoder = True parameter. We have to specify this parameter if we want to use this model as a standalone model for predicting the next best word in the sequence. The rest of the code is relatively the same as the one in masked language modeling: we have to retrieve the logits of the model, but instead of specifying the index to be that of the masked token, we just have to take the logits of the last hidden state of the model(using -1 index), compute the softmax of these logits, find the largest probability value in the vocabulary, and decode and print this token.
"""
encoding = tokenizer.encode_plus(
text,
add_special_tokens=True,
truncation=False,
# padding="max_length",
return_attention_mask=True,
return_tensors="pt",
return_offsets_mapping=True,
)
model = TransformerSlidingWindower(
AutoModel.from_pretrained("bert-base-uncased"),
window_size=256,
stride=64,
stride_aggregation="max",
pooler_aggregation="max",
).cuda()
last, pool = model(
input_ids=encoding["input_ids"].cuda(),
attention_mask=encoding["attention_mask"].cuda(),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment