Last active
January 26, 2025 08:00
-
-
Save kkew3/266418d81a26df929dc4071c43554205 to your computer and use it in GitHub Desktop.
I came across [this](https://github.com/aifromphytsai/gan_sampling) repo on weekends, namely, "A simple GAN to generate samples from Gaussian distribution". Good news: to generate Gaussian samples, you don't need to train the network. We can achieve it by simply initializing the weights!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# A useless trick: Initializing a GAN to be Gaussian" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch import autograd\n", | |
"from torch import distributions as D\n", | |
"from matplotlib import pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Introduction\n", | |
"\n", | |
"I came across [this](https://github.com/aifromphytsai/gan_sampling) repo on weekends, namely, \"A simple GAN to generate samples from Gaussian distribution\". Good news: to generate Gaussian samples, you don't need to train the network. We can achieve it by simply initializing the weights!\n", | |
"\n", | |
"Why do we want to initialize a GAN to be Gaussian in the first place? One possible reason is that a Gaussian serves as a reasonable approximation of the probability density we aim to estimate, e.g. an exponential tilting of a Gaussian base distribution: $p(x) \\propto \\exp(f(x)) p_0(x)$. Instead of learning from scratch, it may be a good idea to just kick-start from a Gaussian.\n", | |
"\n", | |
"## Gaussian GAN\n", | |
"\n", | |
"Parameterizing a GAN generator as an MLP, our goal is to initialize the MLP such that it produces Gaussian of desired mean and diagonal variance given standard Gaussian randomness.\n", | |
"\n", | |
"Denote the $K$-layer MLP as $f(\\boldsymbol x) = \\mathbf W \\phi(\\boldsymbol x) + \\boldsymbol b$, where $\\phi$ is the first $K-1$ layer mapping, and $(\\mathbf W,\\boldsymbol b)$ is the last layer parameters. Let the activation function be $\\tanh$. Recall that a linear transformation of a Gaussian is a Gaussian. Thus, we ought to make $\\phi(\\boldsymbol x)$ approximately a linear transformation, and then properly initialize $(\\mathbf W,\\boldsymbol b)$ so that the output exhibits the intended mean and variance." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"x = torch.linspace(-1, 1, 1000)\n", | |
"plt.plot(x, x.tanh(), label='tanh')\n", | |
"plt.plot(x, x, label='identity')\n", | |
"plt.legend()\n", | |
"plt.grid()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We will try to confine most outputs of each layer to within range $[-\\alpha, \\alpha]$, in which $\\tanh$ is approximately an identity function. However, we do not want to choose $\\alpha$ too small, since it will hinder the nonlinearity of the network too much, preventing it from learning complex pattern ([Bradley, 2009](https://www.ri.cmu.edu/pub_files/2010/5/dbradley_thesis.pdf); [Glorot & Bengio, 2010](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)). From the plot above, we see $\\alpha = 0.25$ is a sensible choice. (Well, probably not 🙃; but if we increase $\\alpha$ too much, the output will not be Gaussian! That's why the \"useless\" in title.)\n", | |
"\n", | |
"Our first step is to randomly initialize each layer of $\\phi$ while ensuring it's approximately linear with respect to the input. We will consider in a layer-by-layer manner. For the $i$-th layer, supposing that the input $\\boldsymbol x \\in \\mathbb R^n$ follows $\\mathcal N(\\mathbf 0, \\mathbf I)$, for each output pre-activation $u = \\boldsymbol w^\\top \\boldsymbol x$, we want 95% of it to fail into that range, which means that $1.96 \\sigma_u \\le \\alpha$ (see [confidence interval](https://en.wikipedia.org/wiki/Confidence_interval)). Here, we have denoted the standard deviation of variable $v$ as $\\sigma_v$. Since if $\\boldsymbol x \\sim \\mathcal N(\\mathbf 0, \\mathbf I)$, $\\boldsymbol w^\\top \\boldsymbol x \\sim \\mathcal N(0, w^\\top w)$; thus we need $\\boldsymbol w^\\top \\boldsymbol w = \\sigma_u^2 \\le (\\alpha / 1.96)^2$. Since we typically randomly and normally initialize $w$, we need to take expectation on both sides, which yields the requirements $\\mathbb E[\\boldsymbol w^\\top \\boldsymbol w] \\le (\\alpha / 1.96)^2$. Equivalently, $\\sum_{i=1}^n \\mathbb E[w_i^2] \\le (\\alpha / 1.96)^2$. But $w_i^2$ follows $\\chi^2$-distribution scaled by $\\sigma_w^2$ when $w_i \\sim \\mathcal N(0, \\sigma_w^2)$; hence, $\\mathbb E[\\boldsymbol w^\\top \\boldsymbol w] = \\sum_{i=1}^n \\mathbb E[w_i^2] = n \\sigma_w^2 \\le (\\alpha / 1.96)^2$, or $\\sigma_w \\le n^{-1/2} (\\alpha/1.96)$. For the obvious reason, we will keep the bias zero." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# The 95% confidence intervel scaler on standard deviation.\n", | |
"CONF_SCALER = 1.96\n", | |
"# The probabilistic absolute upper bound of the pre-activation of each layer.\n", | |
"ALPHA = 0.25" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We will go over a small experiment to verify our derivation above:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.9418)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Data dimension.\n", | |
"dim = 5\n", | |
"# Initialize 10000 data points.\n", | |
"x = torch.randn(10000, dim)\n", | |
"# Normally initialize the weights 200 times.\n", | |
"w = torch.randn(200, 1, dim) * (ALPHA / (math.sqrt(dim) * CONF_SCALER))\n", | |
"# Compute the pre-activations.\n", | |
"preact = (w * x).sum(2)\n", | |
"# Initialize the pre-activations of the first trial.\n", | |
"plt.hist(preact[0], bins=100)\n", | |
"plt.axvline(-ALPHA, c='r', linewidth=0.8)\n", | |
"plt.axvline(ALPHA, c='r', linewidth=0.8)\n", | |
"# Should be approximately 0.95:\n", | |
"(preact.abs() <= ALPHA).float().mean()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"While we are able to initialize the first layer by this derivation, it does not apply to the second layer onwards. Denote the input to the second layer as $\\boldsymbol z \\in \\mathbb R^m$. Denote the input the to first layer as $\\boldsymbol x \\in \\mathbb R^n$ as before, and denote the weights of the $i$-th unit in the first layer as $\\boldsymbol w_i$. It follows that $z_i = \\tanh(\\boldsymbol w_i^\\top \\boldsymbol x) \\approx \\boldsymbol w_i^\\top \\boldsymbol x$, and that $z_i \\sim \\mathcal N(0, \\boldsymbol w_i^\\top \\boldsymbol w_i)$. Therefore, $\\boldsymbol z \\sim \\mathcal N(\\mathbf 0, \\operatorname{diag}(\\boldsymbol w_1^\\top \\boldsymbol w_1, \\dots, \\boldsymbol w_m^\\top \\boldsymbol w_m))$. This means we cannot assume standard Gaussian inputs from the second layer. Here, $\\operatorname{diag}$ denotes a diagonal matrix.\n", | |
"\n", | |
"Fortunately, the derivation is similar. Denote the weight vector as $\\boldsymbol v$. When the input $\\boldsymbol z$ admits diagonal variance $\\boldsymbol\\Lambda = \\operatorname{diag}(\\lambda_1,\\dots, \\lambda_m)$, $\\boldsymbol u^\\top \\boldsymbol z \\sim \\mathcal N(0, \\boldsymbol u^\\top \\boldsymbol\\Lambda \\boldsymbol u)$. So instead of keeping $\\mathbb E[\\boldsymbol u^\\top \\boldsymbol u] \\le (\\alpha / 1.96)^2$, we now require $\\mathbb E[\\boldsymbol u^\\top \\boldsymbol\\Lambda \\boldsymbol u] = \\sum_{j=1}^m \\lambda_j \\mathbb E[u_j^2] \\le (\\alpha / 1.96)^2$. Solving the equation, we get $\\sigma_u \\le (\\sum_{j=1}^m \\lambda_j)^{-1/2}(\\alpha / 1.96)$.\n", | |
"\n", | |
"With this in mind, let's try to initialize the first two layers of a three-layer MLP:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.9401)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# An MLP: input_dim -> hidden_dim -> hidden_dim -> output_dim. We focus on the\n", | |
"# first two layers: phi(x) = tanh(W2 tanh(W1 x)), or approximately,\n", | |
"# phi_l(x) = W2 W1 x.\n", | |
"input_dim = 3\n", | |
"hidden_dim = 5\n", | |
"# Initialize 200 instances of W1.\n", | |
"std1 = ALPHA / (CONF_SCALER * math.sqrt(input_dim))\n", | |
"W1 = torch.randn(200, hidden_dim, input_dim) * std1\n", | |
"# Initialize 200 instances of W2.\n", | |
"var1 = torch.bmm(W1, W1.transpose(1, 2)).diagonal(dim1=1, dim2=2).sum(1)\n", | |
"std2 = ALPHA / (CONF_SCALER * var1.sqrt())\n", | |
"W2 = torch.randn(200, hidden_dim, hidden_dim) * std2.unsqueeze(1).unsqueeze(1)\n", | |
"# 10000 input points.\n", | |
"x = torch.randn(10000, input_dim)\n", | |
"# Find approximate z_l = phi_l(x).\n", | |
"z_l = torch.matmul(W1.unsqueeze(1), x.unsqueeze(2)) # of shape (200, 10000, hidden_dim, 1)\n", | |
"z_l = torch.matmul(W2.unsqueeze(1), z_l).squeeze(3) # of shape (200, 10000, hidden_dim)\n", | |
"# Should be approximately 0.95:\n", | |
"(z_l.abs() <= ALPHA).float().mean()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We now actually construct an MLP:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class MLPGenerator(nn.Module):\n", | |
" \"\"\"A GAN generator parameterized as an MLP.\"\"\"\n", | |
" def __init__(self, input_dim, hidden_dim, output_dim):\n", | |
" super().__init__()\n", | |
" self.input_dim = input_dim\n", | |
" self.hidden_dim = hidden_dim\n", | |
" self.output_dim = output_dim\n", | |
" self.fc1 = nn.Linear(input_dim, hidden_dim)\n", | |
" self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n", | |
" self.fc3 = nn.Linear(hidden_dim, output_dim)\n", | |
" self.init_parameters()\n", | |
" \n", | |
" def init_parameters(self):\n", | |
" std1 = ALPHA / (CONF_SCALER * math.sqrt(self.input_dim))\n", | |
" nn.init.normal_(self.fc1.weight, std=std1)\n", | |
" nn.init.zeros_(self.fc1.bias)\n", | |
" w1 = self.fc1.weight.data\n", | |
" var1 = torch.mm(w1, w1.t()).trace()\n", | |
" std2 = ALPHA / (CONF_SCALER * var1.sqrt())\n", | |
" nn.init.normal_(self.fc2.weight, std=std2)\n", | |
" nn.init.zeros_(self.fc2.bias)\n", | |
" # TODO: initialize fc3.\n", | |
" \n", | |
" def forward(self, x, output=True):\n", | |
" \"\"\"\n", | |
" If `output` is True, return the output; otherwise, return the last\n", | |
" hidden state.\n", | |
" \"\"\"\n", | |
" z = self.fc1(x).tanh()\n", | |
" z = self.fc2(z).tanh()\n", | |
" if not output:\n", | |
" return z\n", | |
" return self.fc3(z)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We have tried our best to make $\\phi$ a linear transformation by pushing pre-activations to the linear region of the $\\tanh$ function. But there's still inevitably some non-linearity in the function. Thus, we may compute the Jacobian $\\mathbf J$ of $\\phi$ at zero to completely linearize the network as: $\\phi(\\boldsymbol x) \\approx \\mathbf J \\boldsymbol x$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.0091, -0.0107, -0.0203],\n", | |
" [-0.0327, 0.0492, 0.0109],\n", | |
" [-0.0524, -0.0188, -0.0861],\n", | |
" [ 0.0024, 0.0175, 0.0090],\n", | |
" [ 0.0171, 0.0054, 0.1103]])" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"input_dim = 3\n", | |
"hidden_dim = 5\n", | |
"output_dim = 2\n", | |
"net = MLPGenerator(input_dim, hidden_dim, output_dim)\n", | |
"J = autograd.functional.jacobian(lambda x: net(x, output=False), torch.zeros(input_dim))\n", | |
"J" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The last hidden state $\\boldsymbol z = \\phi(\\boldsymbol x) \\approx \\mathbf J \\boldsymbol x$ around $\\boldsymbol x = \\mathbf 0$. Denote the weights and biases of the last layer as $(\\mathbf W, \\boldsymbol b)$, we have the network output $\\boldsymbol y \\approx \\mathbf W \\mathbf J \\boldsymbol x + \\boldsymbol b$. Assuming standard Gaussian randomness in the network input $\\boldsymbol x$, $\\boldsymbol y$ will approximately follow Gaussian distribution $\\mathcal N(\\boldsymbol b, \\mathbf W \\mathbf J \\mathbf J^\\top \\mathbf W^\\top)$.\n", | |
"\n", | |
"Suppose we'd like to initialize the GAN generator so that $\\boldsymbol y \\sim \\mathcal N(\\boldsymbol\\mu, \\boldsymbol\\Sigma)$ where $\\boldsymbol\\Sigma$ is the diagonal variance. Since it's obvious that $\\boldsymbol b = \\boldsymbol\\mu$, we only need to find $\\mathbf W$ such that $\\mathbf W \\mathbf J \\mathbf J^\\top \\mathbf W^\\top = \\boldsymbol\\Sigma$. By the semi-definiteness of $\\mathbf J \\mathbf J^\\top$, it follows immediately that $\\mathbf J \\mathbf J^\\top = \\mathbf Q \\mathbf S \\mathbf Q^\\top$ where $\\mathbf Q$ is orthogonal and $\\mathbf S$ is diagonal with non-negative elements. Following $\\mathbf S = \\mathbf S^{\\frac{1}{2}} \\mathbf S^{\\frac{1}{2}\\top}$ and $\\boldsymbol\\Sigma = \\boldsymbol\\Sigma^{\\frac{1}{2}} \\boldsymbol\\Sigma^{\\frac{1}{2}\\top}$, we can write $\\mathbf W \\mathbf Q \\mathbf S^{\\frac{1}{2}} = \\boldsymbol\\Sigma^{\\frac{1}{2}} \\mathbf U$ where $\\mathbf U$ is any matrix with orthogonal rows, i.e. $\\mathbf U \\mathbf U^\\top = \\mathbf I$. Solving the equation gives $\\mathbf W = \\boldsymbol\\Sigma^{\\frac{1}{2}} \\mathbf U \\mathbf S^{-\\frac{1}{2}} \\mathbf Q^\\top$.\n", | |
"\n", | |
"Notice that we have some freedom in choosing $\\mathbf U$. A simply enough choice is to enforce weight decay, in which we minimize $\\|\\mathbf W\\|_F^2 = \\operatorname{tr}(\\mathbf W^\\top \\mathbf W)$ subject to $\\mathbf U \\mathbf U^\\top = \\mathbf I$. By Lagrange multiplier, the optimal $\\mathbf U$ consists of eigenvectors corresponding to the smallest eigenvalues of $\\mathbf S^{-\\frac{1}{2}}$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Our desired output mean.\n", | |
"MU = torch.tensor([2.0, 4.0])\n", | |
"# Our desired output stdev.\n", | |
"SIGMA = torch.tensor([1.5, 0.8])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-1.5458, 0.4690, -6.5890, 0.7764, 7.2268],\n", | |
" [-0.9419, 12.3732, 2.0683, 2.7170, 0.5894]])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Q, S_half, _ = torch.linalg.svd(J, full_matrices=False)\n", | |
"# First output_dim S_half's elements corresponds to the smallest 1/S_half's\n", | |
"# elements.\n", | |
"U = torch.eye(input_dim)[:output_dim]\n", | |
"W = torch.diag(SIGMA).mm(U).mm(torch.diag(S_half.reciprocal())).mm(Q.t())\n", | |
"W" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Assert W J J.t W.t == Sigma:\n", | |
"torch.allclose(W.mm(J).mm(J.t()).mm(W.t()), torch.diag(SIGMA.square()), atol=1e-6)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Finishing the GAN generator above:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class MLPGenerator(nn.Module):\n", | |
" \"\"\"A GAN generator parameterized as an MLP.\"\"\"\n", | |
" def __init__(self, input_dim, hidden_dim, output_dim, enable_baseline=False):\n", | |
" super().__init__()\n", | |
" self.input_dim = input_dim\n", | |
" self.hidden_dim = hidden_dim\n", | |
" self.output_dim = output_dim\n", | |
" self.fc1 = nn.Linear(input_dim, hidden_dim)\n", | |
" self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n", | |
" self.fc3 = nn.Linear(hidden_dim, output_dim)\n", | |
" if enable_baseline:\n", | |
" self.init_parameters_baseline()\n", | |
" else:\n", | |
" self.init_parameters()\n", | |
" \n", | |
" def init_parameters_baseline(self):\n", | |
" tanh_gain = nn.init.calculate_gain('tanh')\n", | |
" nn.init.xavier_normal_(self.fc1.weight, gain=tanh_gain)\n", | |
" nn.init.zeros_(self.fc1.bias)\n", | |
" nn.init.xavier_normal_(self.fc2.weight, gain=tanh_gain)\n", | |
" nn.init.zeros_(self.fc2.bias)\n", | |
" nn.init.xavier_normal_(self.fc3.weight, gain=tanh_gain)\n", | |
" self.fc3.bias.data.copy_(MU)\n", | |
" \n", | |
" @torch.no_grad()\n", | |
" def init_parameters(self):\n", | |
" std1 = ALPHA / (CONF_SCALER * math.sqrt(self.input_dim))\n", | |
" nn.init.normal_(self.fc1.weight, std=std1)\n", | |
" nn.init.zeros_(self.fc1.bias)\n", | |
" w1 = self.fc1.weight.data\n", | |
" var1 = torch.mm(w1, w1.t()).trace()\n", | |
" std2 = ALPHA / (CONF_SCALER * var1.sqrt())\n", | |
" nn.init.normal_(self.fc2.weight, std=std2)\n", | |
" nn.init.zeros_(self.fc2.bias)\n", | |
" \n", | |
" # Initialize self.fc3.\n", | |
" J = autograd.functional.jacobian(lambda x: self(x, output=False),\n", | |
" torch.zeros(self.input_dim))\n", | |
" Q, S_half, _ = torch.linalg.svd(J, full_matrices=False)\n", | |
" U = torch.eye(self.input_dim)[:output_dim]\n", | |
" W = torch.diag(SIGMA).mm(U).mm(torch.diag(S_half.reciprocal())).mm(Q.t())\n", | |
" self.fc3.weight.data.copy_(W)\n", | |
" self.fc3.bias.data.copy_(MU)\n", | |
" \n", | |
" def forward(self, x, output=True):\n", | |
" \"\"\"\n", | |
" If `output` is True, return the output; otherwise, return the last\n", | |
" hidden state.\n", | |
" \"\"\"\n", | |
" z = self.fc1(x).tanh()\n", | |
" z = self.fc2(z).tanh()\n", | |
" if not output:\n", | |
" return z\n", | |
" return self.fc3(z)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"For the baseline (default) initialization, I have set all hidden layers' biases to zero, and used xavier normal initialization with recommended gain for $\\tanh$ nonlinearity. In both the proposed and the default initialization, we align the output to the desired mean by setting the output layer's bias to the desired mean." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Generate some samples:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 1200x500 with 2 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"net = MLPGenerator(input_dim, hidden_dim, output_dim)\n", | |
"x = torch.randn(1000, input_dim)\n", | |
"with torch.no_grad():\n", | |
" y = net(x)\n", | |
"plt.figure(figsize=(12, 5))\n", | |
"plt.subplot(121)\n", | |
"plt.title('Initialized GAN samples and the desired mean (red)')\n", | |
"plt.scatter(y[:, 0], y[:, 1], marker='x', s=2, alpha=0.5)\n", | |
"plt.scatter([MU[0]], [MU[1]], c='r')\n", | |
"plt.xlim(MU[0] - 4, MU[0] + 4)\n", | |
"plt.ylim(MU[1] - 4, MU[1] + 4)\n", | |
"plt.subplot(122)\n", | |
"plt.title('Desired Gaussian samples')\n", | |
"y0 = D.Normal(MU, SIGMA).sample((1000,))\n", | |
"plt.scatter(y0[:, 0], y0[:, 1], marker='x', s=2, alpha=0.5)\n", | |
"plt.xlim(MU[0] - 4, MU[0] + 4)\n", | |
"plt.ylim(MU[1] - 4, MU[1] + 4);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"As can be seen, the GAN samples are almost the same as the desired Gaussian samples. The difference is that the GAN samples exhibits slightly lower variance around the boundary, due to the squashing effect of $\\tanh$.\n", | |
"\n", | |
"As a comparison, we generate some baseline samples, where we apply the default initialization:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 1200x500 with 2 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"net = MLPGenerator(input_dim, hidden_dim, output_dim, enable_baseline=True)\n", | |
"x = torch.randn(1000, input_dim)\n", | |
"with torch.no_grad():\n", | |
" y = net(x)\n", | |
"plt.figure(figsize=(12, 5))\n", | |
"plt.subplot(121)\n", | |
"plt.title('Default GAN samples and the desired mean (red)')\n", | |
"plt.scatter(y[:, 0], y[:, 1], marker='x', s=0.7, alpha=0.5)\n", | |
"plt.scatter([MU[0]], [MU[1]], c='r')\n", | |
"plt.xlim(MU[0] - 4, MU[0] + 4)\n", | |
"plt.ylim(MU[1] - 4, MU[1] + 4)\n", | |
"plt.subplot(122)\n", | |
"plt.title('Desired Gaussian samples')\n", | |
"y0 = D.Normal(MU, SIGMA).sample((1000,))\n", | |
"plt.scatter(y0[:, 0], y0[:, 1], marker='x', s=2, alpha=0.5)\n", | |
"plt.xlim(MU[0] - 4, MU[0] + 4)\n", | |
"plt.ylim(MU[1] - 4, MU[1] + 4);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Remark: You may have noticed that the baseline distribution produced by the default GAN is always symmetric. The reasons are threefold:\n", | |
"\n", | |
"1. The input Gaussian randomness is symmetric.\n", | |
"2. The activation function ($\\tanh$) is symmetric.\n", | |
"3. All biases except for the last layer's are zero; so the symmetry is not broken." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment