Created
June 19, 2020 10:14
-
-
Save aliwaqas333/5d53e4a85a43f32db9e2a778b7d49fb1 to your computer and use it in GitHub Desktop.
Image Classification base class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ImageClassificationBase(nn.Module): | |
def training_step(self, batch): | |
images, labels = batch | |
out = self(images) # Generate predictions | |
loss = F.cross_entropy(out, labels.long()) # Calculate loss | |
return loss | |
def validation_step(self, batch): | |
images, labels = batch | |
out = self(images) # Generate predictions | |
loss = F.cross_entropy(out, labels.long()) # Calculate loss | |
acc = accuracy(out, labels) # Calculate accuracy | |
return {'val_loss': loss.detach(), 'val_acc': acc} | |
def validation_epoch_end(self, outputs): | |
batch_losses = [x['val_loss'] for x in outputs] | |
epoch_loss = torch.stack(batch_losses).mean() # Combine losses | |
batch_accs = [x['val_acc'] for x in outputs] | |
epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies | |
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} | |
def epoch_end(self, epoch, result): | |
print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( | |
epoch, result['train_loss'], result['val_loss'], result['val_acc'])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment