Skip to content

Instantly share code, notes, and snippets.

@nyk510
Created November 10, 2016 09:16
Show Gist options
  • Save nyk510/84b38176f701243918f4dc498f7775ba to your computer and use it in GitHub Desktop.
Save nyk510/84b38176f701243918f4dc498f7775ba to your computer and use it in GitHub Desktop.
# coding: utf-8
# In[3]:
from chainer import Chain
from chainer import Variable,optimizers
import chainer.functions as F
import chainer.links as L
# In[2]:
import matplotlib.pyplot as plt
import numpy as np
get_ipython().magic('matplotlib inline')
# In[194]:
class Generator(Chain):
'''ランダムなベクトルから画像を生成する画像作成機
'''
def __init__(self,z_dim):
super(Generator,self).__init__(
l1 = L.Linear(z_dim,3*3*512),
dc1 = L.Deconvolution2D(512, 256, 2, stride=2, pad=1,),
dc2 = L.Deconvolution2D(256, 128, 2, stride=2, pad=1,),
dc3 = L.Deconvolution2D(128, 64, 2, stride=2, pad=1,),
dc4 = L.Deconvolution2D(64, 1, 3, stride=3, pad=1),
# bn0 = L.BatchNormalization(6*6*512),
bn1 = L.BatchNormalization(512),
bn2 = L.BatchNormalization(256),
bn3 = L.BatchNormalization(128),
bn4 = L.BatchNormalization(64),
)
def __call__(self,z, test=False):
h = self.l1(z)
# 512チャンネルをもつ、6×6のベクトルに変換する
h = F.reshape(h,(z.data.shape[0], 512, 3, 3))
h = F.relu(self.bn1(h, test=test))
h = F.relu(self.bn2(self.dc1(h),test=test))
h = F.relu(self.bn3(self.dc2(h), test=test))
h = F.relu(self.bn4(self.dc3(h), test=test))
x = self.dc4(h)
return x
# In[195]:
class Descriminator(Chain):
def __init__(self,):
super(Descriminator,self).__init__(
c1 = L.Convolution2D(1, 64, 3, stride=3, pad=1, ),
c2 = L.Convolution2D(64, 128, 2, stride=2, pad=1,),
c3 = L.Convolution2D(128, 256, 2, stride=2, pad=1,),
c4 = L.Convolution2D(256, 512, 2, stride=2, pad=1,),
l1 = L.Linear(3*3*512, 2),
bn1 = L.BatchNormalization(128),
bn2 = L.BatchNormalization(256),
bn3 = L.BatchNormalization(512),
)
def __call__(self,x,test=False):
h = F.relu(self.c1(x))
h = F.relu(self.bn1(self.c2(h), test=test))
h = F.relu(self.bn2(self.c3(h), test=test))
h = F.relu(self.bn3(self.c4(h), test=test))
y = self.l1(h)
return y
# In[196]:
z_dim = 100
gn = Generator(z_dim=z_dim)
dc = Descriminator()
z = np.random.normal(size=1000,loc=10).reshape(10,-1).astype(np.float32)
z = Variable(z)
x = gn(z)
print(x.shape)
y = dc(x)
# In[187]:
x.shape
# In[188]:
y.data
# In[114]:
from sklearn.datasets import fetch_mldata
# In[115]:
data = fetch_mldata('MNIST original')
# In[118]:
X = data['data']
X = np.array(X, dtype=np.float32)
X /= 256.
# In[120]:
X.shape
# In[123]:
784 ** .5
# In[266]:
n_train = X.shape[0]
epochs = 1
batchsize = 1000
# In[267]:
X = X.reshape(n_train,1, 28,28)
# In[268]:
import pandas as pd
df_log = pd.DataFrame()
# In[269]:
x_data.shape
# In[307]:
z_dim = 100
gn = Generator(z_dim=z_dim)
dc = Descriminator()
o_gen = optimizers.Adam(beta1=.5)
o_dis = optimizers.Adam(beta1=.5)
o_gen.setup(gn)
o_dis.setup(dc)
# In[ ]:
for epoch in range(epochs):
perm = np.random.permutation(n_train)
sum_loss_of_dis = np.float32(0)
sum_loss_of_gen = np.float32(0)
for i in range(int(n_train/batchsize)):
print('iter {i}'.format(**locals()))
# load true data form dataset
x_data = X[i*batchsize:(i+1)*batchsize]
x_data = Variable(x_data)
z = np.random.uniform(-1,1,(batchsize, z_dim))
z = z.astype(dtype=np.float32)
z = Variable(z)
x = gn(z)
y1 = dc(x)
# 答え合わせ
# ジェネレーターとしては0と判別させたい(騙すことが目的)
loss_gen = F.softmax_cross_entropy(y1, Variable(np.zeros(batchsize, dtype=np.int32)))
# 判別機としては1(偽物)と判別したい
loss_dis = F.softmax_cross_entropy(y1, Variable(np.ones(batchsize, dtype=np.int32)))
# 正しい画像に対しても
y2 = dc(x_data)
# 今度は正しい画像なので、0(正しい画像)と判別したい
loss_dis += F.softmax_cross_entropy(y2, Variable(np.zeros(batchsize, dtype=np.int32)))
o_gen.zero_grads()
loss_gen.backward()
o_gen.update()
o_dis.zero_grads()
loss_dis.backward()
o_dis.update()
sum_loss_of_dis += loss_dis.data
sum_loss_of_gen += loss_gen.data
print('loss\tdis-{sum_loss_of_dis}_gen-{sum_loss_of_gen}'.format(**locals()))
# In[274]:
plt.imshow(x_data[0].data.reshape(28,28))
# In[298]:
z = Variable(np.random.uniform(-1,1,1000).reshape(-1,100).astype(np.float32))
x = gn(z)
y = dc(x)
x = x.data
x = X[perm[:10]]
# In[299]:
x = x.reshape(-1,28,28)
# In[300]:
for i,xx in enumerate(x):
plt.subplot(4,4,i+1)
plt.imshow(xx)
# In[301]:
y.data
# In[305]:
x.shape
# In[306]:
dc(Variable(x.reshape(-1,1,28,28))).data
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment