Created
May 3, 2023 07:58
-
-
Save vankesteren/684de6a3f44597ac9e09158457b3c0d0 to your computer and use it in GitHub Desktop.
Smallest possible useful GAN in R torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# setup with seeds for reproducibility | |
library(torch) | |
set.seed(45) | |
torch_manual_seed(45) | |
# True distribution is normal(1, 3), we sample 10000 points | |
N <- 500 | |
y <- rnorm(N, 1, 3) | |
# we need to create torch tensor of this data to use in torch | |
y_torch <- torch_tensor(matrix(y), requires_grad = FALSE) | |
# let's look at the data | |
plot(density(y), bty = "L", main = "Density of real data y") | |
curve(dnorm(x, 1, 3), add = TRUE, col = "darkgrey") | |
rug(y) | |
# an extremely simple generator with 2 parameters: | |
# a weight and a bias | |
generator <- nn_linear(1, 1) | |
# an extremely simple discriminator with 2 hidden nodes: | |
discriminator <- nn_sequential( | |
nn_linear(1, 2), | |
nn_sigmoid(), | |
nn_linear(2, 1), | |
nn_sigmoid() | |
) | |
is_real <- torch_ones_like(y_torch) | |
is_fake <- torch_zeros_like(y_torch) | |
criterion <- nn_bce_loss() | |
# Two time-scale update rule: discriminator learning rate | |
# is twice as high as the generator learning rate | |
# https://arxiv.org/abs/1706.08500 | |
optg <- optim_adam(generator$parameters, lr = 1e-2) | |
optd <- optim_adam(discriminator$parameters, lr = 2e-2) | |
n_epoch <- 500 | |
dlosses <- numeric(n_epoch) | |
glosses <- numeric(n_epoch) | |
# Plot theoretical density to learn | |
# Start training | |
for (i in 1:n_epoch) { | |
# generate fake data | |
inp <- torch_randn(N, 1, requires_grad = FALSE) | |
y_fake <- generator(inp) | |
# train the discriminator | |
discriminator$zero_grad() | |
# the discriminator loss is its ability to classify | |
# real and fake data correctly | |
prob_real <- discriminator(y_torch) | |
prob_fake <- discriminator(y_fake) | |
dloss_real <- criterion(prob_real, is_real) | |
dloss_fake <- criterion(prob_fake, is_fake) | |
dloss <- dloss_real + dloss_fake | |
dloss$backward() | |
optd$step() | |
dlosses[i] <- dloss$item() | |
# train the generator | |
generator$zero_grad() | |
# the generator loss is its ability to create | |
# data that is classified by the discriminator | |
# as real data | |
prob_fake <- discriminator(generator(inp)) | |
gloss <- criterion(prob_fake, is_real) | |
gloss$backward() | |
optg$step() | |
glosses[i] <- gloss$item() | |
# Print current state | |
if (interactive()) | |
cat("\r iteration", i, "dloss:", dlosses[i], "gloss:", glosses[i]) | |
} | |
# inspect losses | |
par(mfrow = c(1, 2)) | |
plot(glosses, type = "l", ylab = "loss", xlab = "Epoch", | |
main = "Generator loss", col = "darkblue", bty = "L") | |
plot(dlosses, type = "l", ylab = "loss", xlab = "Epoch", | |
main = "Discriminator loss", col = "darkred", bty = "L") | |
par(mfrow = c(1, 1)) | |
# inspect discriminator output | |
xpred <- seq(-15, 15, length.out = 1000) | |
plot(xpred, as.numeric(discriminator(matrix(xpred))), type = "l") | |
abline(h = 0.5, lty = 2) | |
# inspect parameters | |
(mu_hat <- as.numeric(generator$parameters[["bias"]])) | |
(sd_hat <- abs(as.numeric(generator$parameters[["weight"]]))) | |
# inspect theoretical distributions | |
curve(dnorm(x, 1, 3), from = -10, to = 10) | |
curve(dnorm(x, mean(y), sd(y)), add = TRUE, col = "blue") | |
curve(dnorm(x, mu_hat, sd_hat), add = TRUE, col = "darkgrey") | |
# generate fake data | |
y_hat <- as.numeric(generator(matrix(rnorm(1000)))) | |
plot(density(y_hat), bty = "L", main = "Density of fake data y_hat") | |
rug(y_hat) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment