Skip to content

Instantly share code, notes, and snippets.

@Adityanagraj
Created July 2, 2020 08:53
Show Gist options
  • Save Adityanagraj/9d0cba52d356d6ade8cd0f266a625845 to your computer and use it in GitHub Desktop.
Save Adityanagraj/9d0cba52d356d6ade8cd0f266a625845 to your computer and use it in GitHub Desktop.
class FishModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(input_size, output_size) # fill this (hint: use input_size & output_size defined above)
def forward(self, xb):
out = self.linear(xb) # fill this
return out
def training_step(self, batch):
inputs, targets = batch
# Generate predictions
out = self(inputs)
# Calcuate loss
loss = F.smooth_l1_loss(out, targets) # fill this
return loss
def validation_step(self, batch):
inputs, targets = batch
# Generate predictions
out = self(inputs)
# Calculate loss
loss = F.smooth_l1_loss(out, targets) # fill this
return {'val_loss': loss.detach()}
def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine losses
return {'val_loss': epoch_loss.item()}
def epoch_end(self, epoch, result, num_epochs):
# Print result every 20th epoch
if (epoch+1) % 20 == 0 or epoch == num_epochs-1:
print("Epoch [{}], val_loss: {:.4f}".format(epoch+1, result['val_loss']))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment