Created
November 26, 2018 18:45
-
-
Save WillKoehrsen/4be9d8b7284a13228d10756ce2158aeb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Early stopping details | |
n_epochs_stop = 5 | |
min_val_loss = np.Inf | |
epochs_no_improve = 0 | |
# Main loop | |
for epoch in range(n_epochs): | |
# Initialize validation loss for epoch | |
val_loss = 0 | |
# Training loop | |
for data, targets in trainloader: | |
# Generate predictions | |
out = model(data) | |
# Calculate loss | |
loss = criterion(out, targets) | |
# Backpropagation | |
loss.backward() | |
# Update model parameters | |
optimizer.step() | |
# Validation loop | |
for data, targets in validloader: | |
# Generate predictions | |
out = model(data) | |
# Calculate loss | |
loss = criterion(out, targets) | |
val_loss += loss | |
# Average validation loss | |
val_loss = val_loss / len(trainloader) | |
# If the validation loss is at a minimum | |
if val_loss < min_val_loss: | |
# Save the model | |
torch.save(model, checkpoint_path) | |
epochs_no_improve = 0 | |
min_val_loss = val_loss | |
else: | |
epochs_no_improve += 1 | |
# Check early stopping condition | |
if epochs_no_improve == n_epochs_stop: | |
print('Early stopping!') | |
# Load in the best model | |
model = torch.load(checkpoint_path) |
Yes, the validation runs once per epoch. This is generally how I've seen early stopping implemented; if the validation loss does not decrease for 5 epochs, the training stops and we use the best model.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is the indentation correct on this?? It seems like it goes through the entire epoch loop before entering the validation loop.