Skip to content

Instantly share code, notes, and snippets.

@SolomidHero
Last active May 20, 2021 09:50
Show Gist options
  • Save SolomidHero/c99cbb031b2f152ccd87a51e09807745 to your computer and use it in GitHub Desktop.
Save SolomidHero/c99cbb031b2f152ccd87a51e09807745 to your computer and use it in GitHub Desktop.
PatchGAN - Discriminator for frequency features
# Least Squares GAN loss
def adversarial_loss(scores, as_real=True):
if as_real:
return torch.mean((1 - scores) ** 2)
return torch.mean(scores ** 2)
def discriminator_loss(fake_scores, real_scores):
loss = adversarial_loss(fake_scores, as_real=False) + adversarial_loss(real_scores, as_real=True)
return loss
# PatchGAN
# ref: https://github.com/jackaduma/CycleGAN-VC2/blob/master/model_tf.py
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.inputConvLayer = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.SiLU()
)
# DownSample Layer
self.down1 = self.downSample(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=1)
self.down2 = self.downSample(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=1)
self.down3 = self.downSample(in_channels=512, out_channels=1024, kernel_size=(3, 3), stride=(2, 2), padding=1)
# self.down4 = self.downSample(in_channels=1024, out_channels=1024, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2))
# Conv Layer
self.outputConvLayer = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
def downSample(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
nn.InstanceNorm2d(num_features=out_channels, affine=True),
nn.SiLU()
)
def forward(self, input):
# input has shape (batch_size, num_features, time)
# discriminator requires shape (batchSize, 1, num_features, time)
x = self.inputConvLayer(input.unsqueeze(1))
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)
# x = self.down4(x)
output = self.outputConvLayer(x)
return output
# Discriminator Dimensionality Testing
input = torch.randn(32, 80, 1337) # (N, C_in, T_in) For Conv2d
discriminator = Discriminator()
output = discriminator(input)
print("Discriminator output shape ", output.shape) # (N, 1, C_out, T_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment