Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Created May 3, 2023 07:58
Show Gist options
  • Save vankesteren/684de6a3f44597ac9e09158457b3c0d0 to your computer and use it in GitHub Desktop.
Save vankesteren/684de6a3f44597ac9e09158457b3c0d0 to your computer and use it in GitHub Desktop.
Smallest possible useful GAN in R torch
# setup with seeds for reproducibility
library(torch)
set.seed(45)
torch_manual_seed(45)
# True distribution is normal(1, 3), we sample 10000 points
N <- 500
y <- rnorm(N, 1, 3)
# we need to create torch tensor of this data to use in torch
y_torch <- torch_tensor(matrix(y), requires_grad = FALSE)
# let's look at the data
plot(density(y), bty = "L", main = "Density of real data y")
curve(dnorm(x, 1, 3), add = TRUE, col = "darkgrey")
rug(y)
# an extremely simple generator with 2 parameters:
# a weight and a bias
generator <- nn_linear(1, 1)
# an extremely simple discriminator with 2 hidden nodes:
discriminator <- nn_sequential(
nn_linear(1, 2),
nn_sigmoid(),
nn_linear(2, 1),
nn_sigmoid()
)
is_real <- torch_ones_like(y_torch)
is_fake <- torch_zeros_like(y_torch)
criterion <- nn_bce_loss()
# Two time-scale update rule: discriminator learning rate
# is twice as high as the generator learning rate
# https://arxiv.org/abs/1706.08500
optg <- optim_adam(generator$parameters, lr = 1e-2)
optd <- optim_adam(discriminator$parameters, lr = 2e-2)
n_epoch <- 500
dlosses <- numeric(n_epoch)
glosses <- numeric(n_epoch)
# Plot theoretical density to learn
# Start training
for (i in 1:n_epoch) {
# generate fake data
inp <- torch_randn(N, 1, requires_grad = FALSE)
y_fake <- generator(inp)
# train the discriminator
discriminator$zero_grad()
# the discriminator loss is its ability to classify
# real and fake data correctly
prob_real <- discriminator(y_torch)
prob_fake <- discriminator(y_fake)
dloss_real <- criterion(prob_real, is_real)
dloss_fake <- criterion(prob_fake, is_fake)
dloss <- dloss_real + dloss_fake
dloss$backward()
optd$step()
dlosses[i] <- dloss$item()
# train the generator
generator$zero_grad()
# the generator loss is its ability to create
# data that is classified by the discriminator
# as real data
prob_fake <- discriminator(generator(inp))
gloss <- criterion(prob_fake, is_real)
gloss$backward()
optg$step()
glosses[i] <- gloss$item()
# Print current state
if (interactive())
cat("\r iteration", i, "dloss:", dlosses[i], "gloss:", glosses[i])
}
# inspect losses
par(mfrow = c(1, 2))
plot(glosses, type = "l", ylab = "loss", xlab = "Epoch",
main = "Generator loss", col = "darkblue", bty = "L")
plot(dlosses, type = "l", ylab = "loss", xlab = "Epoch",
main = "Discriminator loss", col = "darkred", bty = "L")
par(mfrow = c(1, 1))
# inspect discriminator output
xpred <- seq(-15, 15, length.out = 1000)
plot(xpred, as.numeric(discriminator(matrix(xpred))), type = "l")
abline(h = 0.5, lty = 2)
# inspect parameters
(mu_hat <- as.numeric(generator$parameters[["bias"]]))
(sd_hat <- abs(as.numeric(generator$parameters[["weight"]])))
# inspect theoretical distributions
curve(dnorm(x, 1, 3), from = -10, to = 10)
curve(dnorm(x, mean(y), sd(y)), add = TRUE, col = "blue")
curve(dnorm(x, mu_hat, sd_hat), add = TRUE, col = "darkgrey")
# generate fake data
y_hat <- as.numeric(generator(matrix(rnorm(1000))))
plot(density(y_hat), bty = "L", main = "Density of fake data y_hat")
rug(y_hat)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment