Last active
May 30, 2025 18:38
-
-
Save amoudgl/b1bd6027f4086e9af6bd32c2ddd483c2 to your computer and use it in GitHub Desktop.
modula GPT tutorial script with seed exposed as a command line argument
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Adapted from modula's Hello GPT tutorial: | |
https://github.com/modula-systems/modula/blob/ede2ba72a1b9de3e1f44156db058b5c32c682941/examples/hello-gpt.ipynb | |
This script simply exposes dataloader seed as command line argument to test | |
training sensitivity to seed. | |
Usage: | |
python hello-gpt.py --seed 0 | |
""" | |
from absl import app, flags | |
flags.DEFINE_integer("seed", 0, "training seed") | |
flags.DEFINE_integer("log_interval", 10, "log interval") | |
flags.DEFINE_integer("val_interval", 200, "val interval") | |
FLAGS = flags.FLAGS | |
def main(unused_argv): | |
# First, let's download the Shakespeare dataset. The task will be to predict the next character. | |
context = 64 | |
batch_size = 12 | |
seed = FLAGS.seed | |
from data.shakespeare import load_shakespeare | |
data = load_shakespeare(context, batch_size, seed=seed) | |
train_loader = data["train_loader"] | |
val_loader = data["val_loader"] | |
encode = data["encode"] | |
decode = data["decode"] | |
# Let's peek at an example to verify the data loaded correctly! | |
for inputs, targets in train_loader: | |
print("Input shape:", inputs.shape) | |
print("Target shape:", targets.shape) | |
print("First input sequence:", inputs[0][:10], "...") | |
print("First target sequence:", targets[0][:10], "...") | |
print("\nDecoded input:", decode(inputs[0])) | |
print("\nDecoded target:", decode(targets[0])) | |
break | |
# Transformer hyperparameters | |
vocab_size = 65 | |
num_heads = 4 | |
d_embed = 128 | |
d_query = 32 | |
d_value = 32 | |
num_blocks = 4 | |
attention_scale = 1 | |
final_scale = 1 | |
# Training hyperparameters | |
lr = 0.1 | |
beta = 0.95 | |
steps = 2001 | |
log_interval = FLAGS.log_interval | |
val_interval = FLAGS.val_interval | |
val_iters = 20 | |
# Next up, we'll define the *attention* module and *residual blocks*. | |
# Attention in Modula | |
# In Modula, we'll define attention by stringing together several bond modules to do the parameterless computations. The roadmap is: | |
# * Map `(batch, token, d_embed)` into `(batch, head, token, d_query)` (and same for key and value) via `Linear` and `SplitIntoHeads` | |
# * Use Rotary Positional Embeddings (RoPE) on the query and the key via `Rope` | |
# * Map `query` and `key` into attention similarities of shape `(batch, head, token, token)` via `AttentionQK` | |
# * Use a causal mask and then softmax to create attention scores via `CausalMask` and `Softmax` | |
# * Use the attention scores to create output vectors via `ApplyAttentionScores`, then `MergeHeads` and `Linear` | |
# | |
# The main difference to a standard transformer is that `AttentionQK` uses $1/d_\text{head}$ scaling instead of the standard $1/\sqrt{d_\text{head}}$. The reason for this is to provide Lipschitz guarantees for attention that are independent of $d_\text{head}$. For more information on this, see Appendix B.6 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813). | |
# | |
# And here's the implementation: | |
from modula.atom import Linear | |
from modula.bond import SplitIntoHeads, MergeHeads, Rope, AttentionQK, CausalMask, Softmax, ApplyAttentionScores, GeLU | |
def Attention(num_heads, d_embed, d_query, d_value, attention_scale): | |
"""Multi-head attention""" | |
# For keys, queries, and values we add a heads dimension. For the out projection, we remove heads. | |
# Remember modules compose right-to-left, and the order is Linear(d_out, d_in)! And @ means compose. | |
Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed) | |
K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed) | |
V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed) | |
W = Linear(d_embed, num_heads * d_value) @ MergeHeads() | |
# Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose). | |
AttentionScores = Softmax(attention_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K) | |
# Read right-to-left: apply attention scores, multiply by 1/3 to fix the sensitivity to 1, project back to d_embed. | |
return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores) | |
# Let's check that the sensitivity is 1 at initialization. | |
# print(Attention(num_heads, d_embed, d_query, d_value, attention_scale)) | |
# ## Residual blocks in Modula | |
# | |
# To implement the rest of our transformer, the roadmap is: | |
# * Embed the input tokens | |
# * Apply residual blocks for attention and the MLP | |
# * Project out | |
# | |
# All that's left is to set up the residual blocks. In Modula, we define residual connections using a convex combination. If $L$ is the number of residual blocks, then we use a convex combination of the identity and the block to get $x \mapsto \frac{L-1}{L} \cdot x + \frac{1}{L} \cdot \textsf{block}(x)$. The purpose is to create a Lipschitz guarantee that is independent of the number of blocks. For more information, see Proposition 4 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813). | |
# | |
# In short, these changes enable Lipschitz guarantees on our transformer even as we scale the width and the depth! | |
from modula.abstract import Identity | |
from modula.atom import Embed | |
def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0): | |
# Set embed to have mass 1. This controls the proportion of feature learning that it contributes to the whole network. | |
embed = Embed(d_embed, vocab_size) | |
embed.tare() | |
# Let's create attention and MLP layers. | |
att = Attention(num_heads, d_embed, d_query, d_value, attention_scale) | |
mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed) | |
# For our residual connections, L = 2*num_blocks because each block has two residual connections. | |
att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att | |
mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp | |
# We can use powers of a module to compose it with itself many times! | |
blocks = (mlp_block @ att_block) ** num_blocks | |
# Set all transformer blocks to have mass 5 (by default). | |
# So 5/7 of the change in the network output is due to the blocks, | |
# and 2/7 of the change in output is due to the embedding and out projection. | |
blocks.tare(absolute=blocks_mass) | |
out = final_scale * Linear(vocab_size, d_embed) | |
return out @ blocks @ embed | |
# And finally we are ready to construct our GPT! | |
model = GPT( | |
vocab_size=vocab_size, | |
num_heads=num_heads, | |
d_embed=d_embed, | |
d_query=d_query, | |
d_value=d_value, | |
num_blocks=num_blocks, | |
attention_scale=attention_scale, | |
final_scale=final_scale, | |
) | |
model.jit() | |
print(model) | |
# ## Loss function and training | |
# | |
# To train our transformer we'll use cross entropy loss, which we can compute by decomposing the softmax: | |
# | |
# $$ | |
# -\log(\text{target probability}) = -\log(\text{softmax}(\text{logits})_\text{target}) = -\text{logit}_\text{target} + \text{log\,sum\,exp}(\text{logits}) | |
# $$ | |
import jax | |
import jax.numpy as jnp | |
def cross_entropy_loss(w, inputs, targets): | |
# We use the logsumexp trick for stable cross entropy | |
logits = model(inputs, w) # shape is [batch, seq_len, vocab_size] | |
batch_indices = jnp.arange(logits.shape[0])[:, None] # shape is [batch, 1] | |
seq_indices = jnp.arange(logits.shape[1])[None, :] # shape is [1, seq_len] | |
# This indexing selects out logits[b, s, targets[b, s]], which is the target logit | |
losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1) # shape is [batch, seq_len] | |
return losses.mean() | |
loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss)) | |
# And we're ready to train! | |
key = jax.random.PRNGKey(0) | |
w = model.initialize(key) | |
step = 0 | |
momentum = [0 * weight for weight in w] | |
lr_schedule = lambda step: lr * (steps - step) / steps | |
for inputs, targets in train_loader: | |
loss, grad_w = loss_and_grad(w, inputs, targets) | |
momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)] | |
d_w = model.dualize(momentum) | |
w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)] | |
if step % log_interval == 0: | |
print(f"Step {step}: loss {loss}") | |
if step % val_interval == 0: | |
val_losses = [] | |
for val_inputs, val_targets in val_loader: | |
loss, _ = loss_and_grad(w, val_inputs, val_targets) | |
val_losses.append(loss) | |
if len(val_losses) >= val_iters: | |
break | |
print(f"[Seed {seed}] Step {step} --> val loss {sum(val_losses)/len(val_losses)}") | |
step += 1 | |
if step >= steps: | |
break | |
print('='*100) | |
if __name__ == "__main__": | |
app.run(main) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Adapted from modula examples: | |
https://github.com/modula-systems/modula/blob/ede2ba72a1b9de3e1f44156db058b5c32c682941/examples/data/shakespeare.py | |
This script adds `seed` as an optional argument to `load_shakespeare` method | |
instead of keeping it hardcoded to 0 as in the original script. | |
""" | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import os | |
import pickle | |
import requests | |
from typing import Tuple, Dict, Any, Iterator | |
class TokenDataset: | |
"""JAX dataset for uint16 tokens.""" | |
def __init__(self, data_path: str, context_length: int): | |
self.data = np.memmap(data_path, dtype=np.uint16, mode='r') | |
self.context_length = context_length | |
self._length = len(self.data) - self.context_length - 1 | |
def __getitem__(self, idx): | |
input_seq = jnp.array(self.data[idx:idx+self.context_length].astype(np.int32)) | |
target_seq = jnp.array(self.data[idx+1:idx+self.context_length+1].astype(np.int32)) | |
return input_seq, target_seq | |
def __len__(self): | |
return self._length | |
class DataLoader: | |
"""JAX dataloader for uint16 tokens.""" | |
def __init__(self, dataset: TokenDataset, batch_size: int, shuffle: bool = False, drop_last: bool = True, seed: int = 0): | |
"""Initialize the dataloader. | |
Args: | |
dataset: Dataset to load from | |
batch_size: Number of samples per batch | |
shuffle: Whether to shuffle the dataset | |
drop_last: Whether to drop the last incomplete batch | |
seed: Random seed for shuffling | |
""" | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.drop_last = drop_last | |
self.key = jax.random.PRNGKey(seed) | |
def __iter__(self) -> Iterator[Tuple[jnp.ndarray, jnp.ndarray]]: | |
"""Create an iterator over the dataset.""" | |
indices = jnp.arange(len(self.dataset)) | |
if self.shuffle: | |
self.key, subkey = jax.random.split(self.key) | |
indices = jax.random.permutation(subkey, indices) | |
# Calculate number of batches | |
if self.drop_last: | |
num_batches = len(self.dataset) // self.batch_size | |
else: | |
num_batches = (len(self.dataset) + self.batch_size - 1) // self.batch_size | |
for i in range(num_batches): | |
start_idx = i * self.batch_size | |
end_idx = min(start_idx + self.batch_size, len(self.dataset)) | |
batch_indices = indices[start_idx:end_idx] | |
# Get samples for this batch | |
xs, ys = [], [] | |
for idx in batch_indices: | |
x, y = self.dataset[int(idx)] | |
xs.append(x) | |
ys.append(y) | |
# Stack into batch | |
x_batch = jnp.stack(xs) | |
y_batch = jnp.stack(ys) | |
yield x_batch, y_batch | |
def download_shakespeare_data(data_dir: str) -> None: | |
"""Download and prepare the Shakespeare dataset if it doesn't exist. | |
Adapted from Karpathy's nanoGPT: https://github.com/karpathy/nanogpt | |
Args: | |
data_dir: Directory to store the Shakespeare data | |
""" | |
if not os.path.exists(data_dir): | |
os.makedirs(data_dir) | |
# Download the tiny shakespeare dataset | |
input_file_path = os.path.join(data_dir, 'input.txt') | |
if not os.path.exists(input_file_path): | |
print("Downloading Shakespeare dataset...") | |
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' | |
with open(input_file_path, 'w') as f: | |
f.write(requests.get(data_url).text) | |
# Check if processed files already exist | |
if (os.path.exists(os.path.join(data_dir, 'train.bin')) and | |
os.path.exists(os.path.join(data_dir, 'val.bin')) and | |
os.path.exists(os.path.join(data_dir, 'meta.pkl'))): | |
return | |
print("Processing Shakespeare dataset...") | |
with open(input_file_path, 'r') as f: | |
data = f.read() | |
print(f"Length of dataset in characters: {len(data):,}") | |
# Get all the unique characters that occur in this text | |
chars = sorted(list(set(data))) | |
vocab_size = len(chars) | |
print(f"Vocabulary size: {vocab_size:,}") | |
# Create a mapping from characters to integers | |
stoi = {ch:i for i,ch in enumerate(chars)} | |
itos = {i:ch for i,ch in enumerate(chars)} | |
# Create the train and test splits | |
n = len(data) | |
train_data = data[:int(n*0.9)] | |
val_data = data[int(n*0.9):] | |
# Encode both to integers | |
train_ids = [stoi[c] for c in train_data] | |
val_ids = [stoi[c] for c in val_data] | |
print(f"Train has {len(train_ids):,} tokens") | |
print(f"Val has {len(val_ids):,} tokens") | |
# Export to bin files | |
train_ids = np.array(train_ids, dtype=np.uint16) | |
val_ids = np.array(val_ids, dtype=np.uint16) | |
train_ids.tofile(os.path.join(data_dir, 'train.bin')) | |
val_ids.tofile(os.path.join(data_dir, 'val.bin')) | |
# Save the meta information as well, to help us encode/decode later | |
meta = { | |
'vocab_size': vocab_size, | |
'itos': itos, | |
'stoi': stoi, | |
} | |
with open(os.path.join(data_dir, 'meta.pkl'), 'wb') as f: | |
pickle.dump(meta, f) | |
print("Shakespeare dataset processing complete.") | |
def load_shakespeare(context_length: int, batch_size: int, shuffle: bool = True, seed: int = 0) -> Dict[str, Any]: | |
"""Load the Shakespeare dataset and create dataloaders. | |
Args: | |
context_length: Length of context window for prediction | |
batch_size: Number of samples per batch | |
shuffle: Whether to shuffle the training data | |
Returns: | |
Dictionary containing train_loader, val_loader, and meta information | |
""" | |
# Determine the data directory | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
data_dir = os.path.join(script_dir, 'shakespeare') | |
# Check if the Shakespeare data exists, download if it doesn't | |
download_shakespeare_data(data_dir) | |
# Load meta information | |
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f: | |
meta = pickle.load(f) | |
# Create datasets | |
train_dataset = TokenDataset(os.path.join(data_dir, 'train.bin'), context_length) | |
val_dataset = TokenDataset(os.path.join(data_dir, 'val.bin'), context_length) | |
# Create dataloaders | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, seed=seed) | |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, seed=seed) | |
return { | |
'train_loader': train_loader, | |
'val_loader': val_loader, | |
'meta': meta, | |
'vocab_size': meta['vocab_size'], | |
'encode': lambda s: [meta['stoi'][c] for c in s], | |
'decode': lambda l: ''.join([meta['itos'][int(i)] for i in l]) | |
} | |
# Example usage | |
if __name__ == "__main__": | |
# Load the data with context length of 8 and batch size of 4 | |
data = load_shakespeare(context_length=8, batch_size=4) | |
# Get the first batch from the training loader | |
for x_batch, y_batch in data['train_loader']: | |
print("Input shape:", x_batch.shape) | |
print("Target shape:", y_batch.shape) | |
# Print the first sequence in the batch | |
print("First input sequence:", x_batch[0]) | |
print("First target sequence:", y_batch[0]) | |
# Decode the first sequence | |
print("Decoded input:", data['decode'](x_batch[0])) | |
print("Decoded target:", data['decode'](y_batch[0])) | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment