Created
January 15, 2025 23:41
-
-
Save N8python/a5c7b864bed57740705055374bc6f683 to your computer and use it in GitHub Desktop.
Simple character-level pretraining in MLX. Gets a roughly billion tokens/day for an 18M parameter model on one M3 Max.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import random | |
import mlx.optimizers as optim | |
import mlx.core as mx | |
import mlx.nn as nn | |
import numpy as np | |
from tqdm import tqdm | |
import time | |
from datetime import datetime | |
import os | |
from mlx_lm.models.llama import Model, ModelArgs | |
from mlx.utils import tree_flatten | |
random.seed(42) | |
np.random.seed(42) | |
mx.random.seed(42) | |
MAX_CONTEXT_SIZE = 512 | |
docs = [] | |
with open('data.jsonl', 'r') as f: | |
for line in f: | |
d = json.loads(line) | |
text = d["text"] | |
# Break this text into chunks of length MAX_CONTEXT_SIZE | |
for i in range(0, len(text), MAX_CONTEXT_SIZE): | |
chunk_text = text[i : i + MAX_CONTEXT_SIZE] | |
# Append each chunk as a separate doc | |
docs.append({"text": chunk_text}) | |
NORMAL_VOCAB_SIZE = 256 | |
special_token_map = { | |
'<pad>': NORMAL_VOCAB_SIZE, | |
'<bos>': NORMAL_VOCAB_SIZE + 1, | |
'<eos>': NORMAL_VOCAB_SIZE + 2, | |
'<ctrl1>': NORMAL_VOCAB_SIZE + 3, | |
'<ctrl2>': NORMAL_VOCAB_SIZE + 4, | |
'<ctrl3>': NORMAL_VOCAB_SIZE + 5, | |
'<ctrl4>': NORMAL_VOCAB_SIZE + 6, | |
'<ctrl5>': NORMAL_VOCAB_SIZE + 7, | |
} | |
PAD_TOKEN = special_token_map['<pad>'] | |
BOS_TOKEN = special_token_map['<bos>'] | |
EOS_TOKEN = special_token_map['<eos>'] | |
VOCAB_SIZE = NORMAL_VOCAB_SIZE + len(special_token_map) | |
def tokenize(str): | |
return list(str.encode('utf-8')) | |
def tokenize_doc(doc): | |
return [BOS_TOKEN] + tokenize(doc['text'])[:MAX_CONTEXT_SIZE] + [EOS_TOKEN] | |
def generate_training_batch(indices): | |
batch = [tokenize_doc(docs[i]) for i in indices] | |
max_len = max(len(x) for x in batch) | |
for i in range(len(batch)): | |
batch[i] += [PAD_TOKEN] * (max_len - len(batch[i])) | |
return batch | |
idx = sorted(range(len(docs)), key=lambda idx: len(docs[idx])) | |
BATCH_SIZE = 16 | |
batch_idx = [ | |
idx[i : i + BATCH_SIZE : 1] | |
for i in range(0, len(idx) - BATCH_SIZE + 1, BATCH_SIZE) | |
] | |
indices = np.random.permutation(len(batch_idx)) | |
args = ModelArgs( | |
model_type="llama", | |
hidden_size=384, | |
num_hidden_layers=8, | |
intermediate_size=384 * 4, | |
num_attention_heads=8, | |
rms_norm_eps=1e-5, | |
vocab_size=VOCAB_SIZE, | |
head_dim=None, | |
max_position_embeddings=None, | |
num_key_value_heads=None, | |
attention_bias=False, | |
mlp_bias=False, | |
rope_theta=10000, | |
rope_traditional=False, | |
rope_scaling=None, | |
tie_word_embeddings=True | |
) | |
model = Model(args) | |
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**3 | |
print(f"Model has {p}K parameters") | |
def default_loss(model, inputs, targets): | |
""" | |
Masks out PAD tokens and all tokens *before* the first ">" token in each sequence, | |
but keeps the ">" token itself, then computes cross-entropy loss. | |
""" | |
pad_token = PAD_TOKEN | |
# 1) Forward pass | |
logits = model(inputs) | |
logits = logits.astype(mx.float32) | |
loss = nn.losses.cross_entropy(logits, targets) | |
# 3) Mask out the PAD tokens | |
pad_mask = (targets != pad_token) | |
# 5) Combine the PAD mask and the prompt mask | |
combined_mask = pad_mask | |
# 6) Apply the combined mask to the per-token cross-entropy | |
loss = loss * combined_mask | |
# 7) Normalize by the number of tokens that were *not* masked out | |
ntoks = combined_mask.sum() | |
loss = loss.sum() / ntoks | |
return loss, ntoks | |
iters = len(batch_idx) | |
warmup_steps = 1000 | |
warmup = optim.linear_schedule(0, 1e-3, steps=warmup_steps) | |
cosine = optim.cosine_decay(1e-3, iters) | |
lr_schedule = optim.join_schedules([warmup, cosine], [warmup_steps]) | |
optimizer = optim.AdamW(learning_rate=lr_schedule) | |
os.makedirs('logs', exist_ok=True) | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
log_filename = f'logs/training_log_{timestamp}.txt' | |
chkpoint_dir = f'checkpoints_{timestamp}' | |
os.makedirs(chkpoint_dir, exist_ok=True) | |
# Create a progress bar for the entire training loop | |
progress_bar = tqdm(range(iters), desc="Training") | |
total_tokens = 0 | |
loss_value_and_grad = nn.value_and_grad(model, default_loss) | |
with open(log_filename, 'w') as log_file: | |
# Write initial training setup | |
log_file.write(f"Training started at {timestamp}\n") | |
log_file.write(f"Total iterations: {iters}\n") | |
log_file.write("=" * 50 + "\n\n") | |
for step in progress_bar: | |
batch = generate_training_batch(batch_idx[indices[step]]) | |
batch = mx.array(batch) | |
(lvalue, toks), grad = loss_value_and_grad(model, batch[:, :-1], batch[:, 1:]) | |
total_tokens += toks | |
optimizer.update(model, grad) | |
mx.eval(lvalue) | |
mx.metal.clear_cache() | |
# Log metrics every 10 steps | |
if step % 1 == 0: | |
current_lr = lr_schedule(step) | |
log_message = ( | |
f"Step {step}: " | |
f"loss={lvalue:.3e}, " | |
f"toks_per_step={toks}, " | |
f"total_tokens={(total_tokens/10**3):.2f}K, " | |
f"lr={current_lr:.3e}\n" | |
) | |
log_file.write(log_message) | |
log_file.flush() # Ensure logs are written immediately | |
# Update progress bar description | |
progress_bar.set_description( | |
f"Loss: {lvalue:.2f} | Perplexity: {np.exp(lvalue):.2f} | Tokens/step: {toks} | Total tokens: {(total_tokens/10**3):.2f}K | LR: {current_lr:.3e} | Tokens/sec: {(total_tokens / (1000 *(time.time() - progress_bar.start_t))):.2f}K" | |
) | |
if step % 500 == 0: | |
weights = dict(tree_flatten(model.parameters())) | |
mx.save_safetensors(f'{chkpoint_dir}/model_{timestamp}_{step}.safetensors', weights) | |
final_weights = dict(tree_flatten(model.parameters())) | |
mx.save_safetensors(f'{chkpoint_dir}/model_{timestamp}_final.safetensors', final_weights) | |
# Write final summary | |
with open(log_filename, 'a') as log_file: | |
log_file.write("\n" + "=" * 50 + "\n") | |
log_file.write(f"Training completed at {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") | |
log_file.write(f"Final loss: {lvalue:.3e}\n") | |
log_file.write(f"Total tokens processed: {(total_tokens/10**3):.2f}K\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment