Created
May 29, 2021 06:46
-
-
Save rish-16/73d89289743b57cd3dd7c370396deec3 to your computer and use it in GitHub Desktop.
A guide on Colab TPU training using PyTorch XLA (Part 8)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# hlper function to get the testing accuracy at the end of the epoch | |
def get_test_stats(model, loader): | |
total_samples = 0 | |
correct = 0 | |
model.eval() # switch to eval mode | |
for (batch_idx, data) in enumerate(loader, 0): | |
x, y = data | |
logits = model(x) | |
preds = torch.argmax(logits, 1) | |
correct += torch.eq(y, preds).sum().item() | |
total_samples += flags['batch_size'] # more on flags later | |
accuracy = 100.0 * (correct / total_samples) | |
return accuracy | |
EPOCHS = 10 # feel free to change | |
for epoch in range(EPOCHS): | |
# (optional) calculate the batch-wise loss | |
running_loss = 0 | |
steps = 0 | |
model.train() # switch to train mode since we will switch to eval mode later | |
# get the specialised parallel train loader | |
para_loader_train = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) | |
for (batch_idx, data) in enumerate(para_loader_train, 0): | |
steps += 1 | |
x, y = data | |
output = net(x) | |
loss = criterion(output, y) | |
optimizer.zero_grad() | |
loss.backward() | |
running_loss += loss.item() | |
xm.optimizer_step(optimizer) | |
if (i % 20 == 0): # print stuff out to console | |
xm.master_print('{} | RunningLoss={} | Loss={}'.format( | |
batch_idx, running_loss / steps, loss.item()), | |
flush=True | |
) | |
xm.master_print("Finished training epoch {}".format(epoch)) | |
# get the specialised parallel test loader | |
para_loader_test = pl.ParallelLoader(test_loader, [device]).per_device_loader(device) | |
val_accuracy = get_test_stats(model, para_loader_test) | |
xm.master_print("Validation Accuracy: {}".format(val_accuracy)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment