Last active
February 10, 2021 21:59
-
-
Save jacobdanovitch/a09957d4ef387d32c1895dc97be70113 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Tuple, Dict, Any, Union | |
from overrides import overrides | |
from copy import deepcopy | |
from allennlp.common import Registrable | |
from transformers.models.bert.configuration_bert import BertConfig | |
from transformers.models.bert.modeling_bert import BertLayer | |
from transformers.models.roberta.configuration_roberta import RobertaConfig | |
from transformers.models.roberta.modeling_roberta import RobertaLayer | |
from deepspeed.ops.sparse_attention import ( | |
BertSparseSelfAttention, | |
SparsityConfig, | |
DenseSparsityConfig, | |
FixedSparsityConfig, | |
VariableSparsityConfig, | |
BigBirdSparsityConfig, | |
BSLongformerSparsityConfig | |
) | |
import torch | |
import warnings | |
class SparseSelfAttentionLayer(BertSparseSelfAttention): | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor], | |
*args, | |
**kwargs | |
): | |
extras = (*args, *kwargs.values()) | |
if not all(arg is None for arg in extras): | |
warnings.warn('SparseSelfAttentionLayer only accepts hidden_states and attention_mask.') | |
return super().forward(hidden_states, attention_mask) | |
def replace_self_attention( | |
model: torch.nn.Module, | |
sparsity_config: SparsityConfig, | |
model_config: Union[BertConfig, RobertaConfig] = None, | |
): | |
# Largely follows these: | |
# https://github.com/microsoft/DeepSpeed/blob/c5b3f40e8481748f9658a19c2df1f17c5b579919/deepspeed/module_inject/inject.py#L6 | |
# https://github.com/microsoft/DeepSpeed/blob/c5b3f40e8481748f9658a19c2df1f17c5b579919/deepspeed/ops/sparse_attention/sparse_attention_utils.py#L85 | |
config = model_config or model.config | |
assert isinstance(config, (BertConfig, RobertaConfig)), "Only BERT and RoBERTa are currently supported by Deepspeed." | |
for name, layer in model.named_children(): | |
if isinstance(layer, (BertLayer, RobertaLayer)): | |
deepspeed_sparse_self_attn = SparseSelfAttentionLayer(config, sparsity_config) | |
deepspeed_sparse_self_attn.query = layer.attention.self.query | |
deepspeed_sparse_self_attn.key = layer.attention.self.key | |
deepspeed_sparse_self_attn.value = layer.attention.self.value | |
layer.attention.self = deepspeed_sparse_self_attn | |
setattr(model, name, deepcopy(layer)) | |
else: | |
replace_self_attention(layer, sparsity_config, model_config=config) | |
return model | |
class _SparsityConfig(Registrable, SparsityConfig): | |
default_implementation = 'base' | |
_SparsityConfig.register('base')(SparsityConfig) | |
_SparsityConfig.register('dense')(DenseSparsityConfig) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Tuple, Dict, Any, Union | |
from overrides import overrides | |
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder | |
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder | |
from deepspeed.ops.sparse_attention import SparseAttentionUtils | |
from .sparse_attention import SparseSelfAttentionLayer, _SparsityConfig, replace_self_attention | |
import torch | |
@TokenEmbedder.register('sparse_transformer') | |
class SparseTransformerEmbedder(PretrainedTransformerEmbedder): | |
def __init__( | |
self, | |
model_name: str, | |
sparsity_config: _SparsityConfig = _SparsityConfig(num_heads=4), | |
**kwargs | |
): | |
super().__init__(model_name, **kwargs) | |
self._sparsity_config = sparsity_config | |
self.transformer_model = replace_self_attention(self.transformer_model, self._sparsity_config) | |
import torchsnooper | |
# @overrides | |
@torchsnooper.snoop() | |
def forward( | |
self, | |
token_ids: torch.LongTensor, | |
mask: torch.BoolTensor, | |
type_ids: Optional[torch.LongTensor] = None, | |
segment_concat_mask: Optional[torch.BoolTensor] = None, | |
) -> torch.Tensor: # type: ignore | |
_, token_ids, mask, type_ids, *_ = SparseAttentionUtils.pad_to_block_size( | |
block_size=self._sparsity_config.block, | |
input_ids=token_ids, | |
attention_mask=mask, | |
token_type_ids=type_ids, | |
position_ids=None, | |
inputs_embeds=None, | |
pad_token_id=self.transformer_model.config.pad_token_id, | |
model_mbeddings=None, # typo is in function definition, not here | |
) | |
return super().forward(token_ids=token_ids, mask=mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask) | |
if __name__ == '__main__': | |
from allennlp.common import Params | |
embedder = TokenEmbedder.from_params(Params({ | |
"type": "sparse_transformer", | |
"model_name": "bert-base-uncased", # "roberta-base" | |
"sparsity_config": { | |
"num_heads": 4 | |
} | |
})) | |
print(embedder) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment