Skip to content

Instantly share code, notes, and snippets.

@N8python
Created January 15, 2025 23:41
Show Gist options
  • Save N8python/a5c7b864bed57740705055374bc6f683 to your computer and use it in GitHub Desktop.
Save N8python/a5c7b864bed57740705055374bc6f683 to your computer and use it in GitHub Desktop.
Simple character-level pretraining in MLX. Gets a roughly billion tokens/day for an 18M parameter model on one M3 Max.
import json
import random
import mlx.optimizers as optim
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
import time
from datetime import datetime
import os
from mlx_lm.models.llama import Model, ModelArgs
from mlx.utils import tree_flatten
random.seed(42)
np.random.seed(42)
mx.random.seed(42)
MAX_CONTEXT_SIZE = 512
docs = []
with open('data.jsonl', 'r') as f:
for line in f:
d = json.loads(line)
text = d["text"]
# Break this text into chunks of length MAX_CONTEXT_SIZE
for i in range(0, len(text), MAX_CONTEXT_SIZE):
chunk_text = text[i : i + MAX_CONTEXT_SIZE]
# Append each chunk as a separate doc
docs.append({"text": chunk_text})
NORMAL_VOCAB_SIZE = 256
special_token_map = {
'<pad>': NORMAL_VOCAB_SIZE,
'<bos>': NORMAL_VOCAB_SIZE + 1,
'<eos>': NORMAL_VOCAB_SIZE + 2,
'<ctrl1>': NORMAL_VOCAB_SIZE + 3,
'<ctrl2>': NORMAL_VOCAB_SIZE + 4,
'<ctrl3>': NORMAL_VOCAB_SIZE + 5,
'<ctrl4>': NORMAL_VOCAB_SIZE + 6,
'<ctrl5>': NORMAL_VOCAB_SIZE + 7,
}
PAD_TOKEN = special_token_map['<pad>']
BOS_TOKEN = special_token_map['<bos>']
EOS_TOKEN = special_token_map['<eos>']
VOCAB_SIZE = NORMAL_VOCAB_SIZE + len(special_token_map)
def tokenize(str):
return list(str.encode('utf-8'))
def tokenize_doc(doc):
return [BOS_TOKEN] + tokenize(doc['text'])[:MAX_CONTEXT_SIZE] + [EOS_TOKEN]
def generate_training_batch(indices):
batch = [tokenize_doc(docs[i]) for i in indices]
max_len = max(len(x) for x in batch)
for i in range(len(batch)):
batch[i] += [PAD_TOKEN] * (max_len - len(batch[i]))
return batch
idx = sorted(range(len(docs)), key=lambda idx: len(docs[idx]))
BATCH_SIZE = 16
batch_idx = [
idx[i : i + BATCH_SIZE : 1]
for i in range(0, len(idx) - BATCH_SIZE + 1, BATCH_SIZE)
]
indices = np.random.permutation(len(batch_idx))
args = ModelArgs(
model_type="llama",
hidden_size=384,
num_hidden_layers=8,
intermediate_size=384 * 4,
num_attention_heads=8,
rms_norm_eps=1e-5,
vocab_size=VOCAB_SIZE,
head_dim=None,
max_position_embeddings=None,
num_key_value_heads=None,
attention_bias=False,
mlp_bias=False,
rope_theta=10000,
rope_traditional=False,
rope_scaling=None,
tie_word_embeddings=True
)
model = Model(args)
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**3
print(f"Model has {p}K parameters")
def default_loss(model, inputs, targets):
"""
Masks out PAD tokens and all tokens *before* the first ">" token in each sequence,
but keeps the ">" token itself, then computes cross-entropy loss.
"""
pad_token = PAD_TOKEN
# 1) Forward pass
logits = model(inputs)
logits = logits.astype(mx.float32)
loss = nn.losses.cross_entropy(logits, targets)
# 3) Mask out the PAD tokens
pad_mask = (targets != pad_token)
# 5) Combine the PAD mask and the prompt mask
combined_mask = pad_mask
# 6) Apply the combined mask to the per-token cross-entropy
loss = loss * combined_mask
# 7) Normalize by the number of tokens that were *not* masked out
ntoks = combined_mask.sum()
loss = loss.sum() / ntoks
return loss, ntoks
iters = len(batch_idx)
warmup_steps = 1000
warmup = optim.linear_schedule(0, 1e-3, steps=warmup_steps)
cosine = optim.cosine_decay(1e-3, iters)
lr_schedule = optim.join_schedules([warmup, cosine], [warmup_steps])
optimizer = optim.AdamW(learning_rate=lr_schedule)
os.makedirs('logs', exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_filename = f'logs/training_log_{timestamp}.txt'
chkpoint_dir = f'checkpoints_{timestamp}'
os.makedirs(chkpoint_dir, exist_ok=True)
# Create a progress bar for the entire training loop
progress_bar = tqdm(range(iters), desc="Training")
total_tokens = 0
loss_value_and_grad = nn.value_and_grad(model, default_loss)
with open(log_filename, 'w') as log_file:
# Write initial training setup
log_file.write(f"Training started at {timestamp}\n")
log_file.write(f"Total iterations: {iters}\n")
log_file.write("=" * 50 + "\n\n")
for step in progress_bar:
batch = generate_training_batch(batch_idx[indices[step]])
batch = mx.array(batch)
(lvalue, toks), grad = loss_value_and_grad(model, batch[:, :-1], batch[:, 1:])
total_tokens += toks
optimizer.update(model, grad)
mx.eval(lvalue)
mx.metal.clear_cache()
# Log metrics every 10 steps
if step % 1 == 0:
current_lr = lr_schedule(step)
log_message = (
f"Step {step}: "
f"loss={lvalue:.3e}, "
f"toks_per_step={toks}, "
f"total_tokens={(total_tokens/10**3):.2f}K, "
f"lr={current_lr:.3e}\n"
)
log_file.write(log_message)
log_file.flush() # Ensure logs are written immediately
# Update progress bar description
progress_bar.set_description(
f"Loss: {lvalue:.2f} | Perplexity: {np.exp(lvalue):.2f} | Tokens/step: {toks} | Total tokens: {(total_tokens/10**3):.2f}K | LR: {current_lr:.3e} | Tokens/sec: {(total_tokens / (1000 *(time.time() - progress_bar.start_t))):.2f}K"
)
if step % 500 == 0:
weights = dict(tree_flatten(model.parameters()))
mx.save_safetensors(f'{chkpoint_dir}/model_{timestamp}_{step}.safetensors', weights)
final_weights = dict(tree_flatten(model.parameters()))
mx.save_safetensors(f'{chkpoint_dir}/model_{timestamp}_final.safetensors', final_weights)
# Write final summary
with open(log_filename, 'a') as log_file:
log_file.write("\n" + "=" * 50 + "\n")
log_file.write(f"Training completed at {datetime.now().strftime('%Y%m%d_%H%M%S')}\n")
log_file.write(f"Final loss: {lvalue:.3e}\n")
log_file.write(f"Total tokens processed: {(total_tokens/10**3):.2f}K\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment