Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active June 24, 2024 06:59
Show Gist options
  • Save thistleknot/a5f629b71f77dc1682c8c54cc48ef770 to your computer and use it in GitHub Desktop.
Save thistleknot/a5f629b71f77dc1682c8c54cc48ef770 to your computer and use it in GitHub Desktop.
Mamba Gpt w Sub Word tokenizer
# -*- coding: utf-8 -*-
"""SimplerMambaSSM.ipynb
Automatically generated by Colaboratory.
#pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
Original file is located at
https://colab.research.google.com/drive/1g9qpeVcFa0ca0cnhmqusO4RZtQdh9umY
"""
#!pip install mamba-ssm causal-conv1d
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
#!mkdir differentattention
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from tqdm import tqdm
from mamba_ssm import Mamba
import nltk
import pandas as pd
from tokenizers import Tokenizer, models, trainers
from nltk.corpus import brown
import math
from transformers import AutoTokenizer
import random
from sklearn.model_selection import train_test_split
import numpy as np
import os
from datasets import load_dataset
import wandb
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import GradScaler, autocast
os.environ["WANDB_MODE"] = "offline"
wandb.init(project="Selective State Space Attention")
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Uses device " + device)
# hyperparams
# 1 = 100% # of training records are sampled, not that 100% are selected
# for, but 100% of len is sampled, i.e. batches * len(train_data) = 1
# epoch
#16gb vram
#max settings, epochs 40, bs 1, bs 4096, n_emb 1536, n_heads 32, n_layers 13, 24 hrs
#max settings, epochs 10, bs 1, bs 4096, n_emb 1536, n_heads 32, n_layers 13, 6 hrs
#test settings, epochs 40, bs 4, bs 1024, n_emb 1536, n_heads 32, n_layers 13, 1.5 hrs
custom_tokenizer=False
epochs = 40
# Define parameters for learning rate schedule
peak_lr = 1e-3 # Set your peak learning rate
initial_lr = peak_lr / 10
desired_lr = peak_lr / 100 # Set your final learning rate
batch_size = 8
# 1024@20 (45 min) 15GB VRAM
block_size = 512
max_token_len = block_size // 2 # Example stride
eval_iters = 10
num_evals = 10
# eval_interval = 300
n_embed = 1536
n_heads = 32
n_layers = 13
dropout = 0.2
WEIGHT_DECAY=0.1
#used for checkpoint load
epoch = 0
grad_clip = 1.0
half=False
sample=True
# Create an empty list to store the results
block_sizes_list = []
# Initialize mod_block_size with the initial block_size
mod_block_size = block_size
# Use a loop to divide mod_block_size by 2 until it reaches 4
while mod_block_size >= 4:
# Append the current block_size to the list
block_sizes_list.append(mod_block_size)
# Update mod_block_size by dividing it by 2
mod_block_size //= 2
# Print the resulting list of block sizes
print(block_sizes_list)
losses_data = {"train": [], "test": []}
best_model_path = "./best_model.pt"
checkpoint_path = None # "./differentattention/model_40.pt"
if(checkpoint_path):
checkpoint = torch.load('model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
if checkpoint_path:
checkpoint = torch.load(checkpoint_path)
print(checkpoint)
if checkpoint["model_state_dict"]:
model.load_state_dict(checkpoint["model_state_dict"].to(device))
if checkpoint["optimizer_state_dict"]:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
nltk.download("brown")
essays_ = essays = load_dataset("qwedsacf/ivypanda-essays")
books_ = load_dataset("suolyer/pile_books3")
wiki_ = load_dataset("EleutherAI/wikitext_document_level","wikitext-103-v1")
quotes_ = load_dataset("Abirate/english_quotes")
# Early stopping parameters
best_perplexity = float("inf")
evaluations_since_improvement = 0
# Initialize and train the tokenizer
class SelfAttentionHead(nn.Module):
def __init__(self, head_size):
super().__init__()
self.keys = nn.Linear(n_embed, head_size)
self.queries = nn.Linear(n_embed, head_size)
self.values = nn.Linear(n_embed, head_size)
self.head_size = head_size
self.n_embed = n_embed
self.register_buffer(
"tril", torch.tril(torch.ones((block_size, block_size))).to(device)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
k = self.keys(x) # (B,T,C_h)
q = self.queries(x) # (B,T,C_h)
v = self.values(x) # (B,T,C_h)
wei = k @ q.transpose(-1, -2) * C ** (-0.5) # (B,T,T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
# wei = F.softmax(wei, dim=-1) # (B,T,T)
wei = torch.log(torch.exp(wei) + 1) # (B,T,T)
wei = self.dropout(wei)
out = wei @ v # (B,T,C_h)
return out
class LayerNorm(nn.Module):
def __init__(self, dim) -> None:
super().__init__()
self.eps = 1e-5
# params
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x):
xmean = x.mean(dim=1, keepdim=True)
xvar = ((x - xmean) ** 2).mean(dim=1, keepdim=True)
xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
self.out = self.gamma * xhat + self.beta
return self.out
def parameters(self):
return [self.gamma, self.beta]
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, head_size) -> None:
super().__init__()
self.heads = nn.ModuleList(
[SelfAttentionHead(head_size) for _ in range(n_heads)]
)
self.proj = nn.Linear(n_embed, n_embed)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
out = torch.cat([head(x) for head in self.heads], dim=-1)
out = self.proj(out)
out = self.dropout(out)
return out
class FeedForward(nn.Module):
def __init__(self, n_embed) -> None:
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(n_embed, 4 * n_embed),
nn.ReLU(),
nn.Linear(4 * n_embed, n_embed),
nn.Dropout(dropout),
)
def forward(self, x):
return self.ffn(x)
class Block(nn.Module):
def __init__(self, n_embed, n_heads) -> None:
super().__init__()
self.head_size = n_embed // n_heads
# self.sa_head = MultiHeadAttention(n_heads, self.head_size)
self.sa_head = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=n_embed, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
self.ffn = FeedForward(n_embed)
self.ln1 = nn.LayerNorm(n_embed)
self.ln2 = nn.LayerNorm(n_embed)
def forward(self, x):
x = x + self.sa_head(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class BigramNeuralNetwork(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
self.position_embedding_table = nn.Embedding(block_size, n_embed)
self.sa_head = MultiHeadAttention(4, int(n_embed / 4))
self.lm_head = nn.Linear(n_embed, vocab_size)
self.ffn = FeedForward(n_embed)
self.blocks = nn.Sequential(
*[Block(n_embed, n_heads=n_heads) for _ in range(n_layers)]
)
def forward(self, idx, targets=None):
# idx = idx[:,-block_size:]
B, T = idx.shape
tok_emb = self.token_embedding_table(idx) # (B,T,C_e)
pos_emb = self.position_embedding_table(
torch.arange(T, device=device)
) # (T,C_e)
x = tok_emb + pos_emb # (B,T,C_e)
x = self.blocks(x) # (B,T,C_e)
logits = self.lm_head(x) # (B,T,vocab_size)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B * T, C)
targets = targets.view(B * T)
loss = F.cross_entropy(logits, targets)
logits = logits.view(B, T, C)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B,T)
idx_next = []
for i in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, loss = self(idx_cond)
last_timestep = logits[:, -1, :]
probs = F.softmax(last_timestep, dim=1)
next_index = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_index), dim=1)
for arr in idx:
print(decode(arr.cpu().detach().numpy()))
return idx
def chunk_data_with_stride(data, block_size, stride):
# Create chunks using strides for overlapping sequences
chunks = []
for i in range(0, len(data) - block_size, stride):
chunk = data[i: i + block_size]
if len(chunk) >= block_size:
chunks.append(chunk)
return chunks
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def encode(text):
return tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze()
def decode(token_ids):
return tokenizer.decode(token_ids, skip_special_tokens=True)
# Create a lambda function for the warm-up and cosine annealing schedule
def lr_lambda(iter):
if iter < warmup_iterations:
result = initial_lr + (peak_lr - initial_lr) * (iter + 1) / warmup_iterations
else:
result = desired_lr + 0.5 * (peak_lr - desired_lr) * (1 + math.cos(math.pi * (iter - warmup_iterations) / (max_iters - warmup_iterations)))
#print("result",result)
return result
def create_batches(records, lengths, block_size, block_sizes_list):
"""
Create batches of records where the total length of each batch is as close to the block size as possible without exceeding it.
:param records: List of records.
:param lengths: Corresponding lengths of the records.
:param block_size: Maximum allowed size for each batch.
:param block_sizes_list: List of block sizes to consider in descending order.
:return: List of batches, each batch is a list of records.
"""
# Create a dictionary of records and their lengths
record_dict = {i: {"record": records[i], "length": lengths[i]} for i in range(len(records))}
sequences = []
selected_indices = set()
while len(selected_indices) < len(records):
sequence = []
avail_space = block_size
for n in block_sizes_list:
while avail_space >= n:
# Find indices of records that fit and have not been selected
avail_indices = [i for i in range(len(records)) if record_dict[i]["length"] <= n and i not in selected_indices]
if avail_indices:
# Randomly select one record that fits
selected_index = random.choice(avail_indices)
sequence.append(record_dict[selected_index]["record"])
avail_space -= record_dict[selected_index]["length"]
selected_indices.add(selected_index)
else:
# No more records fit in this category, move to smaller size
break
if avail_space < min(block_sizes_list): # No smaller sizes available
break
sequences.append(sequence)
return sequences
# Function to get batches
def get_batch(data):
index = torch.randint(0, len(data) - block_size, (batch_size,))
x = torch.stack([data[ind: ind + block_size] for ind in index])
y = torch.stack([data[ind + 1: ind + block_size + 1] for ind in index])
return x.to(device), y.to(device)
# Function to estimate loss
@torch.no_grad()
def estimate_loss(X, Y):
model.eval()
logits, loss = model(X, Y)
perplexity = torch.exp(loss).item()
model.train()
return [loss.item(), perplexity]
brown_ = [" ".join(brown.words(fileid)) for fileid in brown.fileids()]
wiki = [p for p in wiki_["train"]["page"]]
books = [b['text'] for b in books_['validation']]
essays = essays_['train']['TEXT']
quotes = [item['quote'] for item in quotes_['train']]
if(custom_tokenizer):
corpus_text = "\n".join([" ".join(brown.words(fileid)) for fileid in brown.fileids()])
tokenizer = Tokenizer(models.BPE())
trainer = trainers.BpeTrainer(
special_tokens=["<unk>", "<pad>", "<bos>", "<eos>"], vocab_size=32000
)
tokenizer.train_from_iterator(corpus_text.split("\n"), trainer=trainer)
vocab_size = tokenizer.get_vocab_size()
# Define encoding and decoding functions using the Hugging Face tokenizer
# Concatenate documents from the Brown Corpus with BOS and EOS tokens
bos_token = "<BOS>"
eos_token = "<EOS>"
else:
# Load pre-trained tokenizer
tokenizer = AutoTokenizer.from_pretrained("v1olet/v1olet_marcoroni-go-bruins-merge-7B")
#pad_token_id = tokenizer.eos_token_id
vocab_size = tokenizer.vocab_size
# Define encoding and decoding functions using the pre-trained tokenizer
# Prepare data
bos_token = tokenizer.bos_token if tokenizer.bos_token else "<BOS>"
eos_token = tokenizer.eos_token if tokenizer.eos_token else "<EOS>"
#old way, used for eta
#records_joined = "".join([bos_token + r + eos_token for r in records_])
if(sample):
#records_ = [*brown_,*random.sample(wiki, len(brown_)),*random.sample(books, 10),*random.sample(essays, len(brown_))]
records_ = [*brown_,*random.sample(wiki, len(brown_)),*random.sample(essays, len(brown_)),*random.sample(quotes, len(brown_)*2)]
#records_ = filtered_quotes
else:
records_ = [*brown_,*wiki_,*books, *essays, *quotes]
random.shuffle(records_)
tokenized_corpus = tokenizer.batch_encode_plus(records_)['input_ids']
tokenized_records = [[*t, tokenizer.eos_token_id] for t in tokenized_corpus]
#filtered = [f for f in tokenized_records if len(f) <= max_token_len and len(f) >= 3]
filtered = [f for f in tokenized_records if len(f) <= block_size]
#used for len of train_sequences_ (old way)
#tokenized_joined = np.hstack(filtered)
#data = torch.tensor(tokenized_joined, dtype=torch.long)
#strided_sequences = chunk_data_with_stride(data, block_size, stride)
#train_sequences_, val_sequences_ = train_test_split(strided_sequences, train_size=0.9)
#actual training data
train_tokenized, val_tokenized = train_test_split(filtered, train_size=0.9)
train_lens = lens = [len(t) for t in train_tokenized]
val_lens = lens = [len(t) for t in val_tokenized]
sampled_train = create_batches(train_tokenized, train_lens, block_size, block_sizes_list)
#len
max_iters = int(np.round(len(sampled_train) / batch_size) * epochs)
print("max_iters",max_iters)
print_iters = int(np.round(max_iters/epochs))
warmup_iterations = int(np.round(max_iters / 20))
print("warmup_iterations",warmup_iterations)
#print("# strided sequences:", len(train_sequences_)+len(val_sequences_))
print(len(sampled_train))
print(batch_size)
print(epochs)
print(len(sampled_train) / batch_size)
print((len(sampled_train) / batch_size) * epochs)
# 1 epoch
check_perplexity_iter = int(np.round(max_iters / epochs))
# has to fail 1 epochs worth of eval's consecutively
patience = int(np.ceil(int(np.round(max_iters / epochs / eval_iters))))
model = BigramNeuralNetwork(vocab_size)
# Update the optimizer creation to include initial learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=WEIGHT_DECAY)
# Create the learning rate scheduler
#scheduler = LambdaLR(optimizer, lr_lambda)
model = model.to(device)
#used with half=True
scaler = GradScaler()
total_params = count_parameters(model)
print(f"Total number of parameters in the model: {total_params}")
def precalculate_lengths(tokenized_sequences):
return [len(seq) for seq in tokenized_sequences]
# Precalculate lengths for training and validation data
train_lengths = precalculate_lengths(train_tokenized)
val_lengths = precalculate_lengths(val_tokenized)
#pad_token_id = tokenizer.eos_token_id
#pad_token = tokenizer.eos_token
for iter in tqdm(range(epoch, max_iters)):
train_data = create_batches(train_tokenized, train_lens, block_size, block_sizes_list)
# Calculate the sum of lengths for each batch
train_batch_lengths = [np.sum([len(l) for l in b]) for b in train_data]
# Find the indices of the top n batches
top_train_n_indices = np.argsort(train_batch_lengths)[-batch_size:]
# Select the top n batches
top_train_batches = [train_data[i] for i in top_train_n_indices]
#print(np.sort([np.sum([len(l) for l in l_]) for l_ in train_data]))
train_data_flat = [np.hstack(t) for t in top_train_batches]
train_padded = torch.cat([torch.tensor([*t, *np.repeat(tokenizer.eos_token_id,block_size-len(t))]) for t in train_data_flat],dim=0)
val_data = create_batches(train_tokenized, train_lens, block_size, block_sizes_list)
# Calculate the sum of lengths for each batch
val_batch_lengths = [np.sum([len(l) for l in b]) for b in val_data]
# Find the indices of the top n batches
top_val_n_indices = np.argsort(val_batch_lengths)[-batch_size:]
# Select the top n batches
top_val_batches = [val_data[i] for i in top_val_n_indices]
#print(np.sort([np.sum([len(l) for l in l_]) for l_ in train_data]))
val_data_flat = [np.hstack(t) for t in top_val_batches]
val_padded = torch.cat([torch.tensor([*t, *np.repeat(tokenizer.eos_token_id,block_size-len(t))]) for t in val_data_flat],dim=0)
current_lr = optimizer.param_groups[0]['lr']
#scheduler_lr = scheduler.get_last_lr()[0]
#print(scheduler_lr)
x_tr, y_tr = get_batch(train_padded)
x_te, y_te = get_batch(val_padded)
#eval loss
if iter % eval_iters == 0:
losses = estimate_loss(x_te, y_te)
loss = losses[0]
perplexity = losses[1]
losses_data["test"].append(loss)
print(f"Step {iter}, current_lr:{current_lr:.10f}, test loss:{loss:.4f}, perplexity:{perplexity:.4f}")
wandb.log(
{
"iteration": iter,
"eval_loss": loss,
"perplexity": perplexity,
"current_lr": current_lr,
"evaluations_since_improvement": evaluations_since_improvement,
})
# Start checking for perplexity after a certain number of
# iterations
if iter >= check_perplexity_iter:
MODEL_CHECKPOINT = "./differentattention/model_{iter}_{best_perplexity}.pt"
# Check for improvement and save the best model
if perplexity < best_perplexity:
best_perplexity = perplexity
evaluations_since_improvement = 0
torch.save(
{
"epoch": iter,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": losses,
},
MODEL_CHECKPOINT.format(iter=iter, best_perplexity=best_perplexity),
) # Save the best model
else:
evaluations_since_improvement += 1
# Early stopping check
#if evaluations_since_improvement >= patience:
#print(f"Early stopping triggered at iteration {iter}, perplexity: {perplexity}")
#break
# Train loss
else:
# Prepare data for half precision if needed
# Forward pass
if half:
# Forward pass with autocast for mixed precision
with autocast():
logits, loss = model(x_tr, y_tr)
else:
# Full precision forward pass
logits, loss = model(x_tr, y_tr)
# Zero gradients before backward pass
optimizer.zero_grad(set_to_none=True)
# Backward pass and optimizer step
if half:
# Scale loss and perform backward pass in mixed precision
scaler.scale(loss).backward()
# Unscale gradients and clip
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# Optimizer step with scaled gradients
scaler.step(optimizer)
# Update scaler for next iteration
scaler.update()
else:
# Backward pass and clipping for full precision
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# Optimizer step
optimizer.step()
#scheduler.step() # Uncomment if you are using a learning rate scheduler
# Print loss
# Convert loss to float for printing, as it might be in half precision
print_loss = loss.item() if not half else loss.float().item()
# Correct this line
losses_data["train"].append(print_loss) # Change from "test" to "train"
print(f"Step {iter}, current_lr:{current_lr:.10f}, train loss:{print_loss:.4f}")
# Log metrics to wandb
wandb.log({
"iteration": iter,
"train_loss": print_loss,
"current_lr": current_lr,
"evaluations_since_improvement": evaluations_since_improvement,
})
if(False):
if iter % print_iters == 0:
model.eval()
with torch.no_grad():
# Generate from the model:
output = model.generate(torch.zeros((1, 2), dtype=torch.long).to(device).contiguous(), 1000)[0].tolist()
print(output)
model.train()
# Finish wandb run
wandb.finish()
torch.save(model.state_dict(), "./differentattention/model.pt")
# Generate from the model:
output = model.generate(
torch.zeros(
(1,
2),
dtype=torch.long).to(device),
1000)[0].tolist()
print("Training Losses:")
for i, loss in enumerate(losses_data["train"]):
print(f"Iteration {i}: Train Loss = {loss}")
# Print Testing Losses
print("\nTesting Losses:")
for i, loss in enumerate(losses_data["test"]):
print(f"Iteration {i * eval_iters}: Eval Loss = {loss}") # Assuming eval_iters is your testing interval
# Convert your data to a pandas DataFrame
df = pd.DataFrame({'Iteration': [i * eval_iters for i in range(len(losses_data["test"]))],
'Eval Loss': losses_data["test"]})
# Specify your CSV file path
csv_file_path = 'testing_losses.csv'
# Save to CSV
df.to_csv(csv_file_path, index=False)
print(f"Testing losses have been saved to {csv_file_path}")
@thistleknot
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment