Skip to content

Instantly share code, notes, and snippets.

@kevinmartinjos
Created September 14, 2025 18:39
Show Gist options
  • Save kevinmartinjos/b6f9cc860110e40085eec3d84d70162a to your computer and use it in GitHub Desktop.
Save kevinmartinjos/b6f9cc860110e40085eec3d84d70162a to your computer and use it in GitHub Desktop.
Shibuya et al: Binary MLP for MNIST
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
import numpy as np
from typing import Dict, Any
from torch.optim import Optimizer
class StraightThroughBinaryActivation(torch.autograd.Function):
"""Binary activation with straight-through estimator"""
@staticmethod
def forward(ctx, input):
# Forward pass: apply sign but mapped to {0,1}
return (torch.sign(input) >= 0).float()
@staticmethod
def backward(ctx, grad_output):
# Backward pass: straight-through estimator
# Pass gradients through as-is (identity)
return grad_output
class BinaryActivation(nn.Module):
"""Binary activation function: sign(x) mapped to {0,1} with STE"""
def forward(self, x):
return StraightThroughBinaryActivation.apply(x)
class BinaryLinear(nn.Module):
"""Linear layer with binary weights constrained to {0, 1} as per paper"""
def __init__(self, in_features: int, out_features: int):
super(BinaryLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize binary weights to {0, 1} using Bernoulli(p=0.5) as per paper
# Make sure requires_grad is True for gradient computation
self.weight = nn.Parameter(
torch.bernoulli(torch.full((out_features, in_features), 0.5)),
requires_grad=True
)
# No bias for simplicity
self.register_parameter('bias', None)
def forward(self, input):
# Use the binary weights directly for computation
# The gradient will be computed through the loss and backpropagated
return F.linear(input, self.weight)
class BinaryMLP(nn.Module):
"""Multi-layer perceptron following paper's architecture exactly"""
def __init__(self, input_size=784, hidden_size=128, num_classes=10, num_layers=4):
super(BinaryMLP, self).__init__()
self.layers = nn.ModuleList()
self.batch_norms = nn.ModuleList()
# First layer: input -> hidden
self.layers.append(BinaryLinear(input_size, hidden_size))
self.batch_norms.append(nn.BatchNorm1d(hidden_size))
# Hidden layers
for _ in range(num_layers - 2):
self.layers.append(BinaryLinear(hidden_size, hidden_size))
self.batch_norms.append(nn.BatchNorm1d(hidden_size))
# Output layer
self.layers.append(BinaryLinear(hidden_size, num_classes))
self.batch_norms.append(nn.BatchNorm1d(num_classes))
self.activation = BinaryActivation()
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten
# Forward pass following Algorithm 1
for i, (layer, bn) in enumerate(zip(self.layers, self.batch_norms)):
# W_{l}.h_{l-1}
x = layer(x)
# a_{l} = BatchNorm(W_{l}.h_{l-1})
x = bn(x)
# h_{l} = sgn(a_{l})
x = self.activation(x)
if i == len(self.layers) - 1:
x = F.log_softmax(x, dim=1)
return x
class HyperMaskOptimizer(Optimizer):
"""
Custom optimizer implementing Algorithm 1 from the paper.
Uses {0,1} binary weights and proper Boolean operations.
"""
def __init__(self, params, delta=1e-3):
"""
Args:
params: Model parameters (binary weights in {0,1})
delta: Probability for random mask (δ_t in paper)
"""
defaults = dict(delta=delta)
super(HyperMaskOptimizer, self).__init__(params, defaults)
self.delta = delta
def zero_grad(self, set_to_none: bool = True):
"""Zero out gradients"""
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
p.grad.zero_()
def step(self, closure=None):
"""
Perform one optimization step using Algorithm 1 from paper.
Following the exact steps:
1. Compute real gradient g_l
2. Compute target weight w*_l = ⌊-g_l >= 0⌋
3. Sample hypermask m_l from random distribution
4. Update: w_l = m_l · w_{l-1} + m̄_l · w*_l
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for param in group['params']:
if param.grad is None:
continue
grad = param.grad
current_weights = param.data.clone() # w_{t-1} in {0,1}
# Step 2: Compute target weight w* = ⌊-g >= 0⌋
# ⌊condition⌋ gives 1 if condition is True, 0 if False
target_weights = (-grad >= 0).float() # w*_t in {0,1}
# Step 3: Sample hypermask from random distribution
# P(M_t = 1) = δ_t (Implementation 5)
mask = torch.bernoulli(torch.full_like(grad, self.delta)) # m_t in {0,1}
# Step 4: Update using Boolean operations
# w_t = m_t · w_{t-1} + m̄_t · w*_t
# Where · is AND, m̄_t is NOT m_ttarge
mask_not = 1 - mask # m̄_t (NOT operation)
new_weights = (mask_not * current_weights) + (mask * target_weights)
# Update parameters while preserving the computation graph
param.data.copy_(new_weights)
# Ensure weights remain in {0, 1}
param.data.clamp_(0, 1)
param.data.round_()
return loss
def load_mnist_data(batch_size=16384):
"""Load MNIST dataset from HuggingFace with correct batch size from paper"""
dataset = load_dataset("mnist")
def transform_data(examples):
# Convert to tensors and normalize to [0,1]
images = torch.tensor(np.array(examples['image']), dtype=torch.float32) / 255.0
labels = torch.tensor(examples['label'], dtype=torch.long)
return {'image': images, 'label': labels}
train_dataset = dataset['train'].with_transform(transform_data)
test_dataset = dataset['test'].with_transform(transform_data)
def collate_fn(batch):
images = torch.stack([item['image'] for item in batch])
labels = torch.tensor([item['label'] for item in batch])
return images, labels
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
return train_loader, test_loader
# Removed hook system - using autograd.Function instead
def train_model(model, train_loader, test_loader, epochs=1000, delta=1e-3, learning_rate=10.0):
"""Train the binary neural network following paper's hyperparameters"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# STE is now handled by the custom autograd.Function
# Get only the binary layer parameters for the HyperMask optimizer
binary_params = []
bn_params = []
for name, param in model.named_parameters():
if 'layers' in name and 'weight' in name: # Binary layer weights
param.requires_grad_(True)
binary_params.append(param)
elif 'batch_norms' in name: # BatchNorm parameters
param.requires_grad_(True)
bn_params.append(param)
# Use custom HyperMask optimizer for binary weights only
hypermask_optimizer = HyperMaskOptimizer(binary_params, delta=delta)
# Use standard SGD for BatchNorm parameters
bn_optimizer = torch.optim.SGD(bn_params, lr=0.01)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# Zero gradients for both optimizers
hypermask_optimizer.zero_grad()
bn_optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# Apply optimization steps
hypermask_optimizer.step()
bn_optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# Test accuracy (report every 100 epochs to match paper's long training)
test_acc = evaluate_model(model, test_loader, device)
train_acc = 100 * correct / total
avg_loss = total_loss / len(train_loader)
print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, '
f'Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')
# Verify weights are still binary
verify_binary_weights(model)
def evaluate_model(model, test_loader, device):
"""Evaluate model accuracy"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
outputs = model(data)
_, predicted = torch.max(outputs, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
model.train()
return 100 * correct / total
def verify_binary_weights(model):
"""Verify that all weights are properly binary {0, 1}"""
for name, param in model.named_parameters():
if 'layers' in name and 'weight' in name: # Only check binary layer weights
unique_vals = torch.unique(param.data).cpu().numpy()
# Check if we have only values close to 0 and 1
is_binary = len(unique_vals) <= 2 and all(
np.isclose(val, 0.0, atol=1e-6) or np.isclose(val, 1.0, atol=1e-6)
for val in unique_vals
)
if not is_binary:
print(f"WARNING: {name} has non-binary values: {unique_vals}")
elif 'batch_norms' in name and 'weight' in name:
unique_vals = torch.unique(param.data).cpu().numpy()
# BatchNorm weights are expected to be non-binary, just report for info
if len(unique_vals) <= 5: # Only show if small number of unique values
print(f"INFO: {name} has values: {unique_vals}")
def main():
"""Main training function following paper's experimental setup"""
print("Loading MNIST dataset with paper's batch size (16384)...")
train_loader, test_loader = load_mnist_data(batch_size=16384)
# Test different model sizes from Table 5
hidden_sizes = [1024] # Start with smaller sizes due to memory constraints
for hidden_size in hidden_sizes:
print(f"\n{'='*50}")
print(f"Training Binary MLP with {hidden_size} hidden units")
print(f"{'='*50}")
# Initialize model with paper's architecture
model = BinaryMLP(input_size=784, hidden_size=hidden_size, num_classes=10, num_layers=4)
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
# Verify initial weights are binary
verify_binary_weights(model)
# Train with paper's hyperparameters
# From Table 8: best learning rate = 10.0, delta = 1e-3 for small model
train_model(
model, train_loader, test_loader,
epochs=200, # Reduced from 1000 for faster testing
delta=1e-1, # Paper's best delta
learning_rate=0.01 # Paper's best learning rate
)
# Final evaluation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
final_accuracy = evaluate_model(model, test_loader, device)
print(f"Final Test Accuracy for {hidden_size} hidden units: {final_accuracy:.2f}%")
# Final verification
verify_binary_weights(model)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment