Skip to content

Instantly share code, notes, and snippets.

@renxida
Created September 9, 2025 01:15
Show Gist options
  • Save renxida/2b082f7052eb18ff1ca01c6a0f53ffd6 to your computer and use it in GitHub Desktop.
Save renxida/2b082f7052eb18ff1ca01c6a0f53ffd6 to your computer and use it in GitHub Desktop.
RunPod ML Experiments: SSH troubleshooting, MNIST training, and auto-termination setup

RunPod ML Experiment Automation

This setup allows you to run ML experiments on RunPod with automatic pod termination and persistent storage for artifacts.

Features

  • ✅ Automatic pod termination after experiment completion
  • ✅ Persistent storage for artifacts and logs
  • ✅ Both Docker and Dockerless (template) deployment options
  • ✅ Comprehensive logging and metrics tracking
  • ✅ Cost-efficient (pay only for actual compute time)

Files Overview

  • train_mnist.py - Main training script with comprehensive logging
  • run_experiment.sh - Dockerless runner with auto-termination
  • deploy_runpod.py - Python script to deploy experiments
  • Dockerfile & entrypoint.sh - For Docker-based deployment
  • requirements.txt - Python dependencies

Setup Options

Option 1: Dockerless (Recommended for Quick Experiments)

Pros: Faster setup, no Docker knowledge needed, direct SSH access Cons: Less reproducible, manual dependency management

Option 2: Docker (Recommended for Production)

Pros: Fully reproducible, version-locked dependencies, CI/CD ready Cons: Requires Docker setup, slower initial deployment

Quick Start Guide

Prerequisites

  1. RunPod account with API key
  2. Python 3.8+ installed locally
  3. (Optional) Docker for building custom images

Step 1: Get RunPod API Key

  1. Go to RunPod Settings
  2. Generate an API key
  3. Save it securely

Step 2: Install RunPod SDK

pip install runpod

Step 3: Deploy Experiment (Dockerless)

# Deploy using PyTorch template
python deploy_runpod.py \
    --api-key YOUR_API_KEY \
    --experiment-name "mnist_test_1" \
    --mode template \
    --gpu-type "NVIDIA GeForce RTX 3090" \
    --volume-size 50 \
    --epochs 10 \
    --batch-size 64 \
    --learning-rate 0.001

Step 4: Manual Deployment via RunPod UI

  1. Go to RunPod Console
  2. Click "Deploy" → "GPU Pod"
  3. Select PyTorch template
  4. Choose GPU (e.g., RTX 3090)
  5. Set Network Volume to 50GB (for persistent storage)
  6. Deploy the pod
  7. SSH into the pod and run:
# Upload files to pod (from local machine)
scp train_mnist.py run_experiment.sh root@POD_IP:/workspace/scripts/

# SSH into pod
ssh root@POD_IP

# Inside the pod, run experiment
cd /workspace
bash scripts/run_experiment.sh "my_experiment" 10 64 0.001

Docker Deployment (Alternative)

Build and Push Docker Image

# Build image
docker build -t yourusername/mnist-runpod:latest .

# Push to Docker Hub
docker push yourusername/mnist-runpod:latest

# Deploy with Docker
python deploy_runpod.py \
    --api-key YOUR_API_KEY \
    --experiment-name "mnist_docker_1" \
    --mode docker \
    --docker-image yourusername/mnist-runpod:latest \
    --gpu-type "NVIDIA GeForce RTX 3090" \
    --volume-size 50

Experiment Artifacts

All experiments save the following to /workspace/artifacts/EXPERIMENT_NAME/:

artifacts/
└── mnist_experiment_20240315_120000/
    ├── models/
    │   ├── best_model.pth
    │   ├── final_model.pth
    │   └── checkpoint_epoch_5.pth
    ├── logs/
    │   └── training_20240315_120000.log
    ├── metrics/
    │   └── training_metrics.json
    └── summary.txt

Cost Optimization Tips

  1. Use Spot Instances: 50-80% cheaper than on-demand
  2. Auto-termination: Ensures you don't pay for idle time
  3. Right-size GPU: Use smaller GPUs for testing (e.g., RTX 3070)
  4. Persistent Volume: Keep data between runs without re-downloading

Monitoring Experiments

Via RunPod UI

  • Go to Pods section to see status
  • Check logs in pod terminal

Via Python Script

python deploy_runpod.py \
    --api-key YOUR_API_KEY \
    --experiment-name "test" \
    --mode template \
    --monitor  # Enables monitoring

Via SSH

# SSH into running pod
ssh root@POD_IP

# Check logs
tail -f /workspace/artifacts/*/logs/*.log

# Check GPU usage
nvidia-smi

Advanced Configuration

Custom Hyperparameters

Edit train_mnist.py to add more hyperparameters:

  • Optimizer types
  • Learning rate schedules
  • Model architectures
  • Data augmentation

Multiple Experiments

Run parallel experiments with different configs:

# Experiment 1
python deploy_runpod.py --experiment-name "lr_0.01" --learning-rate 0.01 &

# Experiment 2  
python deploy_runpod.py --experiment-name "lr_0.001" --learning-rate 0.001 &

# Experiment 3
python deploy_runpod.py --experiment-name "batch_128" --batch-size 128 &

Using Network Storage

Network volumes persist across pods. To reuse data:

  1. Create a network volume in RunPod
  2. Mount it to /workspace/data
  3. Download datasets once, reuse across experiments

Troubleshooting

SSH Connection Issues

  • Problem: sign_and_send_pubkey: no mutual signature supported
  • Cause: SSH client compatibility (especially Termux)
  • Solution: Use Jupyter interface instead (http://POD_IP:8888, password: rp12345)
  • Alternative: Try RSA keys instead of ed25519
  • Details: See runpod_ssh_troubleshooting.md

Pod Doesn't Auto-Terminate

  • Check if RUNPOD_POD_ID environment variable is set
  • Ensure runpodctl is available in the pod
  • Check logs for errors

Can't Access Artifacts

  • Verify persistent volume is mounted at /workspace
  • Check volume size isn't full
  • Ensure write permissions

Training Fails

  • Check GPU memory with nvidia-smi
  • Reduce batch size if OOM
  • Check Python dependencies installed correctly

Best Practices

  1. Always use persistent volumes for artifacts
  2. Set reasonable timeouts to avoid runaway costs
  3. Test locally first with small epochs
  4. Version control your experiment configs
  5. Monitor GPU usage to optimize batch sizes
  6. Use structured logging for easy debugging

Example Workflow

  1. Develop and test locally (1-2 epochs)
  2. Deploy to RunPod with small GPU (test run)
  3. Verify artifacts are saved correctly
  4. Scale up to larger GPU and full training
  5. Download results from persistent volume
  6. Pod auto-terminates, saving costs

Support

License

MIT License - Feel free to adapt for your needs!

#!/usr/bin/env python3
"""
Direct RunPod API pod creation without SDK dependencies
"""
import json
import requests
import time
import sys
def create_pod_direct_api(api_key, pod_config):
"""Create pod using direct GraphQL API calls"""
# GraphQL mutation to create pod
mutation = """
mutation {
podFindAndDeployOnDemand(
input: {
cloudType: SECURE
gpuCount: 1
volumeInGb: %d
containerDiskInGb: 20
minVcpuCount: 2
minMemoryInGb: 15
gpuTypeId: "%s"
name: "%s"
imageName: "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04"
dockerArgs: ""
ports: "22/tcp"
volumeMountPath: "/workspace"
env: [
{key: "JUPYTER_PASSWORD", value: "rp12345"}
]
}
) {
id
imageName
env
machineId
machine {
podHostId
}
}
}
""" % (pod_config['volume_size'], pod_config['gpu_type'], pod_config['name'])
headers = {
'Content-Type': 'application/json',
}
url = f"https://api.runpod.io/graphql?api_key={api_key}"
payload = {
'query': mutation
}
try:
print(f"Creating pod: {pod_config['name']}")
print(f"GPU: {pod_config['gpu_type']}")
print(f"Volume: {pod_config['volume_size']}GB")
response = requests.post(url, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
if 'errors' in result:
print(f"GraphQL errors: {result['errors']}")
return None
pod_data = result['data']['podFindAndDeployOnDemand']
print(f"✓ Pod created successfully!")
print(f"Pod ID: {pod_data['id']}")
return pod_data
else:
print(f"HTTP Error: {response.status_code}")
print(f"Response: {response.text}")
return None
except Exception as e:
print(f"Error creating pod: {e}")
return None
def get_pod_status(api_key, pod_id):
"""Get pod status"""
query = """
query {
pod(input: {podId: "%s"}) {
id
name
runtime {
uptimeInSeconds
ports {
ip
isIpPublic
privatePort
publicPort
type
}
gpus {
id
gpuUtilPercent
}
}
machine {
podHostId
}
}
}
""" % pod_id
headers = {'Content-Type': 'application/json'}
url = f"https://api.runpod.io/graphql?api_key={api_key}"
try:
response = requests.post(url, headers=headers, json={'query': query}, timeout=30)
if response.status_code == 200:
result = response.json()
return result['data']['pod'] if 'data' in result else None
except Exception as e:
print(f"Error getting pod status: {e}")
return None
def main():
api_key = open('.runpod_api_key').read().strip()
# Pod configuration
pod_config = {
'name': f'mnist-test-{int(time.time())}',
'gpu_type': 'NVIDIA GeForce RTX 3070', # Try RTX 3070 first
'volume_size': 20,
}
print("RunPod Direct API Pod Creation")
print("=" * 40)
# Create pod
pod = create_pod_direct_api(api_key, pod_config)
if not pod:
print("Failed to create pod. Trying alternative GPU types...")
# Try other GPU types
gpu_types = [
'NVIDIA GeForce RTX 3090',
'NVIDIA GeForce RTX 4090',
'NVIDIA RTX A4000'
]
for gpu_type in gpu_types:
print(f"\nTrying {gpu_type}...")
pod_config['gpu_type'] = gpu_type
pod_config['name'] = f'mnist-test-{gpu_type.replace(" ", "-").lower()}-{int(time.time())}'
pod = create_pod_direct_api(api_key, pod_config)
if pod:
break
if not pod:
print("Failed to create pod with any GPU type")
return 1
pod_id = pod['id']
print(f"\nWaiting for pod {pod_id} to be ready...")
# Wait for pod to be ready
max_wait = 300 # 5 minutes
wait_time = 0
while wait_time < max_wait:
status = get_pod_status(api_key, pod_id)
if status and status.get('runtime'):
runtime = status['runtime']
ports = runtime.get('ports', [])
if ports:
ssh_port = None
ip = None
for port in ports:
if port['privatePort'] == 22:
ssh_port = port.get('publicPort')
ip = port.get('ip')
break
if ssh_port and ip:
print(f"\n✓ Pod is ready!")
print(f"SSH: ssh root@{ip} -p {ssh_port}")
print(f"Pod ID: {pod_id}")
# Save connection details
connection_info = {
'pod_id': pod_id,
'ip': ip,
'port': ssh_port,
'ssh_command': f'ssh root@{ip} -p {ssh_port}'
}
with open('pod_connection.json', 'w') as f:
json.dump(connection_info, f, indent=2)
print(f"Connection details saved to: pod_connection.json")
return 0
print(f"Waiting... ({wait_time}s)")
time.sleep(10)
wait_time += 10
print("Timeout waiting for pod to be ready")
return 1
if __name__ == "__main__":
sys.exit(main())
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from datetime import datetime
print("=== RunPod MNIST Quick Experiment ===")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
# Simple CNN
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = torch.relu(self.fc1(x))
return self.fc2(x)
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('/workspace/data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('/workspace/data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Model
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
print("Starting training for 3 epochs...")
# Training loop
for epoch in range(3):
model.train()
train_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx >= 100: # Limit to 100 batches for quick test
break
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
if batch_idx % 20 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
# Test
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
accuracy = 100. * correct / len(test_dataset)
print(f'Epoch {epoch+1} completed. Test Accuracy: {accuracy:.2f}%')
# Save model
os.makedirs('/workspace/artifacts', exist_ok=True)
torch.save(model.state_dict(), '/workspace/artifacts/mnist_model.pth')
print("Model saved to /workspace/artifacts/mnist_model.pth")
# Save result summary
with open('/workspace/artifacts/experiment_result.txt', 'w') as f:
f.write(f"MNIST Experiment Results\n")
f.write(f"Timestamp: {datetime.now()}\n")
f.write(f"Final Accuracy: {accuracy:.2f}%\n")
f.write(f"Device: {device}\n")
f.write(f"PyTorch: {torch.__version__}\n")
print("=== Experiment Complete ===")
print(f"Final Test Accuracy: {accuracy:.2f}%")
print("Artifacts saved to /workspace/artifacts/")
print("\nTo auto-terminate pod, run:")
print("import os; os.system('runpodctl stop pod $RUNPOD_POD_ID')")
#!/bin/bash
# RunPod Experiment Runner Script (Dockerless version)
# This script runs on a RunPod pod with PyTorch template
# It automatically terminates the pod after completion
set -e # Exit on error
echo "================================================"
echo "RunPod ML Experiment Runner"
echo "Pod ID: $RUNPOD_POD_ID"
echo "Started at: $(date)"
echo "================================================"
# Configuration
WORKSPACE_DIR="/workspace"
ARTIFACTS_DIR="$WORKSPACE_DIR/artifacts"
DATA_DIR="$WORKSPACE_DIR/data"
SCRIPTS_DIR="$WORKSPACE_DIR/scripts"
# Parse command line arguments
EXPERIMENT_NAME="${1:-mnist_auto_$(date +%Y%m%d_%H%M%S)}"
EPOCHS="${2:-10}"
BATCH_SIZE="${3:-64}"
LEARNING_RATE="${4:-0.001}"
echo "Configuration:"
echo " Experiment: $EXPERIMENT_NAME"
echo " Epochs: $EPOCHS"
echo " Batch Size: $BATCH_SIZE"
echo " Learning Rate: $LEARNING_RATE"
echo ""
# Create directories
echo "Setting up directories..."
mkdir -p "$ARTIFACTS_DIR"
mkdir -p "$DATA_DIR"
mkdir -p "$SCRIPTS_DIR"
# Check if training script exists, if not download it
if [ ! -f "$SCRIPTS_DIR/train_mnist.py" ]; then
echo "Training script not found. Downloading..."
# You can wget/curl your script from GitHub or another source
# For now, we'll assume it's been uploaded
if [ -f "/tmp/train_mnist.py" ]; then
cp /tmp/train_mnist.py "$SCRIPTS_DIR/"
else
echo "ERROR: train_mnist.py not found in /tmp/"
echo "Please upload the training script first"
exit 1
fi
fi
# Install dependencies if needed
echo "Checking Python dependencies..."
if ! python -c "import torch" 2>/dev/null; then
echo "Installing PyTorch..."
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
fi
# Additional dependencies
pip install numpy matplotlib tqdm
# Run the training script
echo ""
echo "Starting training..."
echo "================================================"
python "$SCRIPTS_DIR/train_mnist.py" \
--experiment-name "$EXPERIMENT_NAME" \
--epochs "$EPOCHS" \
--batch-size "$BATCH_SIZE" \
--lr "$LEARNING_RATE" \
--output-dir "$ARTIFACTS_DIR" \
--data-dir "$DATA_DIR"
TRAINING_EXIT_CODE=$?
echo "================================================"
echo "Training completed with exit code: $TRAINING_EXIT_CODE"
# Create completion marker
echo "Creating completion marker..."
COMPLETION_FILE="$ARTIFACTS_DIR/$EXPERIMENT_NAME/COMPLETED"
echo "Experiment completed at: $(date)" > "$COMPLETION_FILE"
echo "Exit code: $TRAINING_EXIT_CODE" >> "$COMPLETION_FILE"
# List generated artifacts
echo ""
echo "Generated artifacts:"
ls -la "$ARTIFACTS_DIR/$EXPERIMENT_NAME/"
# Calculate total size
ARTIFACTS_SIZE=$(du -sh "$ARTIFACTS_DIR/$EXPERIMENT_NAME" | cut -f1)
echo "Total artifacts size: $ARTIFACTS_SIZE"
# Final summary
echo ""
echo "================================================"
echo "Experiment Summary:"
echo " Name: $EXPERIMENT_NAME"
echo " Status: $([ $TRAINING_EXIT_CODE -eq 0 ] && echo 'SUCCESS' || echo 'FAILED')"
echo " Artifacts: $ARTIFACTS_DIR/$EXPERIMENT_NAME"
echo " Size: $ARTIFACTS_SIZE"
echo " Completed: $(date)"
echo "================================================"
# Auto-terminate the pod if RUNPOD_POD_ID is set
if [ ! -z "$RUNPOD_POD_ID" ]; then
echo ""
echo "Auto-terminating pod in 10 seconds..."
echo "Artifacts are saved to persistent storage at: $ARTIFACTS_DIR/$EXPERIMENT_NAME"
echo "Press Ctrl+C to cancel auto-termination"
sleep 10
echo "Stopping pod $RUNPOD_POD_ID..."
runpodctl stop pod $RUNPOD_POD_ID
else
echo "Not running on RunPod (RUNPOD_POD_ID not set), skipping auto-termination"
fi
exit $TRAINING_EXIT_CODE

RunPod SSH Troubleshooting Guide

Summary

Successfully deployed RunPod GPU pod via CLI, but encountered SSH client compatibility issues with Termux. Jupyter interface works as excellent alternative.

Issue Analysis

What We Tried

  1. ✅ Generated ed25519 SSH keys
  2. ✅ Added public key to RunPod account
  3. ✅ Used various SSH connection methods
  4. ❌ Direct SSH connection failed
  5. ❌ SCP/SFTP file transfers failed

Error Messages Encountered

sign_and_send_pubkey: no mutual signature supported
Error: Your SSH client doesn't support PTY
subsystem request failed on channel 0
Unable to negotiate with 3.209.22.108 port 22: no matching host key type found. 
Their offer: rsa-sha2-512,rsa-sha2-256,ssh-rsa

Root Cause Analysis

NOT RunPod's Fault

  • ✅ RunPod uses standard RSA host keys (normal)
  • ✅ RunPod accepts ed25519 user keys (we added successfully)
  • ✅ RunPod SSH server configuration is standard

Termux SSH Client Issues

  • ❌ Limited signature algorithm support
  • ❌ Poor PTY (pseudo-terminal) support
  • ❌ Overly restrictive protocol negotiation
  • ❌ Compatibility issues with some SSH servers

Technical Details

SSH Client Version:

OpenSSH_10.0p2, OpenSSL 3.5.2 5 Aug 2025

Host Key Mismatch:

  • RunPod offers: rsa-sha2-512,rsa-sha2-256,ssh-rsa
  • Our client expected: ssh-ed25519
  • Solution: Accept RSA host keys

Partial Success Command:

ssh -o StrictHostKeyChecking=no \
    -o HostKeyAlgorithms=+ssh-rsa,rsa-sha2-256,rsa-sha2-512 \
    -o PubkeyAcceptedKeyTypes=+ssh-rsa,ssh-ed25519 \
    -T [email protected] "echo connected"

Result: Connects but still has PTY limitations.

Successful Workaround: Jupyter Interface

Why Jupyter is Better for ML Experiments

  • ✅ Visual progress bars and plots
  • ✅ Interactive debugging capabilities
  • ✅ Cell-by-cell execution
  • ✅ No SSH compatibility issues
  • ✅ Built-in file manager
  • ✅ Persistent sessions

Connection Details

  • URL: http://POD_IP:8888
  • Password: rp12345 (set during pod creation)
  • Default Port: Usually 8888, check pod status for actual port

Inline Experiment Approach

Created inline_experiment.py with complete MNIST training code that can be copied directly into Jupyter notebook.

Alternative SSH Solutions (For Future Reference)

1. Try Different SSH Client

# If available, try different SSH implementations
# dropbear, putty, etc.

2. Generate RSA Keys Instead

ssh-keygen -t rsa -b 4096 -f ~/.ssh/id_rsa_runpod -N ""
# Add id_rsa_runpod.pub to RunPod account

3. Use RunPod Web Terminal

  • Available in RunPod web console
  • Browser-based terminal access
  • No local SSH client issues

4. Force SSH Compatibility Mode

ssh -o StrictHostKeyChecking=no \
    -o HostKeyAlgorithms=+ssh-rsa \
    -o PubkeyAcceptedKeyTypes=+ssh-rsa,ssh-ed25519 \
    -o KexAlgorithms=+diffie-hellman-group14-sha256 \
    -T user@host "command"

Deployment Success Summary

What Worked

  1. Pod Creation: ✅ Direct API calls with requests
  2. GPU Selection: ✅ RTX 3090 (RTX 3070 was unavailable)
  3. Volume Configuration: ✅ 20GB persistent storage at /workspace
  4. Image: ✅ runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04

Pod Details

  • Pod ID: togx64g2fw0r49
  • SSH Format: [email protected]
  • IP: 213.192.2.115
  • Jupyter: http://213.192.2.115:8888
  • Cost: ~$0.50-0.80/hour for RTX 3090

Files Created

  • create_pod.py - Direct API pod creation
  • inline_experiment.py - Complete MNIST experiment for Jupyter
  • pod_connection.json - Connection details
  • get_pod_logs.py - API-based pod monitoring

Lessons Learned

  1. SSH isn't always necessary - Jupyter is often better for ML
  2. Client compatibility varies - Don't assume SSH "just works"
  3. API approach is reliable - Direct GraphQL calls bypass CLI issues
  4. Plan for alternatives - Have multiple access methods ready
  5. RunPod's setup is standard - Issues are usually client-side

Best Practices for RunPod

For Reliable Access

  1. Primary: Use Jupyter interface for experiments
  2. Backup: Try multiple SSH approaches
  3. Alternative: Use RunPod web terminal
  4. Testing: Verify SSH before long deployments

For Cost Efficiency

  1. Auto-termination: Always implement pod stopping
  2. Spot instances: Use when possible (50-80% savings)
  3. Right-sizing: Start with smaller GPUs for testing
  4. Monitoring: Track usage and terminate idle pods

Future Improvements

  1. Better SSH client: Research alternatives to Termux SSH
  2. WebSocket approach: Direct connection to pod via browser
  3. File sync: Alternative upload methods (WebDAV, etc.)
  4. Automated deployment: Scripts that handle multiple access methods

Date: September 8, 2025
Environment: Termux on Android, ARM64
RunPod Account: Active with ed25519 SSH key

#!/usr/bin/env python3
"""
MNIST Training Script for RunPod
Trains a simple CNN on MNIST dataset and saves artifacts to persistent storage
"""
import os
import sys
import json
import logging
import argparse
from datetime import datetime
from pathlib import Path
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
# Setup logging
def setup_logging(log_dir):
"""Setup logging to both file and console"""
log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = log_dir / f"training_{timestamp}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout)
]
)
return logging.getLogger(__name__)
# Simple CNN Model
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.25)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
def train_epoch(model, device, train_loader, optimizer, criterion, epoch, logger):
"""Train for one epoch"""
model.train()
train_loss = 0
correct = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
if batch_idx % 100 == 0:
logger.info(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
avg_loss = train_loss / len(train_loader)
accuracy = 100. * correct / len(train_loader.dataset)
return avg_loss, accuracy
def test(model, device, test_loader, criterion, logger):
"""Evaluate model on test set"""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader)
accuracy = 100. * correct / len(test_loader.dataset)
logger.info(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
f'({accuracy:.2f}%)')
return test_loss, accuracy
def main():
parser = argparse.ArgumentParser(description='MNIST Training on RunPod')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=64, help='input batch size')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--output-dir', type=str, default='/workspace/artifacts',
help='directory to save outputs')
parser.add_argument('--data-dir', type=str, default='/workspace/data',
help='directory for dataset')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--experiment-name', type=str, default=None,
help='experiment name for tracking')
args = parser.parse_args()
# Setup directories
output_dir = Path(args.output_dir)
data_dir = Path(args.data_dir)
# Create experiment directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
exp_name = args.experiment_name or f"mnist_experiment_{timestamp}"
exp_dir = output_dir / exp_name
exp_dir.mkdir(parents=True, exist_ok=True)
# Setup subdirectories
models_dir = exp_dir / "models"
logs_dir = exp_dir / "logs"
metrics_dir = exp_dir / "metrics"
for dir_path in [models_dir, logs_dir, metrics_dir]:
dir_path.mkdir(parents=True, exist_ok=True)
# Setup logging
logger = setup_logging(logs_dir)
logger.info(f"Starting experiment: {exp_name}")
logger.info(f"Arguments: {vars(args)}")
# Set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
# Data loading
logger.info("Loading MNIST dataset...")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(data_dir, train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
logger.info(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
# Model, optimizer, and loss
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()
# Training metrics
metrics = {
'config': vars(args),
'device': str(device),
'timestamp': timestamp,
'train_losses': [],
'train_accuracies': [],
'test_losses': [],
'test_accuracies': [],
'epoch_times': []
}
# Training loop
logger.info("Starting training...")
best_accuracy = 0
for epoch in range(1, args.epochs + 1):
start_time = time.time()
train_loss, train_acc = train_epoch(model, device, train_loader, optimizer, criterion, epoch, logger)
test_loss, test_acc = test(model, device, test_loader, criterion, logger)
epoch_time = time.time() - start_time
# Record metrics
metrics['train_losses'].append(train_loss)
metrics['train_accuracies'].append(train_acc)
metrics['test_losses'].append(test_loss)
metrics['test_accuracies'].append(test_acc)
metrics['epoch_times'].append(epoch_time)
logger.info(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%, Time: {epoch_time:.2f}s")
# Save best model
if test_acc > best_accuracy:
best_accuracy = test_acc
best_model_path = models_dir / "best_model.pth"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'test_accuracy': test_acc,
'train_accuracy': train_acc,
'test_loss': test_loss,
'train_loss': train_loss,
}, best_model_path)
logger.info(f"Saved best model with accuracy: {test_acc:.2f}%")
# Save checkpoint every 5 epochs
if epoch % 5 == 0:
checkpoint_path = models_dir / f"checkpoint_epoch_{epoch}.pth"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_path)
logger.info(f"Saved checkpoint at epoch {epoch}")
# Save final model
final_model_path = models_dir / "final_model.pth"
torch.save({
'epoch': args.epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'final_test_accuracy': test_acc,
'best_test_accuracy': best_accuracy,
}, final_model_path)
# Save metrics
metrics['best_accuracy'] = best_accuracy
metrics['final_accuracy'] = test_acc
metrics['total_training_time'] = sum(metrics['epoch_times'])
metrics_file = metrics_dir / "training_metrics.json"
with open(metrics_file, 'w') as f:
json.dump(metrics, f, indent=2)
logger.info(f"Training completed! Best accuracy: {best_accuracy:.2f}%")
logger.info(f"All artifacts saved to: {exp_dir}")
logger.info(f"Total training time: {sum(metrics['epoch_times']):.2f} seconds")
# Create summary file
summary_file = exp_dir / "summary.txt"
with open(summary_file, 'w') as f:
f.write(f"Experiment: {exp_name}\n")
f.write(f"Timestamp: {timestamp}\n")
f.write(f"Device: {device}\n")
f.write(f"Epochs: {args.epochs}\n")
f.write(f"Batch Size: {args.batch_size}\n")
f.write(f"Learning Rate: {args.lr}\n")
f.write(f"Best Test Accuracy: {best_accuracy:.2f}%\n")
f.write(f"Final Test Accuracy: {test_acc:.2f}%\n")
f.write(f"Total Training Time: {sum(metrics['epoch_times']):.2f} seconds\n")
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment