Last active
January 25, 2019 21:56
-
-
Save erykml/52bba6103ba1aa233d3b2e2b480423c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train_cnn(model, train_loader, valid_loader, | |
| criterion, optimizer, n_epochs = 30, train_on_gpu = False, | |
| save_model_on_improvement = True, plot_loss = True): | |
| ''' | |
| Function for training the CNN given input parameters. Can be run on CPU or GPU. | |
| The function automatically verifies whether the selected criterion is Binary cross-entropy and if so | |
| converts tensors to appropriate type. | |
| Inputs: | |
| model - architecture of the neural network defined using either Class approach or Sequential | |
| train_loader - loader of the dataset used for training | |
| valid_loader - loader of the dataset used for validation | |
| criterion - loss function | |
| optimizer - selected optimizer | |
| n_epochs - number of epochs | |
| train_on_gpu - boolean; whether to train using GPU | |
| save_model_on_improvement - boolean; whether to save the model when validation loss decreases compared to previous epoch | |
| plot_loss - boolean; whether to plot the train/validation loss over epochs | |
| ''' | |
| valid_loss_min = np.Inf # track change in validation loss | |
| train_losses, valid_losses = [], [] | |
| times = [] | |
| for epoch in range(1, n_epochs + 1): | |
| # keep track of training and validation loss | |
| train_loss = 0.0 | |
| valid_loss = 0.0 | |
| # keep track of time | |
| if train_on_gpu: | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| # train the model ---- | |
| model.train() | |
| for data, target in train_loader: | |
| # move tensors to GPU if CUDA is available | |
| if train_on_gpu: | |
| data, target = data.cuda(), target.cuda() | |
| # convert long to float tensor if eval. criterion = BCEWithLogitsLoss | |
| if ((type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss) | |
| and (target.type() == 'torch.cuda.LongTensor') | |
| and train_on_gpu): | |
| target = target.type(torch.cuda.FloatTensor) | |
| elif ((type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss) and | |
| (target.type() == 'torch.LongTensor')): | |
| target = target.type(torch.FloatTensor) | |
| # reset the gradients of all optimized variables | |
| optimizer.zero_grad() | |
| # 1. forward pass | |
| output = model(data).squeeze() | |
| # 2. calculate the batch loss | |
| loss = criterion(output, target) | |
| # 3. backward pass | |
| loss.backward() | |
| # 4. perform a single optimization step (parameter update) | |
| optimizer.step() | |
| # update training loss | |
| train_loss += loss.item()*data.size(0) | |
| # validate the model ---- | |
| # turn off gradients | |
| with torch.no_grad(): | |
| model.eval() | |
| for data, target in valid_loader: | |
| # move tensors to GPU if CUDA is available | |
| if train_on_gpu: | |
| data, target = data.cuda(), target.cuda() | |
| # convert long to float tensor if eval. criterion = BCEWithLogitsLoss | |
| if ((type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss) | |
| and (target.type() == 'torch.cuda.LongTensor') | |
| and train_on_gpu): | |
| target = target.type(torch.cuda.FloatTensor) | |
| elif ((type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss) and | |
| (target.type() == 'torch.LongTensor')): | |
| target = target.type(torch.FloatTensor) | |
| # 1. forward pass | |
| output = model(data).squeeze() | |
| # 2. calculate the batch loss | |
| loss = criterion(output, target) | |
| # update average validation loss | |
| valid_loss += loss.item()*data.size(0) | |
| # calculate average losses | |
| train_loss = train_loss/len(train_loader.dataset) | |
| train_losses.append(train_loss) | |
| valid_loss = valid_loss/len(valid_loader.dataset) | |
| valid_losses.append(valid_loss) | |
| # append time of entire epoch (not counting model saving etc.) | |
| if train_on_gpu: | |
| torch.cuda.synchronize() | |
| t1 = time.perf_counter() | |
| times.append(t1 - t0) | |
| # print training/validation statistics | |
| print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tTraining time: {:.2f} s'.format( | |
| epoch, train_loss, valid_loss, t1 - t0)) | |
| # save model if validation loss has decreased | |
| if valid_loss <= valid_loss_min: | |
| print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format( | |
| valid_loss_min, | |
| valid_loss)) | |
| torch.save(model.state_dict(), 'model_mvw.pt') | |
| valid_loss_min = valid_loss | |
| if plot_loss: | |
| plt.plot(train_losses, label='Training loss') | |
| plt.plot(valid_losses, label='Validation loss') | |
| plt.legend(frameon=False) | |
| return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment