Created
September 6, 2021 07:04
-
-
Save andreaschandra/a7ec282c9bd18807d2db053c87499b28 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Training Routine | |
# set 100 epoch | |
for epoch in range(1, 101): | |
train_loss = 0 | |
test_loss = 0 | |
model.train() | |
train_gen = DataLoader(dataset_train, batch_size=2) | |
for batch_index, (x, y) in enumerate(train_gen, 1): | |
optimizer.zero_grad() | |
y_pred = model(x) | |
loss = loss_fn(y_pred.squeeze(), y) | |
train_loss += (loss.item() - train_loss) / batch_index | |
loss.backward() | |
optimizer.step() | |
model.eval() | |
test_gen = DataLoader(dataset_test, batch_size=2) | |
for batch_index, (x, y) in enumerate(test_gen, 1): | |
with torch.no_grad(): | |
y_pred = model(x) | |
loss = loss_fn(y_pred.squeeze(), y) | |
test_loss += (loss.item() - test_loss) / batch_index | |
print(f"epoch: {epoch}") | |
print(f"train loss: {train_loss:.2f} | test loss: {test_loss:.2f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment