Created
August 19, 2019 08:25
-
-
Save koshian2/d044981c732df855ed25ee4a7372a53e to your computer and use it in GitHub Desktop.
ACGAN(1) CIFAR-10
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| class Generator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = nn.Sequential( | |
| nn.Linear(110, 384), | |
| nn.ReLU(True) | |
| ) | |
| self.conv1 = self.transposeconv_bn_relu(384, 192, 5) | |
| self.conv2 = self.transposeconv_bn_relu(192, 96, 6) | |
| self.conv3 = self.transposeconv_bn_relu(96, 3, 6, use_bn=False, act="tanh") | |
| def transposeconv_bn_relu(self, in_ch, out_ch, kernel_size, use_bn=True, act="relu"): | |
| layers = [] | |
| layers.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride=2)) | |
| if use_bn: | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| if act == "relu": | |
| layers.append(nn.ReLU(True)) | |
| elif act == "tanh": | |
| layers.append(nn.Tanh()) | |
| return nn.Sequential(*layers) | |
| def forward(self, inputs): | |
| x = self.linear(inputs).view(inputs.size(0), -1, 1, 1) | |
| return self.conv3(self.conv2(self.conv1(x))) | |
| class Discriminator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = self.conv_bn_lkrelu(3, 16, 2, use_bn=False) | |
| self.conv2 = self.conv_bn_lkrelu(16, 32, 1) | |
| self.conv3 = self.conv_bn_lkrelu(32, 64, 2) | |
| self.conv4 = self.conv_bn_lkrelu(64, 128, 1) | |
| self.conv5 = self.conv_bn_lkrelu(128, 256, 2) | |
| self.conv6 = self.conv_bn_lkrelu(256, 512, 1) | |
| self.prob = nn.Linear(8192, 1) | |
| self.classes = nn.Linear(8192, 10) | |
| def conv_bn_lkrelu(self, in_ch, out_ch, stride, use_bn=True): | |
| layers = [] | |
| layers.append(nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1)) | |
| if use_bn: | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.LeakyReLU(0.2, True)) | |
| layers.append(nn.Dropout(0.5)) | |
| return nn.Sequential(*layers) | |
| def forward(self, inputs): | |
| x = self.conv6(self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(inputs)))))) | |
| x = x.view(x.size(0), -1) | |
| return self.prob(x), self.classes(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torchvision | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| import numpy as np | |
| from models import Generator, Discriminator | |
| import os | |
| import pickle | |
| import statistics | |
| def load_dataset(batch_size): | |
| trans = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=trans, download=True) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| return dataloader | |
| def weight_init(layer): | |
| if type(layer) in [nn.Conv2d, nn.ConvTranspose2d]: | |
| nn.init.normal_(layer.weight, 0.0, 0.02) | |
| nn.init.zeros_(layer.bias) | |
| class ACGAN_loss(): | |
| def __init__(self, batch_size, device): | |
| self.ones = torch.ones(batch_size, 1).to(device) | |
| self.zeros = torch.zeros(batch_size, 1).to(device) | |
| self.source_loss = torch.nn.BCEWithLogitsLoss() | |
| self.classes_loss = torch.nn.CrossEntropyLoss() | |
| def __call__(self, real_outs, fake_outs, real_label, network_type): | |
| assert network_type in ["D", "G"] | |
| batch_len = len(real_outs[0]) | |
| loss_s = self.source_loss(real_outs[0], self.ones[:batch_len]) | |
| loss_s += self.source_loss(fake_outs[0], self.zeros[:batch_len]) | |
| loss_c = self.classes_loss(real_outs[1], real_label) | |
| loss_c += self.classes_loss(fake_outs[1], real_label) | |
| if network_type == "D": | |
| return loss_s + loss_c | |
| else: | |
| return loss_c - loss_s | |
| def train(): | |
| output_dir = "cifar_acgan" | |
| device = "cuda" | |
| batch_size = 100 | |
| dataloader = load_dataset(batch_size) | |
| model_G = Generator() | |
| model_D = Discriminator() | |
| model_G.apply(weight_init) | |
| model_D.apply(weight_init) | |
| model_G, model_D = model_G.to(device), model_D.to(device) | |
| if device == "cuda": | |
| model_G, model_D = torch.nn.DataParallel(model_G), torch.nn.DataParallel(model_D) | |
| param_G = torch.optim.Adam(model_G.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
| param_D = torch.optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
| loss_func = ACGAN_loss(batch_size, device) | |
| result = {"d_loss":[], "g_loss":[]} | |
| for epoch in range(101): | |
| log_loss_D, log_loss_G = [], [] | |
| for real_img, real_label in tqdm(dataloader): | |
| batch_len = len(real_img) | |
| real_img, real_label = real_img.to(device), real_label.to(device) | |
| # train G | |
| rand_X = torch.randn(batch_len, 100) | |
| label_onehot = torch.eye(10)[real_label] | |
| rand_X = torch.cat([rand_X, label_onehot], dim=1) | |
| rand_X = rand_X.to(device) | |
| fake_img = model_G(rand_X) | |
| fake_img_tensor = fake_img.detach() | |
| fake_out = model_D(fake_img) | |
| real_out = model_D(real_img) | |
| loss = loss_func(real_out, fake_out, real_label, "G") | |
| log_loss_G.append(loss.item()) | |
| # backprop | |
| param_D.zero_grad() | |
| param_G.zero_grad() | |
| loss.backward() | |
| param_G.step() | |
| # train D | |
| # train real | |
| d_out_real = model_D(real_img) | |
| # train fake | |
| d_out_fake = model_D(fake_img_tensor) | |
| loss = loss_func(d_out_real, d_out_fake, real_label, "D") | |
| log_loss_D.append(loss.item()) | |
| # backprop | |
| param_D.zero_grad() | |
| param_G.zero_grad() | |
| loss.backward() | |
| param_D.step() | |
| # ログ | |
| result["d_loss"].append(statistics.mean(log_loss_D)) | |
| result["g_loss"].append(statistics.mean(log_loss_G)) | |
| print(f"epoch = {epoch}, g_loss = {result['g_loss'][-1]}, d_loss = {result['d_loss'][-1]}") | |
| if not os.path.exists(output_dir): | |
| os.mkdir(output_dir) | |
| torchvision.utils.save_image(fake_img_tensor[:100], f"{output_dir}/epoch_{epoch:03}.png", nrow=10, | |
| padding=3, normalize=True, range=(-1.0, 1.0)) | |
| # 係数保存 | |
| if not os.path.exists(output_dir + "/models"): | |
| os.mkdir(output_dir+"/models") | |
| if epoch % 10 == 0: | |
| torch.save(model_G.state_dict(), f"{output_dir}/models/gen_epoch_{epoch:03}.pytorch") | |
| torch.save(model_D.state_dict(), f"{output_dir}/models/dis_epoch_{epoch:03}.pytorch") | |
| # ログ | |
| with open(output_dir + "/logs.pkl", "wb") as fp: | |
| pickle.dump(result, fp) | |
| if __name__ == "__main__": | |
| train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment