Last active
January 3, 2019 07:15
-
-
Save rezoo/4e005611aaa4dad26697 to your computer and use it in GitHub Desktop.
Simple implementation of Generative Adversarial Nets using chainer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gzip | |
import os | |
import numpy as np | |
import six | |
from six.moves.urllib import request | |
parent = 'http://yann.lecun.com/exdb/mnist' | |
train_images = 'train-images-idx3-ubyte.gz' | |
train_labels = 'train-labels-idx1-ubyte.gz' | |
test_images = 't10k-images-idx3-ubyte.gz' | |
test_labels = 't10k-labels-idx1-ubyte.gz' | |
num_train = 60000 | |
num_test = 10000 | |
dim = 784 | |
def load_mnist(images, labels, num): | |
data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim)) | |
target = np.zeros(num, dtype=np.uint8).reshape((num, )) | |
with gzip.open(images, 'rb') as f_images,\ | |
gzip.open(labels, 'rb') as f_labels: | |
f_images.read(16) | |
f_labels.read(8) | |
for i in six.moves.range(num): | |
target[i] = ord(f_labels.read(1)) | |
for j in six.moves.range(dim): | |
data[i, j] = ord(f_images.read(1)) | |
return data, target | |
def download_mnist_data(): | |
print('Downloading {:s}...'.format(train_images)) | |
request.urlretrieve('{:s}/{:s}'.format(parent, train_images), train_images) | |
print('Done') | |
print('Downloading {:s}...'.format(train_labels)) | |
request.urlretrieve('{:s}/{:s}'.format(parent, train_labels), train_labels) | |
print('Done') | |
print('Downloading {:s}...'.format(test_images)) | |
request.urlretrieve('{:s}/{:s}'.format(parent, test_images), test_images) | |
print('Done') | |
print('Downloading {:s}...'.format(test_labels)) | |
request.urlretrieve('{:s}/{:s}'.format(parent, test_labels), test_labels) | |
print('Done') | |
print('Converting training data...') | |
data_train, target_train = load_mnist(train_images, train_labels, | |
num_train) | |
print('Done') | |
print('Converting test data...') | |
data_test, target_test = load_mnist(test_images, test_labels, num_test) | |
mnist = {} | |
mnist['data'] = np.append(data_train, data_test, axis=0) | |
mnist['target'] = np.append(target_train, target_test, axis=0) | |
print('Done') | |
print('Save output...') | |
with open('mnist.pkl', 'wb') as output: | |
six.moves.cPickle.dump(mnist, output, -1) | |
print('Done') | |
print('Convert completed') | |
def load_mnist_data(): | |
if not os.path.exists('mnist.pkl'): | |
download_mnist_data() | |
with open('mnist.pkl', 'rb') as mnist_pickle: | |
mnist = six.moves.cPickle.load(mnist_pickle) | |
return mnist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import logging | |
import argparse | |
import pickle | |
import numpy as np | |
from PIL import Image | |
import chainer | |
from chainer import cuda | |
import chainer.functions as F | |
import chainer.optimizers | |
import data | |
class GANModel(chainer.FunctionSet): | |
n_hidden = 100 | |
def __init__(self): | |
super(GANModel, self).__init__( | |
g_fc0=F.Linear(self.n_hidden, 500), | |
g_fc1=F.Linear(500, 500), | |
g_fc2=F.Linear(500, 784), | |
d_fc0=F.Linear(784, 240), | |
d_fc1=F.Linear(240, 240), | |
d_fc2=F.Linear(240, 1)) | |
@property | |
def generators(self): | |
return [self.g_fc0, self.g_fc1, self.g_fc2] | |
@property | |
def discriminators(self): | |
return [self.d_fc0, self.d_fc1, self.d_fc2] | |
def make_z(self, n): | |
return 0.2 * np.asarray( | |
np.random.randn(n, self.n_hidden), | |
dtype=np.float32) | |
def make_generator(self, z, train=True): | |
h = F.dropout(F.relu(self.g_fc0(z)), train=train) | |
h = F.dropout(F.relu(self.g_fc1(h)), train=train) | |
return self.g_fc2(h) | |
# return F.sigmoid(self.g_fc2(h)) | |
def make_discriminator(self, x, t, train=True): | |
h = F.relu(self.d_fc0(x)) | |
h = F.relu(self.d_fc1(h)) | |
h = self.d_fc2(h) | |
return F.sigmoid_cross_entropy(h, t) | |
def collect_generator_parameters(self): | |
parameters = ( | |
sum((f.parameters for f in self.generators), ()), | |
sum((f.gradients for f in self.generators), ()), | |
) | |
return parameters | |
def collect_discriminator_parameters(self): | |
parameters = ( | |
sum((f.parameters for f in self.discriminators), ()), | |
sum((f.gradients for f in self.discriminators), ()), | |
) | |
return parameters | |
def generate(self, z_data, train=True): | |
z = chainer.Variable(z_data) | |
x = self.make_generator(z, train=train) | |
return x.data | |
def forward_xy(self, x_data, t_data): | |
x = chainer.Variable(x_data) | |
t = chainer.Variable(t_data) | |
loss = self.make_discriminator(x, t) | |
return loss | |
def forward_zy(self, z_data, t_data): | |
z = chainer.Variable(z_data) | |
t = chainer.Variable(t_data) | |
x = self.make_generator(z) | |
loss = self.make_discriminator(x, t) | |
return loss | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--batch', type=int, default=50) | |
parser.add_argument('--epoch', type=int, default=30) | |
parser.add_argument('-g', '--gpu', type=int, default=-1) | |
parser.add_argument('--display', type=int, default=100) | |
parser.add_argument('--image', default='./') | |
parser.add_argument('--dst', default='model.pkl') | |
args = parser.parse_args() | |
logging.basicConfig( | |
format="%(asctime)s [%(levelname)s] %(message)s", | |
level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
if 0 <= args.gpu: | |
cuda.init(args.gpu) | |
logger.info('loading MNIST dataset...') | |
mnist = data.load_mnist_data() | |
mnist['data'] = mnist['data'].astype(np.float32) | |
mnist['data'] /= 255 | |
x_train = mnist['data'] | |
logger.info('constructing GAN model...') | |
model = GANModel() | |
if 0 <= args.gpu: | |
model = model.to_gpu(args.gpu) | |
logger.info('initializing two optimizers...') | |
g_optimizer = chainer.optimizers.Adam(alpha=-1e-5) | |
g_optimizer.setup(model.collect_generator_parameters()) | |
d_optimizer = chainer.optimizers.Adam(alpha=1e-5) | |
d_optimizer.setup(model.collect_discriminator_parameters()) | |
example_z = model.make_z(100) | |
if 0 <= args.gpu: | |
example_z = cuda.to_gpu(example_z, args.gpu) | |
iteration = 0 | |
for epoch in xrange(1, args.epoch + 1): | |
logger.info('epoch %i', epoch) | |
perm = np.random.permutation(x_train.shape[0]) | |
sum_dloss, sum_gloss = 0.0, 0.0 | |
example_x = cuda.to_cpu(model.generate(example_z, train=False)) | |
example_x = example_x.reshape(10, 10, 28, 28).transpose([0, 2, 1, 3]) | |
example_x = np.clip( | |
255 * example_x.reshape(280, 280), 0.0, 255.0).astype(np.uint8) | |
img = Image.fromarray(example_x) | |
img.save(os.path.join(args.image, "{:03}.png".format(epoch))) | |
for i in xrange(0, x_train.shape[0], args.batch): | |
iteration += 1 | |
batchsize = min(i + args.batch, x_train.shape[0]) - i | |
# update discriminator | |
x_batch = np.empty( | |
(2 * batchsize, x_train.shape[1]), dtype=np.float32) | |
t_batch = np.empty((2 * batchsize, 1), dtype=np.int32) | |
z_batch = model.make_z(batchsize) | |
if 0 <= args.gpu: | |
z_batch = cuda.to_gpu(z_batch, args.gpu) | |
x_batch[:batchsize] = cuda.to_cpu(model.generate(z_batch)) | |
x_batch[batchsize:] = x_train[perm[i:i + batchsize]] | |
t_batch[:batchsize] = 0 | |
t_batch[batchsize:] = 1 | |
if 0 <= args.gpu: | |
x_batch = cuda.to_gpu(x_batch, args.gpu) | |
t_batch = cuda.to_gpu(t_batch, args.gpu) | |
d_optimizer.zero_grads() | |
#g_optimizer.zero_grads() | |
dloss = model.forward_xy(x_batch, t_batch) | |
dloss.backward() | |
d_optimizer.update() | |
dloss_data = float(cuda.to_cpu(dloss.data)) | |
sum_dloss += dloss_data * batchsize * 2 | |
# update generator | |
z_batch = model.make_z(batchsize) | |
t_batch = np.zeros((batchsize, 1), dtype=np.int32) | |
if 0 <= args.gpu: | |
z_batch = cuda.to_gpu(z_batch, args.gpu) | |
t_batch = cuda.to_gpu(t_batch, args.gpu) | |
#d_optimizer.zero_grads() | |
g_optimizer.zero_grads() | |
gloss = model.forward_zy(z_batch, t_batch) | |
gloss.backward() | |
g_optimizer.update() | |
gloss_data = float(cuda.to_cpu(gloss.data)) | |
sum_gloss += gloss_data * batchsize | |
if iteration % args.display == 0: | |
logger.info( | |
'loss D:%.3e G:%.3e iter:%i', | |
dloss_data, gloss_data, iteration) | |
ave_dloss = sum_dloss / (2 * x_train.shape[0]) | |
ave_gloss = sum_gloss / x_train.shape[0] | |
logger.info( | |
'train mean loss D:%.3e G:%.3e epoch:%i', | |
ave_dloss, ave_gloss, epoch) | |
logger.info('done. now pickling model...') | |
with open(args.dst, 'wb') as fp: | |
pickle.dump(model, fp) | |
if __name__ == "__main__": | |
sys.exit(main()) |
@vebmaylrie Sorry for replying late.
Exactly, in order to make the discriminator missclassify the sampled data,
I set the alpha in the optimizer of the optimizer to the negative value.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is Line 169 np.ones(*) ? This is because the generator should be trained to make the discriminator missclassify the sampled data.