Skip to content

Instantly share code, notes, and snippets.

@amoudgl
Last active May 30, 2025 18:38
Show Gist options
  • Save amoudgl/b1bd6027f4086e9af6bd32c2ddd483c2 to your computer and use it in GitHub Desktop.
Save amoudgl/b1bd6027f4086e9af6bd32c2ddd483c2 to your computer and use it in GitHub Desktop.
modula GPT tutorial script with seed exposed as a command line argument
"""
Adapted from modula's Hello GPT tutorial:
https://github.com/modula-systems/modula/blob/ede2ba72a1b9de3e1f44156db058b5c32c682941/examples/hello-gpt.ipynb
This script simply exposes dataloader seed as command line argument to test
training sensitivity to seed.
Usage:
python hello-gpt.py --seed 0
"""
from absl import app, flags
flags.DEFINE_integer("seed", 0, "training seed")
flags.DEFINE_integer("log_interval", 10, "log interval")
flags.DEFINE_integer("val_interval", 200, "val interval")
FLAGS = flags.FLAGS
def main(unused_argv):
# First, let's download the Shakespeare dataset. The task will be to predict the next character.
context = 64
batch_size = 12
seed = FLAGS.seed
from data.shakespeare import load_shakespeare
data = load_shakespeare(context, batch_size, seed=seed)
train_loader = data["train_loader"]
val_loader = data["val_loader"]
encode = data["encode"]
decode = data["decode"]
# Let's peek at an example to verify the data loaded correctly!
for inputs, targets in train_loader:
print("Input shape:", inputs.shape)
print("Target shape:", targets.shape)
print("First input sequence:", inputs[0][:10], "...")
print("First target sequence:", targets[0][:10], "...")
print("\nDecoded input:", decode(inputs[0]))
print("\nDecoded target:", decode(targets[0]))
break
# Transformer hyperparameters
vocab_size = 65
num_heads = 4
d_embed = 128
d_query = 32
d_value = 32
num_blocks = 4
attention_scale = 1
final_scale = 1
# Training hyperparameters
lr = 0.1
beta = 0.95
steps = 2001
log_interval = FLAGS.log_interval
val_interval = FLAGS.val_interval
val_iters = 20
# Next up, we'll define the *attention* module and *residual blocks*.
# Attention in Modula
# In Modula, we'll define attention by stringing together several bond modules to do the parameterless computations. The roadmap is:
# * Map `(batch, token, d_embed)` into `(batch, head, token, d_query)` (and same for key and value) via `Linear` and `SplitIntoHeads`
# * Use Rotary Positional Embeddings (RoPE) on the query and the key via `Rope`
# * Map `query` and `key` into attention similarities of shape `(batch, head, token, token)` via `AttentionQK`
# * Use a causal mask and then softmax to create attention scores via `CausalMask` and `Softmax`
# * Use the attention scores to create output vectors via `ApplyAttentionScores`, then `MergeHeads` and `Linear`
#
# The main difference to a standard transformer is that `AttentionQK` uses $1/d_\text{head}$ scaling instead of the standard $1/\sqrt{d_\text{head}}$. The reason for this is to provide Lipschitz guarantees for attention that are independent of $d_\text{head}$. For more information on this, see Appendix B.6 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).
#
# And here's the implementation:
from modula.atom import Linear
from modula.bond import SplitIntoHeads, MergeHeads, Rope, AttentionQK, CausalMask, Softmax, ApplyAttentionScores, GeLU
def Attention(num_heads, d_embed, d_query, d_value, attention_scale):
"""Multi-head attention"""
# For keys, queries, and values we add a heads dimension. For the out projection, we remove heads.
# Remember modules compose right-to-left, and the order is Linear(d_out, d_in)! And @ means compose.
Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed)
W = Linear(d_embed, num_heads * d_value) @ MergeHeads()
# Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose).
AttentionScores = Softmax(attention_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K)
# Read right-to-left: apply attention scores, multiply by 1/3 to fix the sensitivity to 1, project back to d_embed.
return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores)
# Let's check that the sensitivity is 1 at initialization.
# print(Attention(num_heads, d_embed, d_query, d_value, attention_scale))
# ## Residual blocks in Modula
#
# To implement the rest of our transformer, the roadmap is:
# * Embed the input tokens
# * Apply residual blocks for attention and the MLP
# * Project out
#
# All that's left is to set up the residual blocks. In Modula, we define residual connections using a convex combination. If $L$ is the number of residual blocks, then we use a convex combination of the identity and the block to get $x \mapsto \frac{L-1}{L} \cdot x + \frac{1}{L} \cdot \textsf{block}(x)$. The purpose is to create a Lipschitz guarantee that is independent of the number of blocks. For more information, see Proposition 4 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).
#
# In short, these changes enable Lipschitz guarantees on our transformer even as we scale the width and the depth!
from modula.abstract import Identity
from modula.atom import Embed
def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0):
# Set embed to have mass 1. This controls the proportion of feature learning that it contributes to the whole network.
embed = Embed(d_embed, vocab_size)
embed.tare()
# Let's create attention and MLP layers.
att = Attention(num_heads, d_embed, d_query, d_value, attention_scale)
mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed)
# For our residual connections, L = 2*num_blocks because each block has two residual connections.
att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att
mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp
# We can use powers of a module to compose it with itself many times!
blocks = (mlp_block @ att_block) ** num_blocks
# Set all transformer blocks to have mass 5 (by default).
# So 5/7 of the change in the network output is due to the blocks,
# and 2/7 of the change in output is due to the embedding and out projection.
blocks.tare(absolute=blocks_mass)
out = final_scale * Linear(vocab_size, d_embed)
return out @ blocks @ embed
# And finally we are ready to construct our GPT!
model = GPT(
vocab_size=vocab_size,
num_heads=num_heads,
d_embed=d_embed,
d_query=d_query,
d_value=d_value,
num_blocks=num_blocks,
attention_scale=attention_scale,
final_scale=final_scale,
)
model.jit()
print(model)
# ## Loss function and training
#
# To train our transformer we'll use cross entropy loss, which we can compute by decomposing the softmax:
#
# $$
# -\log(\text{target probability}) = -\log(\text{softmax}(\text{logits})_\text{target}) = -\text{logit}_\text{target} + \text{log\,sum\,exp}(\text{logits})
# $$
import jax
import jax.numpy as jnp
def cross_entropy_loss(w, inputs, targets):
# We use the logsumexp trick for stable cross entropy
logits = model(inputs, w) # shape is [batch, seq_len, vocab_size]
batch_indices = jnp.arange(logits.shape[0])[:, None] # shape is [batch, 1]
seq_indices = jnp.arange(logits.shape[1])[None, :] # shape is [1, seq_len]
# This indexing selects out logits[b, s, targets[b, s]], which is the target logit
losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1) # shape is [batch, seq_len]
return losses.mean()
loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))
# And we're ready to train!
key = jax.random.PRNGKey(0)
w = model.initialize(key)
step = 0
momentum = [0 * weight for weight in w]
lr_schedule = lambda step: lr * (steps - step) / steps
for inputs, targets in train_loader:
loss, grad_w = loss_and_grad(w, inputs, targets)
momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)]
d_w = model.dualize(momentum)
w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)]
if step % log_interval == 0:
print(f"Step {step}: loss {loss}")
if step % val_interval == 0:
val_losses = []
for val_inputs, val_targets in val_loader:
loss, _ = loss_and_grad(w, val_inputs, val_targets)
val_losses.append(loss)
if len(val_losses) >= val_iters:
break
print(f"[Seed {seed}] Step {step} --> val loss {sum(val_losses)/len(val_losses)}")
step += 1
if step >= steps:
break
print('='*100)
if __name__ == "__main__":
app.run(main)
"""
Adapted from modula examples:
https://github.com/modula-systems/modula/blob/ede2ba72a1b9de3e1f44156db058b5c32c682941/examples/data/shakespeare.py
This script adds `seed` as an optional argument to `load_shakespeare` method
instead of keeping it hardcoded to 0 as in the original script.
"""
import jax
import jax.numpy as jnp
import numpy as np
import os
import pickle
import requests
from typing import Tuple, Dict, Any, Iterator
class TokenDataset:
"""JAX dataset for uint16 tokens."""
def __init__(self, data_path: str, context_length: int):
self.data = np.memmap(data_path, dtype=np.uint16, mode='r')
self.context_length = context_length
self._length = len(self.data) - self.context_length - 1
def __getitem__(self, idx):
input_seq = jnp.array(self.data[idx:idx+self.context_length].astype(np.int32))
target_seq = jnp.array(self.data[idx+1:idx+self.context_length+1].astype(np.int32))
return input_seq, target_seq
def __len__(self):
return self._length
class DataLoader:
"""JAX dataloader for uint16 tokens."""
def __init__(self, dataset: TokenDataset, batch_size: int, shuffle: bool = False, drop_last: bool = True, seed: int = 0):
"""Initialize the dataloader.
Args:
dataset: Dataset to load from
batch_size: Number of samples per batch
shuffle: Whether to shuffle the dataset
drop_last: Whether to drop the last incomplete batch
seed: Random seed for shuffling
"""
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.key = jax.random.PRNGKey(seed)
def __iter__(self) -> Iterator[Tuple[jnp.ndarray, jnp.ndarray]]:
"""Create an iterator over the dataset."""
indices = jnp.arange(len(self.dataset))
if self.shuffle:
self.key, subkey = jax.random.split(self.key)
indices = jax.random.permutation(subkey, indices)
# Calculate number of batches
if self.drop_last:
num_batches = len(self.dataset) // self.batch_size
else:
num_batches = (len(self.dataset) + self.batch_size - 1) // self.batch_size
for i in range(num_batches):
start_idx = i * self.batch_size
end_idx = min(start_idx + self.batch_size, len(self.dataset))
batch_indices = indices[start_idx:end_idx]
# Get samples for this batch
xs, ys = [], []
for idx in batch_indices:
x, y = self.dataset[int(idx)]
xs.append(x)
ys.append(y)
# Stack into batch
x_batch = jnp.stack(xs)
y_batch = jnp.stack(ys)
yield x_batch, y_batch
def download_shakespeare_data(data_dir: str) -> None:
"""Download and prepare the Shakespeare dataset if it doesn't exist.
Adapted from Karpathy's nanoGPT: https://github.com/karpathy/nanogpt
Args:
data_dir: Directory to store the Shakespeare data
"""
if not os.path.exists(data_dir):
os.makedirs(data_dir)
# Download the tiny shakespeare dataset
input_file_path = os.path.join(data_dir, 'input.txt')
if not os.path.exists(input_file_path):
print("Downloading Shakespeare dataset...")
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(input_file_path, 'w') as f:
f.write(requests.get(data_url).text)
# Check if processed files already exist
if (os.path.exists(os.path.join(data_dir, 'train.bin')) and
os.path.exists(os.path.join(data_dir, 'val.bin')) and
os.path.exists(os.path.join(data_dir, 'meta.pkl'))):
return
print("Processing Shakespeare dataset...")
with open(input_file_path, 'r') as f:
data = f.read()
print(f"Length of dataset in characters: {len(data):,}")
# Get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size:,}")
# Create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
# Create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]
# Encode both to integers
train_ids = [stoi[c] for c in train_data]
val_ids = [stoi[c] for c in val_data]
print(f"Train has {len(train_ids):,} tokens")
print(f"Val has {len(val_ids):,} tokens")
# Export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(data_dir, 'train.bin'))
val_ids.tofile(os.path.join(data_dir, 'val.bin'))
# Save the meta information as well, to help us encode/decode later
meta = {
'vocab_size': vocab_size,
'itos': itos,
'stoi': stoi,
}
with open(os.path.join(data_dir, 'meta.pkl'), 'wb') as f:
pickle.dump(meta, f)
print("Shakespeare dataset processing complete.")
def load_shakespeare(context_length: int, batch_size: int, shuffle: bool = True, seed: int = 0) -> Dict[str, Any]:
"""Load the Shakespeare dataset and create dataloaders.
Args:
context_length: Length of context window for prediction
batch_size: Number of samples per batch
shuffle: Whether to shuffle the training data
Returns:
Dictionary containing train_loader, val_loader, and meta information
"""
# Determine the data directory
script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(script_dir, 'shakespeare')
# Check if the Shakespeare data exists, download if it doesn't
download_shakespeare_data(data_dir)
# Load meta information
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
meta = pickle.load(f)
# Create datasets
train_dataset = TokenDataset(os.path.join(data_dir, 'train.bin'), context_length)
val_dataset = TokenDataset(os.path.join(data_dir, 'val.bin'), context_length)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, seed=seed)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, seed=seed)
return {
'train_loader': train_loader,
'val_loader': val_loader,
'meta': meta,
'vocab_size': meta['vocab_size'],
'encode': lambda s: [meta['stoi'][c] for c in s],
'decode': lambda l: ''.join([meta['itos'][int(i)] for i in l])
}
# Example usage
if __name__ == "__main__":
# Load the data with context length of 8 and batch size of 4
data = load_shakespeare(context_length=8, batch_size=4)
# Get the first batch from the training loader
for x_batch, y_batch in data['train_loader']:
print("Input shape:", x_batch.shape)
print("Target shape:", y_batch.shape)
# Print the first sequence in the batch
print("First input sequence:", x_batch[0])
print("First target sequence:", y_batch[0])
# Decode the first sequence
print("Decoded input:", data['decode'](x_batch[0]))
print("Decoded target:", data['decode'](y_batch[0]))
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment