Skip to content

Instantly share code, notes, and snippets.

@WillKoehrsen
Created November 26, 2018 18:45
Show Gist options
  • Save WillKoehrsen/4be9d8b7284a13228d10756ce2158aeb to your computer and use it in GitHub Desktop.
Save WillKoehrsen/4be9d8b7284a13228d10756ce2158aeb to your computer and use it in GitHub Desktop.
# Early stopping details
n_epochs_stop = 5
min_val_loss = np.Inf
epochs_no_improve = 0
# Main loop
for epoch in range(n_epochs):
# Initialize validation loss for epoch
val_loss = 0
# Training loop
for data, targets in trainloader:
# Generate predictions
out = model(data)
# Calculate loss
loss = criterion(out, targets)
# Backpropagation
loss.backward()
# Update model parameters
optimizer.step()
# Validation loop
for data, targets in validloader:
# Generate predictions
out = model(data)
# Calculate loss
loss = criterion(out, targets)
val_loss += loss
# Average validation loss
val_loss = val_loss / len(trainloader)
# If the validation loss is at a minimum
if val_loss < min_val_loss:
# Save the model
torch.save(model, checkpoint_path)
epochs_no_improve = 0
min_val_loss = val_loss
else:
epochs_no_improve += 1
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!')
# Load in the best model
model = torch.load(checkpoint_path)
@filmo
Copy link

filmo commented Sep 12, 2019

Is the indentation correct on this?? It seems like it goes through the entire epoch loop before entering the validation loop.

@WillKoehrsen
Copy link
Author

Yes, the validation runs once per epoch. This is generally how I've seen early stopping implemented; if the validation loss does not decrease for 5 epochs, the training stops and we use the best model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment