Skip to content

Instantly share code, notes, and snippets.

@jacobdanovitch
Last active February 10, 2021 21:59
Show Gist options
  • Save jacobdanovitch/a09957d4ef387d32c1895dc97be70113 to your computer and use it in GitHub Desktop.
Save jacobdanovitch/a09957d4ef387d32c1895dc97be70113 to your computer and use it in GitHub Desktop.
from typing import Optional, Tuple, Dict, Any, Union
from overrides import overrides
from copy import deepcopy
from allennlp.common import Registrable
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.bert.modeling_bert import BertLayer
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.models.roberta.modeling_roberta import RobertaLayer
from deepspeed.ops.sparse_attention import (
BertSparseSelfAttention,
SparsityConfig,
DenseSparsityConfig,
FixedSparsityConfig,
VariableSparsityConfig,
BigBirdSparsityConfig,
BSLongformerSparsityConfig
)
import torch
import warnings
class SparseSelfAttentionLayer(BertSparseSelfAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
*args,
**kwargs
):
extras = (*args, *kwargs.values())
if not all(arg is None for arg in extras):
warnings.warn('SparseSelfAttentionLayer only accepts hidden_states and attention_mask.')
return super().forward(hidden_states, attention_mask)
def replace_self_attention(
model: torch.nn.Module,
sparsity_config: SparsityConfig,
model_config: Union[BertConfig, RobertaConfig] = None,
):
# Largely follows these:
# https://github.com/microsoft/DeepSpeed/blob/c5b3f40e8481748f9658a19c2df1f17c5b579919/deepspeed/module_inject/inject.py#L6
# https://github.com/microsoft/DeepSpeed/blob/c5b3f40e8481748f9658a19c2df1f17c5b579919/deepspeed/ops/sparse_attention/sparse_attention_utils.py#L85
config = model_config or model.config
assert isinstance(config, (BertConfig, RobertaConfig)), "Only BERT and RoBERTa are currently supported by Deepspeed."
for name, layer in model.named_children():
if isinstance(layer, (BertLayer, RobertaLayer)):
deepspeed_sparse_self_attn = SparseSelfAttentionLayer(config, sparsity_config)
deepspeed_sparse_self_attn.query = layer.attention.self.query
deepspeed_sparse_self_attn.key = layer.attention.self.key
deepspeed_sparse_self_attn.value = layer.attention.self.value
layer.attention.self = deepspeed_sparse_self_attn
setattr(model, name, deepcopy(layer))
else:
replace_self_attention(layer, sparsity_config, model_config=config)
return model
class _SparsityConfig(Registrable, SparsityConfig):
default_implementation = 'base'
_SparsityConfig.register('base')(SparsityConfig)
_SparsityConfig.register('dense')(DenseSparsityConfig)
from typing import Optional, Tuple, Dict, Any, Union
from overrides import overrides
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder
from deepspeed.ops.sparse_attention import SparseAttentionUtils
from .sparse_attention import SparseSelfAttentionLayer, _SparsityConfig, replace_self_attention
import torch
@TokenEmbedder.register('sparse_transformer')
class SparseTransformerEmbedder(PretrainedTransformerEmbedder):
def __init__(
self,
model_name: str,
sparsity_config: _SparsityConfig = _SparsityConfig(num_heads=4),
**kwargs
):
super().__init__(model_name, **kwargs)
self._sparsity_config = sparsity_config
self.transformer_model = replace_self_attention(self.transformer_model, self._sparsity_config)
import torchsnooper
# @overrides
@torchsnooper.snoop()
def forward(
self,
token_ids: torch.LongTensor,
mask: torch.BoolTensor,
type_ids: Optional[torch.LongTensor] = None,
segment_concat_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor: # type: ignore
_, token_ids, mask, type_ids, *_ = SparseAttentionUtils.pad_to_block_size(
block_size=self._sparsity_config.block,
input_ids=token_ids,
attention_mask=mask,
token_type_ids=type_ids,
position_ids=None,
inputs_embeds=None,
pad_token_id=self.transformer_model.config.pad_token_id,
model_mbeddings=None, # typo is in function definition, not here
)
return super().forward(token_ids=token_ids, mask=mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask)
if __name__ == '__main__':
from allennlp.common import Params
embedder = TokenEmbedder.from_params(Params({
"type": "sparse_transformer",
"model_name": "bert-base-uncased", # "roberta-base"
"sparsity_config": {
"num_heads": 4
}
}))
print(embedder)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment