Created
November 21, 2019 18:11
-
-
Save n0obcoder/1fe0fd0d518d278493fba3b15a81b3b2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Defining a trainer function which would train and validate the model | |
| def trainer(dataloader_dict, model, loss_fn, optimizer, epochs = 1, log_interval = 1): | |
| print('Training started...') | |
| train_losses = [] | |
| val_losses = [] | |
| batch_train_losses = [] | |
| batch_val_losses = [] | |
| for epoch in range(epochs): | |
| print('epoch >>> {}/{}'.format(epoch + 1, epochs)) | |
| for phase in ['train', 'val']: | |
| if phase == 'train': | |
| print('___TRAINING___') | |
| model.train() | |
| else: | |
| print('___VALIDATION___') | |
| model.eval() | |
| epoch_loss = 0 | |
| for batch_idx, (inputs, labels) in enumerate(loader[phase]): | |
| # Zero the parameter gradients | |
| optimizer.zero_grad() | |
| # forward | |
| with torch.set_grad_enabled(phase == 'train'): | |
| outputs = model(inputs) | |
| batch_loss = loss_fn(outputs, labels) | |
| epoch_loss += batch_loss.item()*inputs.shape[0] | |
| # Saving the batch losses | |
| if phase == 'train': | |
| batch_train_losses.append(batch_loss.item()) | |
| else: | |
| batch_val_losses.append(batch_loss.item()) | |
| if phase == 'train': | |
| # Backpropagation | |
| batch_loss.backward() | |
| optimizer.step() | |
| if (batch_idx + 1)%log_interval == 0: | |
| print('batch_loss at batch_idx {}/{}: {}'.format(str(batch_idx).zfill(len(str(len(train_loader)))), len(train_loader), batch_loss)) | |
| mean_epoch_loss = epoch_loss/len(loader[phase].dataset) | |
| print('>>> {} loss at epoch {}/{}: {}'.format(phase, epoch + 1, epochs, mean_epoch_loss)) | |
| # Storing the losses | |
| if phase == 'train': | |
| train_losses.append(mean_epoch_loss) | |
| else: | |
| val_losses.append(mean_epoch_loss) | |
| print('====='*5) | |
| return train_losses, val_losses, batch_train_losses, batch_val_losses |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment