Created
September 14, 2025 18:39
-
-
Save kevinmartinjos/b6f9cc860110e40085eec3d84d70162a to your computer and use it in GitHub Desktop.
Shibuya et al: Binary MLP for MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from datasets import load_dataset | |
| import numpy as np | |
| from typing import Dict, Any | |
| from torch.optim import Optimizer | |
| class StraightThroughBinaryActivation(torch.autograd.Function): | |
| """Binary activation with straight-through estimator""" | |
| @staticmethod | |
| def forward(ctx, input): | |
| # Forward pass: apply sign but mapped to {0,1} | |
| return (torch.sign(input) >= 0).float() | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| # Backward pass: straight-through estimator | |
| # Pass gradients through as-is (identity) | |
| return grad_output | |
| class BinaryActivation(nn.Module): | |
| """Binary activation function: sign(x) mapped to {0,1} with STE""" | |
| def forward(self, x): | |
| return StraightThroughBinaryActivation.apply(x) | |
| class BinaryLinear(nn.Module): | |
| """Linear layer with binary weights constrained to {0, 1} as per paper""" | |
| def __init__(self, in_features: int, out_features: int): | |
| super(BinaryLinear, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| # Initialize binary weights to {0, 1} using Bernoulli(p=0.5) as per paper | |
| # Make sure requires_grad is True for gradient computation | |
| self.weight = nn.Parameter( | |
| torch.bernoulli(torch.full((out_features, in_features), 0.5)), | |
| requires_grad=True | |
| ) | |
| # No bias for simplicity | |
| self.register_parameter('bias', None) | |
| def forward(self, input): | |
| # Use the binary weights directly for computation | |
| # The gradient will be computed through the loss and backpropagated | |
| return F.linear(input, self.weight) | |
| class BinaryMLP(nn.Module): | |
| """Multi-layer perceptron following paper's architecture exactly""" | |
| def __init__(self, input_size=784, hidden_size=128, num_classes=10, num_layers=4): | |
| super(BinaryMLP, self).__init__() | |
| self.layers = nn.ModuleList() | |
| self.batch_norms = nn.ModuleList() | |
| # First layer: input -> hidden | |
| self.layers.append(BinaryLinear(input_size, hidden_size)) | |
| self.batch_norms.append(nn.BatchNorm1d(hidden_size)) | |
| # Hidden layers | |
| for _ in range(num_layers - 2): | |
| self.layers.append(BinaryLinear(hidden_size, hidden_size)) | |
| self.batch_norms.append(nn.BatchNorm1d(hidden_size)) | |
| # Output layer | |
| self.layers.append(BinaryLinear(hidden_size, num_classes)) | |
| self.batch_norms.append(nn.BatchNorm1d(num_classes)) | |
| self.activation = BinaryActivation() | |
| def forward(self, x): | |
| x = x.view(x.size(0), -1) # Flatten | |
| # Forward pass following Algorithm 1 | |
| for i, (layer, bn) in enumerate(zip(self.layers, self.batch_norms)): | |
| # W_{l}.h_{l-1} | |
| x = layer(x) | |
| # a_{l} = BatchNorm(W_{l}.h_{l-1}) | |
| x = bn(x) | |
| # h_{l} = sgn(a_{l}) | |
| x = self.activation(x) | |
| if i == len(self.layers) - 1: | |
| x = F.log_softmax(x, dim=1) | |
| return x | |
| class HyperMaskOptimizer(Optimizer): | |
| """ | |
| Custom optimizer implementing Algorithm 1 from the paper. | |
| Uses {0,1} binary weights and proper Boolean operations. | |
| """ | |
| def __init__(self, params, delta=1e-3): | |
| """ | |
| Args: | |
| params: Model parameters (binary weights in {0,1}) | |
| delta: Probability for random mask (δ_t in paper) | |
| """ | |
| defaults = dict(delta=delta) | |
| super(HyperMaskOptimizer, self).__init__(params, defaults) | |
| self.delta = delta | |
| def zero_grad(self, set_to_none: bool = True): | |
| """Zero out gradients""" | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is not None: | |
| if set_to_none: | |
| p.grad = None | |
| else: | |
| p.grad.zero_() | |
| def step(self, closure=None): | |
| """ | |
| Perform one optimization step using Algorithm 1 from paper. | |
| Following the exact steps: | |
| 1. Compute real gradient g_l | |
| 2. Compute target weight w*_l = ⌊-g_l >= 0⌋ | |
| 3. Sample hypermask m_l from random distribution | |
| 4. Update: w_l = m_l · w_{l-1} + m̄_l · w*_l | |
| """ | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| for group in self.param_groups: | |
| for param in group['params']: | |
| if param.grad is None: | |
| continue | |
| grad = param.grad | |
| current_weights = param.data.clone() # w_{t-1} in {0,1} | |
| # Step 2: Compute target weight w* = ⌊-g >= 0⌋ | |
| # ⌊condition⌋ gives 1 if condition is True, 0 if False | |
| target_weights = (-grad >= 0).float() # w*_t in {0,1} | |
| # Step 3: Sample hypermask from random distribution | |
| # P(M_t = 1) = δ_t (Implementation 5) | |
| mask = torch.bernoulli(torch.full_like(grad, self.delta)) # m_t in {0,1} | |
| # Step 4: Update using Boolean operations | |
| # w_t = m_t · w_{t-1} + m̄_t · w*_t | |
| # Where · is AND, m̄_t is NOT m_ttarge | |
| mask_not = 1 - mask # m̄_t (NOT operation) | |
| new_weights = (mask_not * current_weights) + (mask * target_weights) | |
| # Update parameters while preserving the computation graph | |
| param.data.copy_(new_weights) | |
| # Ensure weights remain in {0, 1} | |
| param.data.clamp_(0, 1) | |
| param.data.round_() | |
| return loss | |
| def load_mnist_data(batch_size=16384): | |
| """Load MNIST dataset from HuggingFace with correct batch size from paper""" | |
| dataset = load_dataset("mnist") | |
| def transform_data(examples): | |
| # Convert to tensors and normalize to [0,1] | |
| images = torch.tensor(np.array(examples['image']), dtype=torch.float32) / 255.0 | |
| labels = torch.tensor(examples['label'], dtype=torch.long) | |
| return {'image': images, 'label': labels} | |
| train_dataset = dataset['train'].with_transform(transform_data) | |
| test_dataset = dataset['test'].with_transform(transform_data) | |
| def collate_fn(batch): | |
| images = torch.stack([item['image'] for item in batch]) | |
| labels = torch.tensor([item['label'] for item in batch]) | |
| return images, labels | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) | |
| return train_loader, test_loader | |
| # Removed hook system - using autograd.Function instead | |
| def train_model(model, train_loader, test_loader, epochs=1000, delta=1e-3, learning_rate=10.0): | |
| """Train the binary neural network following paper's hyperparameters""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| # STE is now handled by the custom autograd.Function | |
| # Get only the binary layer parameters for the HyperMask optimizer | |
| binary_params = [] | |
| bn_params = [] | |
| for name, param in model.named_parameters(): | |
| if 'layers' in name and 'weight' in name: # Binary layer weights | |
| param.requires_grad_(True) | |
| binary_params.append(param) | |
| elif 'batch_norms' in name: # BatchNorm parameters | |
| param.requires_grad_(True) | |
| bn_params.append(param) | |
| # Use custom HyperMask optimizer for binary weights only | |
| hypermask_optimizer = HyperMaskOptimizer(binary_params, delta=delta) | |
| # Use standard SGD for BatchNorm parameters | |
| bn_optimizer = torch.optim.SGD(bn_params, lr=0.01) | |
| criterion = nn.CrossEntropyLoss() | |
| model.train() | |
| for epoch in range(epochs): | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| for batch_idx, (data, target) in enumerate(train_loader): | |
| data, target = data.to(device), target.to(device) | |
| # Zero gradients for both optimizers | |
| hypermask_optimizer.zero_grad() | |
| bn_optimizer.zero_grad() | |
| output = model(data) | |
| loss = criterion(output, target) | |
| loss.backward() | |
| # Apply optimization steps | |
| hypermask_optimizer.step() | |
| bn_optimizer.step() | |
| total_loss += loss.item() | |
| _, predicted = torch.max(output.data, 1) | |
| total += target.size(0) | |
| correct += (predicted == target).sum().item() | |
| # Test accuracy (report every 100 epochs to match paper's long training) | |
| test_acc = evaluate_model(model, test_loader, device) | |
| train_acc = 100 * correct / total | |
| avg_loss = total_loss / len(train_loader) | |
| print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, ' | |
| f'Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%') | |
| # Verify weights are still binary | |
| verify_binary_weights(model) | |
| def evaluate_model(model, test_loader, device): | |
| """Evaluate model accuracy""" | |
| model.eval() | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for data, target in test_loader: | |
| data, target = data.to(device), target.to(device) | |
| outputs = model(data) | |
| _, predicted = torch.max(outputs, 1) | |
| total += target.size(0) | |
| correct += (predicted == target).sum().item() | |
| model.train() | |
| return 100 * correct / total | |
| def verify_binary_weights(model): | |
| """Verify that all weights are properly binary {0, 1}""" | |
| for name, param in model.named_parameters(): | |
| if 'layers' in name and 'weight' in name: # Only check binary layer weights | |
| unique_vals = torch.unique(param.data).cpu().numpy() | |
| # Check if we have only values close to 0 and 1 | |
| is_binary = len(unique_vals) <= 2 and all( | |
| np.isclose(val, 0.0, atol=1e-6) or np.isclose(val, 1.0, atol=1e-6) | |
| for val in unique_vals | |
| ) | |
| if not is_binary: | |
| print(f"WARNING: {name} has non-binary values: {unique_vals}") | |
| elif 'batch_norms' in name and 'weight' in name: | |
| unique_vals = torch.unique(param.data).cpu().numpy() | |
| # BatchNorm weights are expected to be non-binary, just report for info | |
| if len(unique_vals) <= 5: # Only show if small number of unique values | |
| print(f"INFO: {name} has values: {unique_vals}") | |
| def main(): | |
| """Main training function following paper's experimental setup""" | |
| print("Loading MNIST dataset with paper's batch size (16384)...") | |
| train_loader, test_loader = load_mnist_data(batch_size=16384) | |
| # Test different model sizes from Table 5 | |
| hidden_sizes = [1024] # Start with smaller sizes due to memory constraints | |
| for hidden_size in hidden_sizes: | |
| print(f"\n{'='*50}") | |
| print(f"Training Binary MLP with {hidden_size} hidden units") | |
| print(f"{'='*50}") | |
| # Initialize model with paper's architecture | |
| model = BinaryMLP(input_size=784, hidden_size=hidden_size, num_classes=10, num_layers=4) | |
| print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") | |
| # Verify initial weights are binary | |
| verify_binary_weights(model) | |
| # Train with paper's hyperparameters | |
| # From Table 8: best learning rate = 10.0, delta = 1e-3 for small model | |
| train_model( | |
| model, train_loader, test_loader, | |
| epochs=200, # Reduced from 1000 for faster testing | |
| delta=1e-1, # Paper's best delta | |
| learning_rate=0.01 # Paper's best learning rate | |
| ) | |
| # Final evaluation | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| final_accuracy = evaluate_model(model, test_loader, device) | |
| print(f"Final Test Accuracy for {hidden_size} hidden units: {final_accuracy:.2f}%") | |
| # Final verification | |
| verify_binary_weights(model) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment