Skip to content

Instantly share code, notes, and snippets.

@LevanKvirkvelia
Last active October 3, 2024 17:51
Show Gist options
  • Save LevanKvirkvelia/4ebc00be62b914d3a14811d7a0ea56ea to your computer and use it in GitHub Desktop.
Save LevanKvirkvelia/4ebc00be62b914d3a14811d7a0ea56ea to your computer and use it in GitHub Desktop.
nanoBERT, inspired by @karpathy's nanoGPT
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
class BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor]:
# Compute query, key, value tensors
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
# Use flash implementation if config.flash is set
if self.flash:
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout.p if self.training else 0,
is_causal=False
)
else:
# Compute attention scores
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_scores += attention_mask
# Convert attention scores to probabilities
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
# Compute context layer
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(context_layer.size()[:-2] + (self.all_head_size,))
# Return accordingly
outputs = context_layer
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor]:
attention_outputs = self.self(hidden_states=hidden_states, attention_mask=attention_mask)
attention_output = self.output(input_tensor=hidden_states, hidden_states=attention_outputs)
outputs = attention_output
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor]:
self_attention_outputs = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask
)
attention_output = self_attention_outputs
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None):
for layer_module in self.layer:
hidden_states = layer_module(hidden_states=hidden_states, attention_mask=attention_mask)
return hidden_states
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids=None, token_type_ids=None) -> torch.Tensor:
b, t = input_ids.size()
device = input_ids.device
# Use the position_ids created in the constructor if none are provided
position_ids = torch.arange(0, t, dtype=torch.long, device=device)
# Use the buffered token_type_ids if none are provided
if token_type_ids is None:
token_type_ids = torch.zeros(self.position_ids.size(), dtype=torch.long).expand(b, t)
# If no input embeddings are provided, transform input_ids into embeddings
inputs_embeds = self.word_embeddings(input_ids)
# Create token type embeddings
token_type_embeddings = self.token_type_embeddings(token_type_ids)
# Sum word embeddings, position embeddings, and token type embeddings
embeddings = inputs_embeds + token_type_embeddings
embeddings += self.position_embeddings(position_ids)
# Normalize and apply dropout
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertConfig:
def __init__(self):
self.hidden_size = 384
self.num_hidden_layers = 12
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.layer_norm_eps = 1e-12
self.intermediate_size = 1536
self.vocab_size = 30522
self.max_position_embeddings = 512
self.type_vocab_size = 2
self.pad_token_id = 0
self.num_attention_heads = 12
class BertModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
def get_extended_attention_mask(self, attention_mask, input_shape):
# Convert attention mask to binary:
extended_attention_mask = attention_mask[:, None, None, :]
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
raise ValueError("You have to specify either input_ids")
if attention_mask is None:
attention_mask = torch.ones(input_shape, dtype=torch.long, device=input_ids.device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(input_ids, token_type_ids)
encoder_outputs = self.encoder(embedding_output, extended_attention_mask)
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output
config = BertConfig()
model = BertModel(config)
@LevanKvirkvelia
Copy link
Author

script to load weights



from transformers import BertModel, BertTokenizerFast
print("loading weights from pretrained gpt: %s" % model_type)
model_hf = BertModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()
sd_keys_hf = sd_hf.keys()

sd = model.state_dict()
sd_keys = sd.keys()

# copy while ensuring all of the parameters are aligned and match in names and shapes
sd_keys_hf = sd_hf.keys()
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
# this means that we have to transpose these weights when we import them
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)} "
for k in sd_keys_hf:
    assert sd_hf[k].shape == sd[k].shape
    with torch.no_grad():
        sd[k].copy_(sd_hf[k])```

@LevanKvirkvelia
Copy link
Author

dont forget to use model.eval() to disable dropout

@vgel
Copy link

vgel commented Aug 28, 2023

awesome!

@snoop2head
Copy link

snoop2head commented Nov 1, 2023

Loading script below also works fine for me.
Thanks for the awesome work!

from transformers import BertModel as HFBertModel

hf_model_name = "roberta-base"
hf_model = HFBertModel.from_pretrained(hf_model_name)
hf_state_dict = hf_model.state_dict()

nanoconfig = hf_model.config
nanobert = BertModel(nanoconfig)
nanobert.load_state_dict(hf_state_dict)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment