Last active
June 27, 2018 16:43
-
-
Save goldsborough/f1cc332bd9ac03cae24c422175a49026 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/torch.h> | |
#include "mnist_reader.h" | |
#include <fstream> | |
#include <iomanip> | |
#include <iostream> | |
#include <string> | |
#include <vector> | |
using namespace torch; | |
void initialize_weights(nn::Module& module) { | |
if (module.name().find("Conv2d") != std::string::npos) { | |
module.parameters()["weight"].data().normal_(0.0, 0.02); | |
} else if (module.name().find("BatchNorm") != std::string::npos) { | |
auto parameters = module.parameters(); | |
parameters["weight"].data().normal_(1.0, 0.02); | |
parameters["bias"].data().fill_(0); | |
} | |
} | |
void store_csv_tensor(Tensor tensor) { | |
std::ofstream out("out.csv"); | |
auto flat = tensor.flatten(); | |
for (size_t i = 0; i < tensor.numel(); ++i) { | |
out << flat[i].toCFloat() << ","; | |
} | |
} | |
auto main() -> int { | |
const int64_t kNoiseSize = 100; | |
const int64_t kNumberOfEpochs = 30; | |
const int64_t kBatchSize = 60; | |
const int64_t kSampleEvery = 100; | |
const at::Device device(at::kCUDA, 0); | |
nn::Sequential generator( | |
// Layer 1 | |
nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4).with_bias(false).transposed(true)), | |
nn::BatchNorm(256), | |
nn::Functional(at::relu), | |
// Layer 2 | |
nn::Conv2d(nn::Conv2dOptions(256, 128, 3).stride(2).padding(1).with_bias(false).transposed(true)), | |
nn::BatchNorm(128), | |
nn::Functional(at::relu), | |
// Layer 3 | |
nn::Conv2d(nn::Conv2dOptions(128, 64, 4).stride(2).padding(1).with_bias(false).transposed(true)), | |
nn::BatchNorm(64), | |
nn::Functional(at::relu), | |
// Layer 4 | |
nn::Conv2d(nn::Conv2dOptions(64, 1, 4).stride(2).padding(1).with_bias(false).transposed(true)), | |
nn::Functional(at::tanh)); | |
generator.to(device); | |
generator.modules().apply(initialize_weights); | |
nn::Sequential discriminator( | |
// Layer 1 | |
nn::Conv2d(nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)), | |
nn::Functional(at::leaky_relu, 0.2), | |
// Layer 2 | |
nn::Conv2d(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)), | |
nn::BatchNorm(64 * 2), | |
nn::Functional(at::leaky_relu, 0.2), | |
// Layer 3 | |
nn::Conv2d(nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)), | |
nn::BatchNorm(256), | |
nn::Functional(at::leaky_relu, 0.2), | |
// Layer 4 | |
nn::Conv2d(nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)), | |
nn::Functional(at::sigmoid)); | |
discriminator.to(device); | |
discriminator.modules().apply(initialize_weights); | |
optim::Adam generator_optimizer(generator.parameters(), optim::AdamOptions(2e-4).beta1(0.5)); | |
optim::Adam discriminator_optimizer(discriminator.parameters(), optim::AdamOptions(5e-4).beta1(0.5)); | |
auto examples = read_mnist_examples("test/cpp/api/mnist/train-images-idx3-ubyte"); | |
examples = (examples * 2) - 1; | |
examples = examples.reshape({examples.size(0) / kBatchSize, kBatchSize, 1, 28, 28}); | |
examples = examples.to(device); | |
const auto fixed_noise = torch::randn({kBatchSize, kNoiseSize, 1, 1}, device); | |
std::cout << std::setprecision(4) << "\n"; | |
for (size_t epoch = 0; epoch < kNumberOfEpochs; ++epoch) { | |
for (size_t i = 0; i < examples.size(0); ++i) { | |
// Train discriminator with real images. | |
discriminator.zero_grad(); | |
torch::Tensor real_images = examples[i].to(device); | |
auto real_labels = torch::empty(kBatchSize, device).uniform_(0.8, 1.0); | |
auto real_output = discriminator.forward(real_images); | |
auto d_loss_real = at::binary_cross_entropy(real_output, real_labels); | |
d_loss_real.backward(); | |
// Train discriminator with fake images. | |
auto noise = torch::randn({kBatchSize, kNoiseSize, 1, 1}, device); | |
torch::Tensor fake_images = generator.forward(noise); | |
auto fake_labels = torch::zeros(kBatchSize, device); | |
auto fake_output = discriminator.forward(torch::Tensor(fake_images.detach())); | |
auto d_loss_fake = at::binary_cross_entropy(fake_output, fake_labels); | |
d_loss_fake.backward(); | |
auto d_loss = d_loss_real + d_loss_fake; | |
discriminator_optimizer.step(); | |
// Train generator. | |
generator.zero_grad(); | |
fake_labels.fill_(1); | |
fake_output = discriminator.forward(fake_images); | |
auto g_loss = at::binary_cross_entropy(fake_output, fake_labels); | |
g_loss.backward(); | |
generator_optimizer.step(); | |
std::cout << "\r[" << epoch << "/" << kNumberOfEpochs << "][" << i << "/" << examples.size(0) | |
<< "] D_loss: " << d_loss.toCFloat() << " | G_loss: " << g_loss.toCFloat() << std::flush; | |
if (i % kSampleEvery == 0) { | |
auto fake_images = generator.forward(fixed_noise); | |
auto image = (fake_images[0] + 1) / 2; | |
store_csv_tensor(image); | |
// std::cout << "\nWrote tensor to out.csv" << std::endl; | |
} | |
} | |
} | |
std::cout << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment