Skip to content

Instantly share code, notes, and snippets.

@tkius123
Created January 30, 2026 05:20
Show Gist options
  • Select an option

  • Save tkius123/55890e3c82d23e0fa31fb3b4c4eba39e to your computer and use it in GitHub Desktop.

Select an option

Save tkius123/55890e3c82d23e0fa31fb3b4c4eba39e to your computer and use it in GitHub Desktop.
"""
Basic training implementation - Miners can optimize this!
Usage:
1. Run setup: uv run local_test/setup_benchmark.py
2. Test locally: uv run local_test/train.py
3. Submit when ready!
"""
import json
import time
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
@dataclass
class InnerStepsResult:
"""Required return type from inner_steps function."""
final_logits: torch.Tensor # Output logits from last forward pass
total_tokens: int # Total tokens processed across all steps
final_loss: float # Loss value from last training step
def inner_steps(model, data_iterator, optimizer, num_steps, device):
"""
Run training steps and return results.
Args:
model: Pre-loaded model (already on device, in train mode)
data_iterator: Iterator yielding batches of shape (batch_size, seq_len)
optimizer: Pre-configured optimizer
num_steps: Number of training steps to run
device: Target device (cuda or cpu)
Returns:
InnerStepsResult with outputs for verification
"""
total_tokens = 0
final_logits = None
final_loss = 0.0
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high") # or "medium"
for step in range(num_steps):
# Get batch
batch = next(data_iterator)
if batch.device != device:
batch = batch.to(device, non_blocking=True)
# Prepare inputs and labels
input_ids = batch[:, :-1]
labels = batch[:, 1:]
# Forward pass
outputs = model(input_ids)
logits = outputs.logits if hasattr(outputs, "logits") else outputs
# Compute loss
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
labels.reshape(-1),
ignore_index=-100,
)
# Backward pass
loss.backward()
# Update weights
optimizer.step()
optimizer.zero_grad()
# Track metrics
total_tokens += batch.numel()
if step == num_steps - 1:
final_logits = logits.detach().float()
final_loss = loss.item()
return InnerStepsResult(
final_logits=final_logits,
total_tokens=total_tokens,
final_loss=final_loss,
)
# =============================================================================
# LOCAL TESTING - Run this file to test your implementation
# =============================================================================
if __name__ == "__main__":
print("=" * 60)
print("TESTING train.py - Basic Implementation")
print("=" * 60)
print()
# Load configuration
hparams_path = Path(__file__).parent.parent / "hparams" / "hparams.json"
hparams = {}
if hparams_path.exists():
with open(hparams_path) as f:
hparams = json.load(f)
batch_size = hparams.get("benchmark_batch_size", 16)
num_steps = hparams.get("eval_steps", 5)
num_evals = hparams.get("evaluation_runs", 5)
print(f"Batch size: {batch_size}")
print(f"Steps per eval: {num_steps}")
print(f"Evaluations: {num_evals}")
print()
# Check paths
project_root = Path(__file__).parent.parent
model_path = project_root / "benchmark" / "model"
data_path = project_root / "benchmark" / "data" / "train.pt"
if not model_path.exists() or not data_path.exists():
print("Setup required! Run: uv run local_test/setup_benchmark.py")
exit(1)
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print()
# Load model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.gradient_checkpointing_enable() # Required to fit in GPU memory
model.train()
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print()
# Load data
print("Loading data...")
data = torch.load(data_path, weights_only=True)
print(f"Samples: {data.shape[0]:,}, Sequence length: {data.shape[1]}")
print()
# Create data iterator
def create_iterator():
idx = 0
while True:
end_idx = idx + batch_size
if end_idx > data.shape[0]:
idx = 0
end_idx = batch_size
yield data[idx:end_idx]
idx = end_idx
# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Warmup
print("Warmup...")
_ = inner_steps(model, create_iterator(), optimizer, num_steps=2, device=device)
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
print()
# Reset optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Run evaluations
print(f"Running {num_evals} evaluations...")
for i in range(num_evals):
if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.perf_counter()
result = inner_steps(model, create_iterator(), optimizer, num_steps, device)
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
print(
f" Eval {i + 1}: {elapsed:.3f}s, tokens={result.total_tokens:,}, loss={result.final_loss:.4f}"
)
print()
print("Done!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment