Skip to content

Instantly share code, notes, and snippets.

@ehzawad
Created June 14, 2025 19:33
Show Gist options
  • Save ehzawad/20c08a2ce6e0063b55faed2935244050 to your computer and use it in GitHub Desktop.
Save ehzawad/20c08a2ce6e0063b55faed2935244050 to your computer and use it in GitHub Desktop.
LeCun neural network implementation
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import random_split, DataLoader
from multiprocessing import freeze_support
def main():
device = torch.device('cuda' if torch.cuda.is_available() else
'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
full_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_ds = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
val_size = 5000
train_size = len(full_train) - val_size
train_ds, val_ds = random_split(full_train, [train_size, val_size])
if device.type == 'cuda':
loader_kwargs = {'pin_memory': True, 'num_workers': 5}
elif device.type == 'mps':
loader_kwargs = {'num_workers': 5}
else:
loader_kwargs = {}
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, **loader_kwargs)
val_loader = DataLoader(val_ds, batch_size=1000, shuffle=False, **loader_kwargs)
test_loader = DataLoader(test_ds, batch_size=1000, shuffle=False, **loader_kwargs)
class LeNet5(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.dropout = nn.Dropout(0.5)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.dropout(x)
return self.fc3(x)
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.95, weight_decay=5e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)
best_val_loss = float('inf')
epochs_since_improve = 0
max_epochs = 20
def train_one_epoch(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader, start=1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
loss = criterion(model(data), target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch} batch {batch_idx} loss {loss.item():.4f}')
def evaluate(loader):
model.eval()
total, correct, cum_loss = 0, 0, 0.0
with torch.no_grad():
for data, target in loader:
data, target = data.to(device), target.to(device)
output = model(data)
cum_loss += criterion(output, target).item() * target.size(0)
_, pred = output.max(1)
total += target.size(0)
correct += pred.eq(target).sum().item()
return cum_loss / total, 100. * correct / total
for epoch in range(1, max_epochs + 1):
train_one_epoch(epoch)
val_loss, val_acc = evaluate(val_loader)
print(f'Validation: loss {val_loss:.4f} acc {val_acc:.2f}%')
scheduler.step(val_loss)
current_lr = optimizer.param_groups[0]['lr']
print(f'LR now {current_lr:.6f}')
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_since_improve = 0
torch.save(model.state_dict(), 'best_lenet5.pth')
else:
epochs_since_improve += 1
if epochs_since_improve >= 3:
print(f'No improvement for {epochs_since_improve} epochs, stopping early.')
break
test_loss, test_acc = evaluate(test_loader)
print(f'Final Test: loss {test_loss:.4f} acc {test_acc:.2f}%')
if __name__ == '__main__':
freeze_support()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment