Skip to content

Instantly share code, notes, and snippets.

@enochkan
Created June 12, 2020 05:33
Show Gist options
  • Save enochkan/365f48d2d55df969f9232bdad5c04e9c to your computer and use it in GitHub Desktop.
Save enochkan/365f48d2d55df969f9232bdad5c04e9c to your computer and use it in GitHub Desktop.
vox2vox model losses
# ---------------------
# Train Discriminator, only update every disc_update batches
# ---------------------
# Real loss
fake_B = generator(real_A)
pred_real = discriminator(real_B, real_A)
loss_real = criterion_GAN(pred_real, valid)
# Fake loss
pred_fake = discriminator(fake_B.detach(), real_A)
loss_fake = criterion_GAN(pred_fake, fake)
# Total loss
loss_D = 0.5 * (loss_real + loss_fake)
d_real_acu = torch.ge(pred_real.squeeze(), 0.5).float()
d_fake_acu = torch.le(pred_fake.squeeze(), 0.5).float()
d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))
if d_total_acu <= opt.d_threshold:
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
discriminator_update = 'True'
# ------------------
# Train Generators
# ------------------
optimizer_D.zero_grad()
optimizer_G.zero_grad()
# GAN loss
fake_B = generator(real_A)
pred_fake = discriminator(fake_B, real_A)
loss_GAN = criterion_GAN(pred_fake, valid)
# Voxel-wise loss
loss_voxel = criterion_voxelwise(fake_B, real_B)
# Total loss
loss_G = loss_GAN + lambda_voxel * loss_voxel
loss_G.backward()
optimizer_G.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment