Skip to content

Instantly share code, notes, and snippets.

@dvgodoy
Last active April 30, 2022 12:10
Show Gist options
  • Save dvgodoy/8568a533cebe844b90de2c6b02a4378b to your computer and use it in GitHub Desktop.
Save dvgodoy/8568a533cebe844b90de2c6b02a4378b to your computer and use it in GitHub Desktop.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_vae.to(device)
loss_fn = nn.MSELoss(reduction='none')
optim = torch.optim.Adam(model_vae.parameters(), 0.0003)
num_epochs = 30
train_losses = []
reconstruction_loss_factor = 1
for epoch in range(1, num_epochs+1):
batch_losses = []
for i, (x, _) in enumerate(circles_dl):
model_vae.train()
x = x.to(device)
# Step 1 - Computes our model's predicted output - forward pass
yhat = model_vae(x)
# Step 2 - Computes the loss
# reduce (sum) over pixels (dim=[1, 2, 3])
# and then reduce (sum) over batch (dim=0)
loss = loss_fn(yhat, x).sum(dim=[1, 2, 3]).sum(dim=0)
# reduce (sum) over z (dim=1)
# and then reduce (sum) over batch (dim=0)
kl_loss = model_vae.enc.kl_loss().sum(dim=1).sum(dim=0)
# we're adding the KL loss to the original MSE loss
total_loss = reconstruction_loss_factor * loss + kl_loss
# Step 3 - Computes gradients
total_loss.backward()
# Step 4 - Updates parameters using gradients and the learning rate
optim.step()
optim.zero_grad()
batch_losses.append(np.array([total_loss.data.item(),
loss.data.item(),
kl_loss.data.item()]))
# Average over batches
train_losses.append(np.array(batch_losses).mean(axis=0))
print(f'Epoch {epoch:03d} | Loss >> {train_losses[-1][0]:.4f}/ \
{train_losses[-1][1]:.4f}/{train_losses[-1][2]:.4f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment