-
-
Save piscisaureus/72d2a9536aeb55a3f7631b28c8b44fcd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// tslint:disable:typedef | |
// tslint:disable:comment-format | |
import * as tf from "@tensorflow/tfjs"; | |
const DIM = 64; // Model dimensionality. | |
const BATCH_SIZE = 50; | |
const CRITIC_ITERS = 5; // How many critic iterations per generator iteration | |
const LAMBDA = 10; // Gradient penalty lambda hyperparameter | |
const ITERS = 200000; // How many generator iterations to train for | |
const OUTPUT_DIM = 784; // Number of pixels in MNIST (28*28) | |
const generator = tf.sequential({ layers: [ | |
// Preprocess: | |
// nn.Linear(128, 4*4*4*DIM), | |
// nn.ReLU(True), | |
tf.layers.dense({inputShape: [128], units: 4*4*4 * DIM}), | |
tf.layers.activation({activation: "relu"}), | |
// Reshape | |
// output = output.view(-1, 4*DIM, 4, 4) | |
tf.layers.reshape({ targetShape: [4 * DIM, 4, 4] }), | |
// Block 1: | |
// nn.ConvTranspose2d(4*DIM, 2*DIM, 5), | |
// nn.ReLU(True), | |
tf.layers.conv2dTranspose({filters: 2*DIM, kernelSize: 5}), | |
tf.layers.activation({activation: "relu"}), | |
// Block 2: | |
// nn.ConvTranspose2d(2*DIM, DIM, 5), | |
// nn.ReLU(True), | |
tf.layers.conv2dTranspose({filters: DIM, kernelSize: 5}), | |
tf.layers.activation({activation: "relu"}), | |
// Deconv output: | |
// nn.ConvTranspose2d(DIM, 1, 8, stride=2) | |
tf.layers.conv2dTranspose({filters: 1, kernelSize: 8, strides: 2}), | |
// Finalize: | |
// output = self.sigmoid(output) | |
// output = output.view(-1, OUTPUT_DIM) | |
tf.layers.activation({ activation: "sigmoid" }), | |
tf.layers.reshape({ targetShape: [OUTPUT_DIM]} ) | |
]}); | |
const discriminator = tf.sequential({ layers: [ | |
// Reshape input: | |
// input.view(-1, 1, 28, 28) | |
tf.layers.reshape({ targetShape: [1, 28, 28] }), | |
// Main block: | |
// nn.Conv2d(1, DIM, 5, stride=2, padding=2), | |
// nn.ReLU(True), | |
// nn.Conv2d(DIM, 2*DIM, 5, stride=2, padding=2), | |
// nn.ReLU(True), | |
// nn.Conv2d(2*DIM, 4*DIM, 5, stride=2, padding=2), | |
// nn.ReLU(True), | |
tf.layers.zeroPadding2d({ padding: 2 }), | |
tf.layers.conv2d({ filters: DIM, kernelSize: 5, strides: 2 }), | |
tf.layers.activation({activation: "relu"}), | |
tf.layers.zeroPadding2d({ padding: 2 }), | |
tf.layers.conv2d({ filters: 2 * DIM, kernelSize: 5, strides: 2 }), | |
tf.layers.activation({activation: "relu"}), | |
tf.layers.zeroPadding2d({ padding: 2 }), | |
tf.layers.conv2d({ filters: 4 * DIM, kernelSize: 5, strides: 2 }), | |
tf.layers.activation({activation: "relu"}), | |
// Reshape and fully connected layer | |
// output = output.view(-1, 4*4*4*DIM) | |
// nn.Linear(4*4*4*DIM, 1) | |
tf.layers.reshape({ targetShape: [-1, 4**3 * DIM]}), | |
tf.layers.dense({ units: 1 }) | |
]}); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment