Created
June 14, 2025 19:33
-
-
Save ehzawad/20c08a2ce6e0063b55faed2935244050 to your computer and use it in GitHub Desktop.
LeCun neural network implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
from torch.utils.data import random_split, DataLoader | |
from multiprocessing import freeze_support | |
def main(): | |
device = torch.device('cuda' if torch.cuda.is_available() else | |
'mps' if torch.backends.mps.is_available() else 'cpu') | |
print(f'Using device: {device}') | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
full_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) | |
test_ds = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) | |
val_size = 5000 | |
train_size = len(full_train) - val_size | |
train_ds, val_ds = random_split(full_train, [train_size, val_size]) | |
if device.type == 'cuda': | |
loader_kwargs = {'pin_memory': True, 'num_workers': 5} | |
elif device.type == 'mps': | |
loader_kwargs = {'num_workers': 5} | |
else: | |
loader_kwargs = {} | |
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, **loader_kwargs) | |
val_loader = DataLoader(val_ds, batch_size=1000, shuffle=False, **loader_kwargs) | |
test_loader = DataLoader(test_ds, batch_size=1000, shuffle=False, **loader_kwargs) | |
class LeNet5(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(1, 6, 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(6, 16, 5) | |
self.fc1 = nn.Linear(16 * 4 * 4, 120) | |
self.fc2 = nn.Linear(120, 84) | |
self.dropout = nn.Dropout(0.5) | |
self.fc3 = nn.Linear(84, 10) | |
def forward(self, x): | |
x = self.pool(torch.relu(self.conv1(x))) | |
x = self.pool(torch.relu(self.conv2(x))) | |
x = torch.flatten(x, 1) | |
x = torch.relu(self.fc1(x)) | |
x = torch.relu(self.fc2(x)) | |
x = self.dropout(x) | |
return self.fc3(x) | |
model = LeNet5().to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.95, weight_decay=5e-4) | |
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1) | |
best_val_loss = float('inf') | |
epochs_since_improve = 0 | |
max_epochs = 20 | |
def train_one_epoch(epoch): | |
model.train() | |
for batch_idx, (data, target) in enumerate(train_loader, start=1): | |
data, target = data.to(device), target.to(device) | |
optimizer.zero_grad() | |
loss = criterion(model(data), target) | |
loss.backward() | |
optimizer.step() | |
if batch_idx % 100 == 0: | |
print(f'Epoch {epoch} batch {batch_idx} loss {loss.item():.4f}') | |
def evaluate(loader): | |
model.eval() | |
total, correct, cum_loss = 0, 0, 0.0 | |
with torch.no_grad(): | |
for data, target in loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
cum_loss += criterion(output, target).item() * target.size(0) | |
_, pred = output.max(1) | |
total += target.size(0) | |
correct += pred.eq(target).sum().item() | |
return cum_loss / total, 100. * correct / total | |
for epoch in range(1, max_epochs + 1): | |
train_one_epoch(epoch) | |
val_loss, val_acc = evaluate(val_loader) | |
print(f'Validation: loss {val_loss:.4f} acc {val_acc:.2f}%') | |
scheduler.step(val_loss) | |
current_lr = optimizer.param_groups[0]['lr'] | |
print(f'LR now {current_lr:.6f}') | |
if val_loss < best_val_loss: | |
best_val_loss = val_loss | |
epochs_since_improve = 0 | |
torch.save(model.state_dict(), 'best_lenet5.pth') | |
else: | |
epochs_since_improve += 1 | |
if epochs_since_improve >= 3: | |
print(f'No improvement for {epochs_since_improve} epochs, stopping early.') | |
break | |
test_loss, test_acc = evaluate(test_loader) | |
print(f'Final Test: loss {test_loss:.4f} acc {test_acc:.2f}%') | |
if __name__ == '__main__': | |
freeze_support() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment