Created
June 27, 2019 08:20
-
-
Save Lexie88rus/ef4c0db63b07b645ffb783ffcd858a42 to your computer and use it in GitHub Desktop.
Sample training function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# helper function to train a model | |
def train_model(model, trainloader): | |
''' | |
Function trains the model and prints out the training log. | |
INPUT: | |
model - initialized PyTorch model ready for training. | |
trainloader - PyTorch dataloader for training data. | |
''' | |
#setup training | |
#define loss function | |
criterion = nn.NLLLoss() | |
#define learning rate | |
learning_rate = 0.003 | |
#define number of epochs | |
epochs = 5 | |
#initialize optimizer | |
optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
#run training and print out the loss to make sure that we are actually fitting to the training set | |
print('Training the model. Make sure that loss decreases after each epoch.\n') | |
for e in range(epochs): | |
running_loss = 0 | |
for images, labels in trainloader: | |
images = images.view(images.shape[0], -1) | |
log_ps = model(images) | |
loss = criterion(log_ps, labels) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
else: | |
# print out the loss to make sure it is decreasing | |
print(f"Training loss: {running_loss}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment