Created
January 13, 2026 07:27
-
-
Save huangsam/858f2f5a7b4d5daff3915be60c52da24 to your computer and use it in GitHub Desktop.
Running neural net to train and eval against IMDb ratings
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any | |
| import torch | |
| from datasets import Dataset, DatasetDict, load_dataset | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| # --- 1. CONFIGURATION CONSTANTS --- | |
| MODEL_NAME = "bert-base-uncased" # The Hugging Face model to use | |
| MAX_LENGTH = 128 # Max length for tokenization | |
| BATCH_SIZE = 16 # Batch size for training | |
| DEVICE = "cpu" # Default to CPU; Mac M3 will often auto-accelerate PyTorch | |
| def load_data(dataset_name="imdb") -> DatasetDict: | |
| """ | |
| Loads and prepares the dataset from the Hugging Face Hub. | |
| """ | |
| print(f"Loading dataset: {dataset_name}...") | |
| # Load the train and test split for the chosen dataset | |
| dataset: DatasetDict = load_dataset(dataset_name) | |
| return dataset | |
| def preprocess_function(examples, tokenizer: AutoTokenizer): | |
| """ | |
| Tokenization function to preprocess the text data. | |
| """ | |
| return tokenizer( | |
| examples["text"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| ) | |
| def main(): | |
| """ | |
| Main entry point for the text classification project. | |
| """ | |
| # 1. Check for Mac Metal Performance Shaders (MPS) for M-series acceleration | |
| # This is an important step to leverage your M3 Max chip | |
| if torch.backends.mps.is_available(): | |
| global DEVICE | |
| DEVICE = "mps" | |
| print(f"🔥 Found MPS device. Using {DEVICE} for acceleration.") | |
| else: | |
| print("Using CPU.") | |
| # 2. Load Data and Tokenizer | |
| raw_datasets: DatasetDict = load_data() | |
| # Instantiate the tokenizer (needed for all text processing) | |
| tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Tokenize datasets | |
| tokenized_datasets: DatasetDict = raw_datasets.map( | |
| lambda x: preprocess_function(x, tokenizer), | |
| batched=True, | |
| remove_columns=["text"], | |
| ) | |
| print("Data tokenization complete.") | |
| print(tokenized_datasets) | |
| # 3. Model Loading (to ensure environment works) | |
| model: AutoModelForSequenceClassification = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2) | |
| model.to(DEVICE) # Move the model to the chosen device (CPU or MPS) | |
| print("\nSetup complete. Starting training...") | |
| # 4. Prepare DataLoaders | |
| from torch.utils.data import DataLoader | |
| train_dataset: Dataset = tokenized_datasets["train"].with_format("torch") | |
| test_dataset: Dataset = tokenized_datasets["test"].with_format("torch") | |
| train_loader: DataLoader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| test_loader: DataLoader = DataLoader(test_dataset, batch_size=BATCH_SIZE) | |
| # 5. Optimizer | |
| optimizer: torch.optim.AdamW = torch.optim.AdamW(model.parameters(), lr=2e-5) | |
| # 6. Training Loop | |
| EPOCHS: int = 2 | |
| for epoch in range(EPOCHS): | |
| model.train() # Set model to training mode | |
| total_loss: float = 0.0 | |
| num_batches: int = len(train_loader) | |
| # Iterate over batches from the training DataLoader | |
| for batch_idx, batch in enumerate(train_loader): | |
| # Batch items are already torch tensors; move to device | |
| input_ids: torch.Tensor = batch["input_ids"].to(DEVICE) | |
| attention_mask: torch.Tensor = batch["attention_mask"].to(DEVICE) | |
| labels: torch.Tensor = batch["label"].to(DEVICE) | |
| # Forward pass: compute model outputs and loss | |
| outputs: Any = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| loss: torch.Tensor = outputs.loss # Extract the loss value from outputs | |
| optimizer.zero_grad() # Reset gradients from previous step | |
| loss.backward() # Backpropagate to compute gradients | |
| optimizer.step() # Update model weights | |
| total_loss += loss.item() # Accumulate batch loss | |
| # Print progress every 100 batches | |
| if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == num_batches: | |
| print(f"Epoch {epoch + 1} | Batch {batch_idx + 1}/{num_batches} | Loss: {loss.item():.4f}") | |
| avg_loss: float = total_loss / num_batches # Average loss for the epoch | |
| print(f"Epoch {epoch + 1}/{EPOCHS} - Training loss: {avg_loss:.4f}") | |
| # 7. Evaluation Loop | |
| model.eval() # Set model to evaluation mode (disables dropout, etc.) | |
| correct: int = 0 | |
| total: int = 0 | |
| # Disable gradient calculation for evaluation (faster, less memory) | |
| with torch.no_grad(): | |
| for batch in test_loader: | |
| # Batch items are already torch tensors; move to device | |
| input_ids: torch.Tensor = batch["input_ids"].to(DEVICE) | |
| attention_mask: torch.Tensor = batch["attention_mask"].to(DEVICE) | |
| labels: torch.Tensor = batch["label"].to(DEVICE) | |
| # Forward pass (no labels needed for prediction) | |
| outputs: Any = model(input_ids=input_ids, attention_mask=attention_mask) | |
| preds: torch.Tensor = torch.argmax(outputs.logits, dim=1) # Get predicted class | |
| correct += (preds == labels).sum().item() # Count correct predictions | |
| total += labels.size(0) # Count total samples | |
| accuracy: float = correct / total if total > 0 else 0 # Compute accuracy | |
| print(f"Test Accuracy: {accuracy:.4f}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment