Last active
May 18, 2022 10:06
-
-
Save edgartanaka/0d69b50e39f96cb0738f9808d48158a2 to your computer and use it in GitHub Desktop.
Converting MBart to Longformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
import os | |
import copy | |
from transformers import MBart50Tokenizer | |
from transformers import MBartForConditionalGeneration, AutoTokenizer | |
# from transformers.modeling_bart import shift_tokens_right | |
from longformer_encoder_decoder import LongformerSelfAttentionForMBart, LongformerEncoderDecoderConfig | |
from longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
def create_long_model( | |
save_model_to, | |
base_model, | |
tokenizer_name_or_path, | |
attention_window, | |
max_pos | |
): | |
# model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50") | |
# tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO") | |
model = MBartForConditionalGeneration.from_pretrained(base_model) | |
tokenizer = MBart50Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos) | |
config = LongformerEncoderDecoderConfig.from_pretrained(base_model) | |
model.config = config | |
# in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention | |
# expects attention_probs_dropout_prob, so set it here | |
config.attention_probs_dropout_prob = config.attention_dropout | |
config.architectures = ['LongformerEncoderDecoderForConditionalGeneration', ] | |
# extend position embeddings | |
tokenizer.model_max_length = max_pos | |
tokenizer.init_kwargs['model_max_length'] = max_pos | |
current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape | |
assert current_max_pos == config.max_position_embeddings + 2 | |
config.max_encoder_position_embeddings = max_pos | |
config.max_decoder_position_embeddings = config.max_position_embeddings | |
del config.max_position_embeddings | |
max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2 | |
assert max_pos >= current_max_pos | |
# allocate a larger position embedding matrix for the encoder | |
new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size) | |
# copy position embeddings over and over to initialize the new position embeddings | |
k = 2 | |
step = current_max_pos - 2 | |
while k < max_pos - 1: | |
new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:] | |
k += step | |
model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed | |
# allocate a larger position embedding matrix for the decoder | |
# new_decoder_pos_embed = model.model.decoder.embed_positions.weight.new_empty(max_pos, embed_size) | |
# # copy position embeddings over and over to initialize the new position embeddings | |
# k = 2 | |
# step = current_max_pos - 2 | |
# while k < max_pos - 1: | |
# new_decoder_pos_embed[k:(k + step)] = model.model.decoder.embed_positions.weight[2:] | |
# k += step | |
# model.model.decoder.embed_positions.weight.data = new_decoder_pos_embed | |
# replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention` | |
config.attention_window = [attention_window] * config.num_hidden_layers | |
config.attention_dilation = [1] * config.num_hidden_layers | |
for i, layer in enumerate(model.model.encoder.layers): | |
longformer_self_attn_for_bart = LongformerSelfAttentionForMBart(config, layer_id=i) | |
longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj | |
longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj | |
longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj | |
longformer_self_attn_for_bart.longformer_self_attn.query_global = copy.deepcopy(layer.self_attn.q_proj) | |
longformer_self_attn_for_bart.longformer_self_attn.key_global = copy.deepcopy(layer.self_attn.k_proj) | |
longformer_self_attn_for_bart.longformer_self_attn.value_global = copy.deepcopy(layer.self_attn.v_proj) | |
longformer_self_attn_for_bart.output = layer.self_attn.out_proj | |
layer.self_attn = longformer_self_attn_for_bart | |
# save model | |
logger.info(f'saving model to {save_model_to}') | |
model.save_pretrained(save_model_to) | |
tokenizer.save_pretrained(save_model_to) | |
return model, tokenizer | |
# def mask_test(args): | |
# tokenizer = MBartTokenizer.from_pretrained(args.save_model_to) | |
# TXT = "My friends are <mask> but they eat too many carbs." | |
# model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(args.save_model_to) | |
# model.model.encoder.config.gradient_checkpointing = True | |
# model.model.decoder.config.gradient_checkpointing = True | |
# data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048) | |
# input_ids = data['input_ids'] | |
# attention_mask = data['attention_mask'] | |
# decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id) | |
# logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0] | |
# masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | |
# probs = logits[0, masked_index].softmax(dim=0) | |
# values, predictions = probs.topk(5) | |
# print(tokenizer.convert_ids_to_tokens(predictions)) | |
def summary_test(args): | |
tokenizer = MBart50Tokenizer.from_pretrained(args.save_model_to) | |
# TXT = "My friends are <mask> but they eat too many carbs." | |
model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(args.save_model_to) | |
# what are these doing?! | |
# I discommented them because I think they are going to fix the problem of arguments in forward function | |
model.model.encoder.config.gradient_checkpointing = True | |
model.model.decoder.config.gradient_checkpointing = True | |
# ARTICLE_TO_SUMMARIZE = "My friends are cool, but they eat too much carbs." | |
with open('article_es.txt', 'r') as file: | |
ARTICLE_TO_SUMMARIZE = file.read().replace('\n', '') | |
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=4096, return_tensors='pt', padding="max_length", truncation=True) | |
# Generate Summary | |
print(inputs['input_ids']) | |
print('length input ids:', inputs['input_ids'].size()) | |
print('w = ', model.model.config.attention_window) | |
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=50, early_stopping=True) | |
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) | |
def main(): | |
parser = argparse.ArgumentParser(description="Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention") | |
parser.add_argument( | |
'--base_model', | |
type=str, | |
default='facebook/bart-large', | |
help='The name or path of the base model you want to convert' | |
) | |
parser.add_argument( | |
'--tokenizer_name_or_path', | |
type=str, | |
default='facebook/bart-large', | |
help='The name or path of the tokenizer' | |
) | |
parser.add_argument( | |
'--save_model_to', | |
type=str, | |
required=True, | |
help='The path to save the converted model' | |
) | |
parser.add_argument( | |
'--attention_window', | |
type=int, | |
default=512, | |
help='attention window size for longformer self attention (one sided)' | |
) | |
parser.add_argument( | |
'--max_pos', | |
type=int, | |
default=4096, | |
help='maximum encoder positions' | |
) | |
args = parser.parse_args() | |
if not os.path.exists(args.save_model_to): | |
os.mkdir(args.save_model_to) | |
create_long_model( | |
save_model_to=args.save_model_to, | |
base_model=args.base_model, | |
tokenizer_name_or_path=args.tokenizer_name_or_path, | |
attention_window=args.attention_window, | |
max_pos=args.max_pos | |
) | |
summary_test(args) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional, Tuple, Dict | |
from torch import nn, Tensor | |
# from longformer.longformer import LongformerSelfAttention | |
from transformers import LongformerSelfAttention | |
from transformers import MBartConfig, MBartForConditionalGeneration | |
from transformers.models.mbart.modeling_mbart import MBartLearnedPositionalEmbedding | |
class LongformerEncoderDecoderForConditionalGeneration(MBartForConditionalGeneration): | |
def __init__(self, config): | |
super().__init__(config) | |
if config.attention_mode == 'n2': | |
pass # do nothing, use BertSelfAttention instead | |
else: | |
self.model.encoder.embed_positions = MBartLearnedPositionalEmbedding(4096, 1024) | |
for i, layer in enumerate(self.model.encoder.layers): | |
layer.self_attn = LongformerSelfAttentionForMBart(config, layer_id=i) | |
class LongformerEncoderDecoderConfig(MBartConfig): | |
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, | |
autoregressive: bool = False, attention_mode: str = 'sliding_chunks', | |
gradient_checkpointing: bool = False, **kwargs): | |
""" | |
Args: | |
attention_window: list of attention window sizes of length = number of layers. | |
window size = number of attention locations on each side. | |
For an affective window size of 512, use `attention_window=[256]*num_layers` | |
which is 256 on each side. | |
attention_dilation: list of attention dilation of length = number of layers. | |
attention dilation of `1` means no dilation. | |
autoregressive: do autoregressive attention or have attention of both sides | |
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer | |
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention | |
""" | |
super().__init__(**kwargs) | |
self.attention_window = attention_window | |
self.attention_dilation = attention_dilation | |
self.autoregressive = autoregressive | |
self.attention_mode = attention_mode | |
self.gradient_checkpointing = gradient_checkpointing | |
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] | |
class LongformerSelfAttentionForMBart(nn.Module): | |
def __init__(self, config, layer_id): | |
super().__init__() | |
self.embed_dim = config.d_model | |
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id) | |
self.output = nn.Linear(self.embed_dim, self.embed_dim) | |
def forward( | |
self, | |
hidden_states=None, | |
attention_mask=None, | |
layer_head_mask=None, | |
output_attentions=False | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
# NEW | |
outputs = self.longformer_self_attn( | |
hidden_states=hidden_states, # I'm guessing I just need to pass | |
attention_mask=attention_mask, # I'm guessing I just need to pass | |
layer_head_mask=layer_head_mask, # I'm guessing I just need to pass | |
is_index_masked=None, | |
is_index_global_attn=None, | |
is_global_attn=None, | |
output_attentions=output_attentions, | |
) | |
attn_output = self.output(outputs[0].transpose(0, 1)) | |
return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python convert_mbart_to_longformer.py --save_model_to model_dir \ | |
--base_model facebook/mbart-large-50 \ | |
--tokenizer_name_or_path facebook/mbart-large-50% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey man, I'm not sure you're still looking to solve this but I was struggling with the same problem and thanks to this guy https://github.com/Taeksu-Kim/longformer_kobart I've managed to generate summaries with my brand new longformer Bart
2 things needs to be changed: