Last active
January 27, 2024 18:48
-
-
Save thistleknot/de48790627c3407921f252ff1fe50573 to your computer and use it in GitHub Desktop.
minimum nanogpt mamba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.nn.parameter import Parameter | |
from tqdm import tqdm | |
from mamba_ssm import Mamba | |
#hyperparams | |
epochs = 100 | |
lr = 1e-3 | |
batch_size = 64 | |
block_size = 256 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
max_iters = 10000 | |
print_iters = 100 | |
eval_iters = 10 | |
eval_interval = 300 | |
n_embed=384 | |
n_heads = 6 | |
n_layers = 6 | |
dropout = 0.2 | |
# --------- | |
with open("input.txt", "r") as f: | |
text = f.read() | |
# Unique characters | |
chars = sorted(list(set(text))) | |
print(''.join(chars)) | |
vocab_size = len(chars) | |
print(vocab_size) | |
#Tokenizers | |
stoi = {ch:i for i,ch in enumerate(chars)} | |
itos = {i:ch for i,ch in enumerate(chars)} | |
encode = lambda xx: [stoi[x] for x in xx] | |
decode = lambda xx: ''.join([itos[x] for x in xx]) | |
encode("Hello!") | |
print(decode(encode("Hello!"))) | |
# train and test splits | |
data = torch.tensor(encode(text), dtype=torch.long) | |
n = int(len(data)*0.9) | |
train_data = data[:n] | |
val_data = data[n:] | |
def get_batch(split): | |
# generate targets and context | |
if split == "train": | |
data = train_data | |
else: | |
data = val_data | |
index = torch.randint(0,len(data)-block_size,(batch_size,)) | |
x = torch.stack([data[ind:ind+block_size] for ind in index]) | |
y = torch.stack([data[ind+1:ind+block_size+1] for ind in index]) | |
return x.to(device),y.to(device) | |
@torch.no_grad() | |
def estimate_loss(): | |
out = {} | |
model.eval() | |
for split in ['train', 'test']: | |
losses = torch.zeros(eval_iters) | |
for k in range(eval_iters): | |
X,Y = get_batch(split) | |
logits, loss = model(X,Y) | |
losses[k] = loss.item() | |
out[split] = losses.mean() | |
model.train() | |
return out | |
class LayerNorm(nn.Module): | |
def __init__(self, dim) -> None: | |
super().__init__() | |
self.eps = 1e-5 | |
# params | |
self.gamma = nn.Parameter(torch.ones(dim)) | |
self.beta = nn.Parameter(torch.zeros(dim)) | |
def forward(self, x): | |
xmean = x.mean(dim=1, keepdim=True) | |
xvar = ((x-xmean)**2).mean(dim=1, keepdim=True) | |
xhat = (x-xmean) / torch.sqrt(xvar + self.eps) | |
self.out = self.gamma * xhat + self.beta | |
return self.out | |
def parameters(self): | |
return [self.gamma, self.beta] | |
class FeedForward(nn.Module): | |
def __init__(self, n_embed) -> None: | |
super().__init__() | |
self.ffn = nn.Sequential( | |
nn.Linear(n_embed, 4*n_embed), | |
nn.ReLU(), | |
nn.Linear(4*n_embed, n_embed), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
return self.ffn(x) | |
class Block(nn.Module): | |
def __init__(self, n_embed, n_heads) -> None: | |
super().__init__() | |
self.head_size = n_embed // n_heads | |
self.sa_head = Mamba( | |
# This module uses roughly 3 * expand * d_model^2 parameters | |
d_model=n_embed, # Model dimension d_model | |
d_state=16, # SSM state expansion factor | |
d_conv=4, # Local convolution width | |
expand=2, # Block expansion factor | |
).to("cuda") | |
self.ffn = FeedForward(n_embed) | |
self.ln1 = nn.LayerNorm(n_embed) | |
self.ln2 = nn.LayerNorm(n_embed) | |
def forward(self, x): | |
x = x + self.sa_head(self.ln1(x)) | |
x = x + self.ffn(self.ln2(x)) | |
return x | |
class BigramNeuralNetwork(nn.Module): | |
def __init__(self,vocab_size): | |
super().__init__() | |
self.token_embedding_table = nn.Embedding(vocab_size,n_embed) | |
self.position_embedding_table = nn.Embedding(block_size,n_embed) | |
self.lm_head = nn.Linear(n_embed,vocab_size) | |
self.ffn = FeedForward(n_embed) | |
self.blocks = nn.Sequential(*[Block(n_embed,n_heads=n_heads) for _ in range(n_layers)]) | |
def forward(self, idx, targets=None): | |
# idx = idx[:,-block_size:] | |
B,T = idx.shape | |
tok_emb = self.token_embedding_table(idx) # (B,T,C_e) | |
pos_emb = self.position_embedding_table(torch.arange(T,device=device)) # (T,C_e) | |
x = tok_emb + pos_emb # (B,T,C_e) | |
x = self.blocks(x) # (B,T,C_e) | |
logits = self.lm_head(x) # (B,T,vocab_size) | |
if targets is None: | |
loss = None | |
else: | |
B,T,C = logits.shape | |
logits = logits.view(B*T,C) | |
targets = targets.view(B*T) | |
loss = F.cross_entropy(logits, targets) | |
logits = logits.view(B,T,C) | |
return logits, loss | |
def generate(self, idx, max_new_tokens): | |
# idx is (B,T) | |
idx_next = [] | |
for i in range(max_new_tokens): | |
idx_cond = idx[:,-block_size:] | |
logits, loss = self(idx_cond) | |
last_timestep = logits[:,-1,:] | |
probs = F.softmax(last_timestep, dim=1) | |
next_index = torch.multinomial(probs, num_samples=1) | |
idx = torch.cat((idx, next_index), dim=1) | |
for arr in idx: | |
print(decode(arr.cpu().detach().numpy())) | |
return idx | |
model = BigramNeuralNetwork(vocab_size) | |
optimizer = torch.optim.AdamW(model.parameters(),lr=lr) | |
# checkpoint = torch.load('model.pt') | |
# model.load_state_dict(checkpoint['model_state_dict']) | |
# optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
# epoch = checkpoint['epoch'] | |
checkpoint_path = None#"./differentattention/model_40.pt" | |
epoch = 0 | |
if checkpoint_path: | |
checkpoint = torch.load(checkpoint_path) | |
print(checkpoint) | |
if checkpoint['model_state_dict']: | |
model.load_state_dict(checkpoint['model_state_dict'].to(device)) | |
if checkpoint['optimizer_state_dict']: | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
epoch = checkpoint['epoch'] | |
device = "cuda" | |
m = model.to(device) | |
print("Uses device " + device) | |
MODEL_CHECKPOINT = "./differentattention/model_{iter}.pt" | |
losses_data = {"train":[], "test":[]} | |
for iter in tqdm(range(epoch ,max_iters)): | |
if iter % eval_iters == 0: | |
losses = estimate_loss() | |
losses_data['train'].append(losses['train'].cpu().numpy()) | |
losses_data['test'].append(losses['test'].cpu().numpy()) | |
print(f"Step {iter}, train loss:{losses['train']:.4f}, test loss:{losses['test']:.4f}") | |
if iter % print_iters == 0: | |
losses = estimate_loss() | |
torch.save({ | |
'epoch': iter, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'loss': losses, | |
}, MODEL_CHECKPOINT.format(iter=iter)) | |
losses_data['train'].append(losses['train'].cpu().numpy()) | |
losses_data['test'].append(losses['test'].cpu().numpy()) | |
model.eval() | |
with torch.no_grad(): | |
#Generate from the model: | |
output = m.generate(torch.zeros((1,2), dtype=torch.long).to(device).contiguous(), 1000 )[0].tolist() | |
print(f"Step {iter}, train loss:{losses['train']:.4f}, test loss:{losses['test']:.4f}") | |
model.train() | |
#Get data | |
xb,yb = get_batch("train") | |
#Evaluate loss | |
logits,loss = model(xb,yb) | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
torch.save(model.state_dict(), "./differentattention/model.pt") | |
#Generate from the model: | |
output = m.generate(torch.zeros((1,2), dtype=torch.long).to(device),1000)[0].tolist() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment