Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created July 31, 2019 16:54
Show Gist options
  • Save koshian2/efec585c5041e2b1dbb64e311436ca52 to your computer and use it in GitHub Desktop.
Save koshian2/efec585c5041e2b1dbb64e311436ca52 to your computer and use it in GitHub Desktop.
max(log D) DCGAN, CIFAR-10
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from tqdm import tqdm
import statistics
import os
import pickle
import glob
from inception_score import inception_score
def weight_init(layer):
if type(layer) in [nn.Conv2d, nn.ConvTranspose2d]:
nn.init.normal_(layer.weight, 0.0, 0.02)
# 8,286,339
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_act(100, 512, 4) # 4x4
self.conv2 = self.conv_bn_act(512, 256, 2) # 8x8
self.conv3 = self.conv_bn_act(256, 128, 2) # 16x16
self.conv4 = self.conv_bn_act(128, 64, 2) # 32x32
self.out = nn.Sequential(
nn.Conv2d(64, 3, kernel_size=3, padding=1),
nn.Tanh()
)
def conv_bn_act(self, in_ch, out_ch, upsampling_scale, reps=3):
layers = []
if upsampling_scale > 1:
layers.append(nn.UpsamplingNearest2d(scale_factor=upsampling_scale))
for i in range(reps):
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch, out_ch, kernel_size=3, padding=1))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(True))
return nn.Sequential(*layers)
def forward(self, x):
return self.out(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
# 1,553,409
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_act(3, 64, 1)
self.conv2 = self.conv_bn_act(64, 128, 2)
self.conv3 = self.conv_bn_act(128, 256, 2)
self.conv4 = self.conv_bn_act(256, 512, 2)
self.out = nn.Sequential(
nn.AvgPool2d(4),
nn.Conv2d(512, 1, kernel_size=1),
nn.Sigmoid()
)
def conv_bn_act(self, in_ch, out_ch, downsampling):
layers = []
if downsampling > 1:
layers.append(nn.AvgPool2d(downsampling))
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(True))
return nn.Sequential(*layers)
def forward(self, inputs):
x = self.out(self.conv4(self.conv3(self.conv2(self.conv1(inputs)))))
return x.view(x.size(0), -1)
def load_data(batch_size):
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=trans, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
return dataloader
def train(g_loss):
device = "cuda"
batch_size = 256
trainloader = load_data(batch_size)
model_G = Generator()
model_D = Discriminator()
model_G.apply(weight_init)
model_D.apply(weight_init)
if device == "cuda":
model_D = torch.nn.DataParallel(model_D.to(device))
model_G = torch.nn.DataParallel(model_G.to(device))
param_G = torch.optim.Adam(model_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
param_D = torch.optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999))
bce_loss = torch.nn.BCELoss()
ones = torch.ones(batch_size, 1).to(device)
zeros = torch.zeros(batch_size, 1).to(device)
result = {"d_loss":[], "g_loss":[]}
for epoch in range(300):
log_loss_D, log_loss_G = [], []
for real_img, _ in tqdm(trainloader):
batch_len = len(real_img)
real_img = real_img.to(device)
# train G
rand = torch.randn(batch_len, 100, 1, 1)
fake_img = model_G(rand)
fake_img_tensor = fake_img.detach()
g_out = model_D(fake_img)
if g_loss == "min":
loss = bce_loss(g_out, ones[:batch_len])
elif g_loss == "max":
loss = -bce_loss(g_out, zeros[:batch_len])
log_loss_G.append(loss.item())
# backprop
param_D.zero_grad()
param_G.zero_grad()
loss.backward()
param_G.step()
# train D
# -- real as one
d_out = model_D(real_img)
loss_real = bce_loss(d_out, ones[:batch_len])
# -- fake as zeros
d_out = model_D(fake_img_tensor)
loss_fake = bce_loss(d_out, zeros[:batch_len])
loss = (loss_real + loss_fake) / 2.0
log_loss_D.append(loss.item())
# backprop
param_D.zero_grad()
param_G.zero_grad()
loss.backward()
param_D.step()
# ログ
result["d_loss"].append(statistics.mean(log_loss_D))
result["g_loss"].append(statistics.mean(log_loss_G))
print(f"epoch = {epoch}, g_loss = {result['g_loss'][-1]}, d_loss = {result['d_loss'][-1]}")
# 記録
if not os.path.exists(g_loss):
os.mkdir(g_loss)
torchvision.utils.save_image(fake_img_tensor[:256], f"{g_loss}/epoch_{epoch:03}.png", nrow=16, padding=3, normalize=True, range=(-1.0, 1.0))
# 係数保存
if not os.path.exists(g_loss + "/models"):
os.mkdir(g_loss+"/models")
if epoch % 10 == 0:
torch.save(model_G.state_dict(), f"{g_loss}/models/gen_epoch_{epoch:03}.pytorch")
torch.save(model_D.state_dict(), f"{g_loss}/models/dis_epoch_{epoch:03}.pytorch")
# ログ
with open(g_loss + "/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
def calc_inception(directory):
device = "cuda"
files = sorted(glob.glob(f"{directory}/gen*" ))
result = {}
for f in tqdm(files):
model = Generator().to(device)
if device == "cuda":
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(f))
model.eval()
images = []
for i in range(500): # 500
x = torch.randn(100, 100, 1, 1).to(device)
images.append(model(x).detach())
output = torch.cat(images, dim=0)
key = os.path.basename(f).replace(".pytorch", "")
result[key] = inception_score(output, cuda=True, batch_size=32, resize=True)
print(result)
with open(f"is_{directory.replace('/models', '')}.pkl", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
#calc_inception("max/models")
#exit()
train("max")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment