Created
July 31, 2019 16:54
-
-
Save koshian2/efec585c5041e2b1dbb64e311436ca52 to your computer and use it in GitHub Desktop.
max(log D) DCGAN, CIFAR-10
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| import statistics | |
| import os | |
| import pickle | |
| import glob | |
| from inception_score import inception_score | |
| def weight_init(layer): | |
| if type(layer) in [nn.Conv2d, nn.ConvTranspose2d]: | |
| nn.init.normal_(layer.weight, 0.0, 0.02) | |
| # 8,286,339 | |
| class Generator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = self.conv_bn_act(100, 512, 4) # 4x4 | |
| self.conv2 = self.conv_bn_act(512, 256, 2) # 8x8 | |
| self.conv3 = self.conv_bn_act(256, 128, 2) # 16x16 | |
| self.conv4 = self.conv_bn_act(128, 64, 2) # 32x32 | |
| self.out = nn.Sequential( | |
| nn.Conv2d(64, 3, kernel_size=3, padding=1), | |
| nn.Tanh() | |
| ) | |
| def conv_bn_act(self, in_ch, out_ch, upsampling_scale, reps=3): | |
| layers = [] | |
| if upsampling_scale > 1: | |
| layers.append(nn.UpsamplingNearest2d(scale_factor=upsampling_scale)) | |
| for i in range(reps): | |
| layers.append(nn.Conv2d(in_ch if i == 0 else out_ch, out_ch, kernel_size=3, padding=1)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.ReLU(True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.out(self.conv4(self.conv3(self.conv2(self.conv1(x))))) | |
| # 1,553,409 | |
| class Discriminator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = self.conv_bn_act(3, 64, 1) | |
| self.conv2 = self.conv_bn_act(64, 128, 2) | |
| self.conv3 = self.conv_bn_act(128, 256, 2) | |
| self.conv4 = self.conv_bn_act(256, 512, 2) | |
| self.out = nn.Sequential( | |
| nn.AvgPool2d(4), | |
| nn.Conv2d(512, 1, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| def conv_bn_act(self, in_ch, out_ch, downsampling): | |
| layers = [] | |
| if downsampling > 1: | |
| layers.append(nn.AvgPool2d(downsampling)) | |
| layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.ReLU(True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, inputs): | |
| x = self.out(self.conv4(self.conv3(self.conv2(self.conv1(inputs))))) | |
| return x.view(x.size(0), -1) | |
| def load_data(batch_size): | |
| trans = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=trans, download=True) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) | |
| return dataloader | |
| def train(g_loss): | |
| device = "cuda" | |
| batch_size = 256 | |
| trainloader = load_data(batch_size) | |
| model_G = Generator() | |
| model_D = Discriminator() | |
| model_G.apply(weight_init) | |
| model_D.apply(weight_init) | |
| if device == "cuda": | |
| model_D = torch.nn.DataParallel(model_D.to(device)) | |
| model_G = torch.nn.DataParallel(model_G.to(device)) | |
| param_G = torch.optim.Adam(model_G.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
| param_D = torch.optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
| bce_loss = torch.nn.BCELoss() | |
| ones = torch.ones(batch_size, 1).to(device) | |
| zeros = torch.zeros(batch_size, 1).to(device) | |
| result = {"d_loss":[], "g_loss":[]} | |
| for epoch in range(300): | |
| log_loss_D, log_loss_G = [], [] | |
| for real_img, _ in tqdm(trainloader): | |
| batch_len = len(real_img) | |
| real_img = real_img.to(device) | |
| # train G | |
| rand = torch.randn(batch_len, 100, 1, 1) | |
| fake_img = model_G(rand) | |
| fake_img_tensor = fake_img.detach() | |
| g_out = model_D(fake_img) | |
| if g_loss == "min": | |
| loss = bce_loss(g_out, ones[:batch_len]) | |
| elif g_loss == "max": | |
| loss = -bce_loss(g_out, zeros[:batch_len]) | |
| log_loss_G.append(loss.item()) | |
| # backprop | |
| param_D.zero_grad() | |
| param_G.zero_grad() | |
| loss.backward() | |
| param_G.step() | |
| # train D | |
| # -- real as one | |
| d_out = model_D(real_img) | |
| loss_real = bce_loss(d_out, ones[:batch_len]) | |
| # -- fake as zeros | |
| d_out = model_D(fake_img_tensor) | |
| loss_fake = bce_loss(d_out, zeros[:batch_len]) | |
| loss = (loss_real + loss_fake) / 2.0 | |
| log_loss_D.append(loss.item()) | |
| # backprop | |
| param_D.zero_grad() | |
| param_G.zero_grad() | |
| loss.backward() | |
| param_D.step() | |
| # ログ | |
| result["d_loss"].append(statistics.mean(log_loss_D)) | |
| result["g_loss"].append(statistics.mean(log_loss_G)) | |
| print(f"epoch = {epoch}, g_loss = {result['g_loss'][-1]}, d_loss = {result['d_loss'][-1]}") | |
| # 記録 | |
| if not os.path.exists(g_loss): | |
| os.mkdir(g_loss) | |
| torchvision.utils.save_image(fake_img_tensor[:256], f"{g_loss}/epoch_{epoch:03}.png", nrow=16, padding=3, normalize=True, range=(-1.0, 1.0)) | |
| # 係数保存 | |
| if not os.path.exists(g_loss + "/models"): | |
| os.mkdir(g_loss+"/models") | |
| if epoch % 10 == 0: | |
| torch.save(model_G.state_dict(), f"{g_loss}/models/gen_epoch_{epoch:03}.pytorch") | |
| torch.save(model_D.state_dict(), f"{g_loss}/models/dis_epoch_{epoch:03}.pytorch") | |
| # ログ | |
| with open(g_loss + "/logs.pkl", "wb") as fp: | |
| pickle.dump(result, fp) | |
| def calc_inception(directory): | |
| device = "cuda" | |
| files = sorted(glob.glob(f"{directory}/gen*" )) | |
| result = {} | |
| for f in tqdm(files): | |
| model = Generator().to(device) | |
| if device == "cuda": | |
| model = torch.nn.DataParallel(model) | |
| model.load_state_dict(torch.load(f)) | |
| model.eval() | |
| images = [] | |
| for i in range(500): # 500 | |
| x = torch.randn(100, 100, 1, 1).to(device) | |
| images.append(model(x).detach()) | |
| output = torch.cat(images, dim=0) | |
| key = os.path.basename(f).replace(".pytorch", "") | |
| result[key] = inception_score(output, cuda=True, batch_size=32, resize=True) | |
| print(result) | |
| with open(f"is_{directory.replace('/models', '')}.pkl", "wb") as fp: | |
| pickle.dump(result, fp) | |
| if __name__ == "__main__": | |
| #calc_inception("max/models") | |
| #exit() | |
| train("max") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment