Created
May 5, 2020 19:55
-
-
Save burrussmp/fc66d82a1979644697dff2e38465f7b9 to your computer and use it in GitHub Desktop.
Training and validation for annotation segmentation network
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# used to penalize the model less when it predicts a 0 to account for | |
# slight frequency issues in the training space i.e. imbalances of label 0 and other labels | |
weights = np.ones(27) | |
weights[0] = 0.25 | |
class_weights = torch.FloatTensor(weights).cuda() | |
""" | |
Function to train model for a single epoch | |
@params | |
model: PyTorch.nn.Module | |
A segmentation model | |
device: PyTorch.device | |
Which device to allocate tensors (GPU, CPU, TPU, etc.) | |
train_loader: PyTorch.DataLoader | |
DataLoader class initialized with training data set | |
optimizer: PyTorch.optimizer | |
A training optimizer (ex. SGD, Adam, Adagrad, etc) | |
epoch: int | |
Epoch # that is currently being evaluated | |
@return | |
avg_loss: PyTorch.tensor | |
Loaded into device (GPU or CPU) and contains the average loss of the epoch specified by the criterion | |
""" | |
def train(model, device, train_loader, optimizer, epoch): | |
model.train() # training mode | |
criterion = nn.CrossEntropyLoss(class_weights) | |
total_loss = 0.0 | |
total_tested = 0 | |
for batch_idx, loaded in enumerate(train_loader):# iterate across training dataset using batch size | |
data = loaded['src'].to(device) | |
target = loaded['target'].to(device) | |
optimizer.zero_grad() # set gradients to zero | |
output = model(data.float()) # get the outputs of the model | |
loss = criterion(output,target.max(1)[1]) | |
total_loss += loss | |
loss.backward() # Accumulate the gradient | |
optimizer.step() # based on currently stored gradient update model params using optomizer rules | |
total_tested += 1 | |
if batch_idx % 20 == 0: # provide updates on training process | |
print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tAvg Loss: {:.6f}'.format( | |
epoch, batch_idx, len(train_loader), | |
100. * batch_idx / len(train_loader), total_loss/(total_tested+1e-8))) | |
print(torch.sum(output.max(1)[1]!=0)) | |
avg_loss = total_loss/(len(train_loader)+1e-8) | |
return avg_loss | |
""" | |
Function to validate model for a single epoch i.e. does not store the gradients | |
per pass | |
@params | |
model: PyTorch.nn.Module | |
A segmentation model | |
device: PyTorch.device | |
Which device to allocate tensors (GPU, CPU, TPU, etc.) | |
loader: PyTorch.DataLoader | |
DataLoader class initialized with validation data set | |
@return | |
avg_loss: PyTorch.tensor | |
Loaded into device (GPU or CPU) and contains the average loss of the epoch specified by the criterion | |
""" | |
def validate(model, device, loader): | |
model.eval() # inference mode | |
criterion = nn.CrossEntropyLoss() | |
test_loss = 0.0 | |
with torch.no_grad(): | |
for batch_idx, loaded in enumerate(loader):# iterate across training dataset using batch size | |
data = loaded['src'].to(device) | |
target = loaded['target'].to(device) | |
output = model(data.float()) # collect the outputs | |
loss = criterion(output,target.max(1)[1]) | |
test_loss += loss | |
avg_loss = test_loss / (len(loader)+1e-8) # compute the average loss | |
print('\nTest set: Average loss: {:.4f}\n'.format( | |
avg_loss)) | |
return avg_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment