Skip to content

Instantly share code, notes, and snippets.

@huangsam
Created January 13, 2026 07:27
Show Gist options
  • Select an option

  • Save huangsam/858f2f5a7b4d5daff3915be60c52da24 to your computer and use it in GitHub Desktop.

Select an option

Save huangsam/858f2f5a7b4d5daff3915be60c52da24 to your computer and use it in GitHub Desktop.
Running neural net to train and eval against IMDb ratings
from typing import Any
import torch
from datasets import Dataset, DatasetDict, load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# --- 1. CONFIGURATION CONSTANTS ---
MODEL_NAME = "bert-base-uncased" # The Hugging Face model to use
MAX_LENGTH = 128 # Max length for tokenization
BATCH_SIZE = 16 # Batch size for training
DEVICE = "cpu" # Default to CPU; Mac M3 will often auto-accelerate PyTorch
def load_data(dataset_name="imdb") -> DatasetDict:
"""
Loads and prepares the dataset from the Hugging Face Hub.
"""
print(f"Loading dataset: {dataset_name}...")
# Load the train and test split for the chosen dataset
dataset: DatasetDict = load_dataset(dataset_name)
return dataset
def preprocess_function(examples, tokenizer: AutoTokenizer):
"""
Tokenization function to preprocess the text data.
"""
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
)
def main():
"""
Main entry point for the text classification project.
"""
# 1. Check for Mac Metal Performance Shaders (MPS) for M-series acceleration
# This is an important step to leverage your M3 Max chip
if torch.backends.mps.is_available():
global DEVICE
DEVICE = "mps"
print(f"🔥 Found MPS device. Using {DEVICE} for acceleration.")
else:
print("Using CPU.")
# 2. Load Data and Tokenizer
raw_datasets: DatasetDict = load_data()
# Instantiate the tokenizer (needed for all text processing)
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Tokenize datasets
tokenized_datasets: DatasetDict = raw_datasets.map(
lambda x: preprocess_function(x, tokenizer),
batched=True,
remove_columns=["text"],
)
print("Data tokenization complete.")
print(tokenized_datasets)
# 3. Model Loading (to ensure environment works)
model: AutoModelForSequenceClassification = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
model.to(DEVICE) # Move the model to the chosen device (CPU or MPS)
print("\nSetup complete. Starting training...")
# 4. Prepare DataLoaders
from torch.utils.data import DataLoader
train_dataset: Dataset = tokenized_datasets["train"].with_format("torch")
test_dataset: Dataset = tokenized_datasets["test"].with_format("torch")
train_loader: DataLoader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader: DataLoader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
# 5. Optimizer
optimizer: torch.optim.AdamW = torch.optim.AdamW(model.parameters(), lr=2e-5)
# 6. Training Loop
EPOCHS: int = 2
for epoch in range(EPOCHS):
model.train() # Set model to training mode
total_loss: float = 0.0
num_batches: int = len(train_loader)
# Iterate over batches from the training DataLoader
for batch_idx, batch in enumerate(train_loader):
# Batch items are already torch tensors; move to device
input_ids: torch.Tensor = batch["input_ids"].to(DEVICE)
attention_mask: torch.Tensor = batch["attention_mask"].to(DEVICE)
labels: torch.Tensor = batch["label"].to(DEVICE)
# Forward pass: compute model outputs and loss
outputs: Any = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss: torch.Tensor = outputs.loss # Extract the loss value from outputs
optimizer.zero_grad() # Reset gradients from previous step
loss.backward() # Backpropagate to compute gradients
optimizer.step() # Update model weights
total_loss += loss.item() # Accumulate batch loss
# Print progress every 100 batches
if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == num_batches:
print(f"Epoch {epoch + 1} | Batch {batch_idx + 1}/{num_batches} | Loss: {loss.item():.4f}")
avg_loss: float = total_loss / num_batches # Average loss for the epoch
print(f"Epoch {epoch + 1}/{EPOCHS} - Training loss: {avg_loss:.4f}")
# 7. Evaluation Loop
model.eval() # Set model to evaluation mode (disables dropout, etc.)
correct: int = 0
total: int = 0
# Disable gradient calculation for evaluation (faster, less memory)
with torch.no_grad():
for batch in test_loader:
# Batch items are already torch tensors; move to device
input_ids: torch.Tensor = batch["input_ids"].to(DEVICE)
attention_mask: torch.Tensor = batch["attention_mask"].to(DEVICE)
labels: torch.Tensor = batch["label"].to(DEVICE)
# Forward pass (no labels needed for prediction)
outputs: Any = model(input_ids=input_ids, attention_mask=attention_mask)
preds: torch.Tensor = torch.argmax(outputs.logits, dim=1) # Get predicted class
correct += (preds == labels).sum().item() # Count correct predictions
total += labels.size(0) # Count total samples
accuracy: float = correct / total if total > 0 else 0 # Compute accuracy
print(f"Test Accuracy: {accuracy:.4f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment