Created
April 27, 2021 01:31
-
-
Save 123epsilon/44406cc3230bc71f71b8efdfb441f0f0 to your computer and use it in GitHub Desktop.
GAN Generator and Discriminator Models, simple
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#binary classifier that determines whether data points are from the original distribution or the fake (generated) distribution | |
class Discriminator(nn.Module): | |
def __init__(self, input_dim=2, hidden_dim=28, n_layers=3): | |
super(Discriminator,self).__init__() | |
self.input = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LeakyReLU() ) | |
self.layers = [] | |
for i in range(n_layers): | |
self.layers.append( nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU() ) ) | |
self.layers = nn.ModuleList(self.layers) | |
self.output = nn.Sequential( nn.Linear(hidden_dim, 1), nn.Sigmoid() ) | |
def forward(self, x): | |
o = self.input(x) | |
for layer in self.layers: | |
o = layer(o) | |
o = self.output(o) | |
return o | |
#Generator, takes in random noise and outputs a fake point | |
class Generator(nn.Module): | |
def __init__(self, z_dim=1, hidden_dim=28, n_layers=3, out_dim=2): | |
super(Generator, self).__init__() | |
self.input = nn.Sequential( nn.Linear(z_dim, hidden_dim), nn.LeakyReLU() ) | |
self.layers = [] | |
for i in range(n_layers): | |
self.layers.append( nn.Sequential( nn.Linear(hidden_dim, hidden_dim ), nn.LeakyReLU() ) ) | |
self.layers = nn.ModuleList(self.layers) | |
self.output = nn.Linear(hidden_dim, out_dim) | |
def forward(self, x): | |
o = self.input(x) | |
for layer in self.layers: | |
o = layer(o) | |
o = self.output(o) | |
return o |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment