Created
June 27, 2019 09:11
-
-
Save koshian2/4b00bd6d2453450c1d7ea703f2218ca3 to your computer and use it in GitHub Desktop.
Pix2pix STL Colorize
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| from tqdm import tqdm | |
| import os | |
| import pickle | |
| import statistics | |
| class ColorAndGray(object): | |
| def __call__(self, img): | |
| # ToTensor()の前に呼ぶ場合はimgはPILのインスタンス | |
| gray = img.convert("L") | |
| return img, gray | |
| # 複数の入力をtransformsに展開するラッパークラスを作る | |
| class MultiInputWrapper(object): | |
| def __init__(self, base_func): | |
| self.base_func = base_func | |
| def __call__(self, xs): | |
| if isinstance(self.base_func, list): | |
| return [f(x) for f, x in zip(self.base_func, xs)] | |
| else: | |
| return [self.base_func(x) for x in xs] | |
| def load_datasets(): | |
| transform = transforms.Compose([ | |
| ColorAndGray(), | |
| MultiInputWrapper(transforms.ToTensor()), | |
| MultiInputWrapper([ | |
| transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)), | |
| transforms.Normalize(mean=(0.5,), std=(0.5,)) | |
| ]) | |
| ]) | |
| trainset = torchvision.datasets.STL10(root="./data", | |
| split="unlabeled", | |
| download=True, | |
| transform=transform) | |
| train_loader = torch.utils.data.DataLoader(trainset, batch_size=512, | |
| shuffle=True, num_workers=4, pin_memory=True) | |
| return train_loader | |
| class Generator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96 | |
| self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24 | |
| self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12 | |
| self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6 | |
| self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12 | |
| self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24 | |
| self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96 | |
| self.dec4 = nn.Sequential( | |
| nn.Conv2d(32 + 32, 3, kernel_size=5, padding=2), | |
| nn.Tanh() | |
| ) | |
| def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None): | |
| layers = [] | |
| if pool_kernel is not None: | |
| if pool_kernel > 0: | |
| layers.append(nn.AvgPool2d(pool_kernel)) | |
| elif pool_kernel < 0: | |
| layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel)) | |
| layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x1 = self.enc1(x) | |
| x2 = self.enc2(x1) | |
| x3 = self.enc3(x2) | |
| x4 = self.enc4(x3) | |
| out = self.dec1(x4) | |
| out = self.dec2(torch.cat([out, x3], dim=1)) | |
| out = self.dec3(torch.cat([out, x2], dim=1)) | |
| out = self.dec4(torch.cat([out, x1], dim=1)) | |
| return out | |
| class Discriminator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = self.conv_bn_relu(4, 16, kernel_size=5, reps=1) # fake/true color + gray | |
| self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4) | |
| self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2) | |
| self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2) | |
| self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2) | |
| self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3 | |
| def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2): | |
| layers = [] | |
| for i in range(reps): | |
| if i == 0 and pool_kernel is not None: | |
| layers.append(nn.AvgPool2d(pool_kernel)) | |
| layers.append(nn.Conv2d(in_ch if i == 0 else out_ch, | |
| out_ch, kernel_size, padding=(kernel_size - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x))))) | |
| return self.out_patch(out) | |
| def train(): | |
| # モデル | |
| device = "cuda" | |
| torch.backends.cudnn.benchmark = True | |
| model_G, model_D = Generator(), Discriminator() | |
| model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D) | |
| model_G, model_D = model_G.to(device), model_D.to(device) | |
| params_G = torch.optim.Adam(model_G.parameters(), | |
| lr=0.0002, betas=(0.5, 0.999)) | |
| params_D = torch.optim.Adam(model_D.parameters(), | |
| lr=0.0002, betas=(0.5, 0.999)) | |
| # ロスを計算するためのラベル変数 (PatchGAN) | |
| ones = torch.ones(512, 1, 3, 3).to(device) | |
| zeros = torch.zeros(512, 1, 3, 3).to(device) | |
| # 損失関数 | |
| bce_loss = nn.BCEWithLogitsLoss() | |
| mae_loss = nn.L1Loss() | |
| # エラー推移 | |
| result = {} | |
| result["log_loss_G_sum"] = [] | |
| result["log_loss_G_bce"] = [] | |
| result["log_loss_G_mae"] = [] | |
| result["log_loss_D"] = [] | |
| # 訓練 | |
| dataset = load_datasets() | |
| for i in range(200): | |
| log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], [] | |
| for (real_color, input_gray), _ in tqdm(dataset): | |
| batch_len = len(real_color) | |
| real_color, input_gray = real_color.to(device), input_gray.to(device) | |
| # Gの訓練 | |
| # 偽のカラー画像を作成 | |
| fake_color = model_G(input_gray) | |
| # 偽画像を一時保存 | |
| fake_color_tensor = fake_color.detach() | |
| # 偽画像を本物と騙せるようにロスを計算 | |
| LAMBD = 100.0 # BCEとMAEの係数 | |
| out = model_D(torch.cat([fake_color, input_gray], dim=1)) | |
| loss_G_bce = bce_loss(out, ones[:batch_len]) | |
| loss_G_mae = LAMBD * mae_loss(fake_color, real_color) | |
| loss_G_sum = loss_G_bce + loss_G_mae | |
| log_loss_G_bce.append(loss_G_bce.item()) | |
| log_loss_G_mae.append(loss_G_mae.item()) | |
| log_loss_G_sum.append(loss_G_sum.item()) | |
| # 微分計算・重み更新 | |
| params_D.zero_grad() | |
| params_G.zero_grad() | |
| loss_G_sum.backward() | |
| params_G.step() | |
| # Discriminatoの訓練 | |
| # 本物のカラー画像を本物と識別できるようにロスを計算 | |
| real_out = model_D(torch.cat([real_color, input_gray], dim=1)) | |
| loss_D_real = bce_loss(real_out, ones[:batch_len]) | |
| # 偽の画像の偽と識別できるようにロスを計算 | |
| fake_out = model_D(torch.cat([fake_color_tensor, input_gray], dim=1)) | |
| loss_D_fake = bce_loss(fake_out, zeros[:batch_len]) | |
| # 実画像と偽画像のロスを合計 | |
| loss_D = loss_D_real + loss_D_fake | |
| log_loss_D.append(loss_D.item()) | |
| # 微分計算・重み更新 | |
| params_D.zero_grad() | |
| params_G.zero_grad() | |
| loss_D.backward() | |
| params_D.step() | |
| result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum)) | |
| result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce)) | |
| result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae)) | |
| result["log_loss_D"].append(statistics.mean(log_loss_D)) | |
| print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " + | |
| f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " + | |
| f"log_loss_D = {result['log_loss_D'][-1]}") | |
| # 画像を保存 | |
| if not os.path.exists("stl_color"): | |
| os.mkdir("stl_color") | |
| # 生成画像を保存 | |
| torchvision.utils.save_image(fake_color_tensor[:min(batch_len, 100)], | |
| f"stl_color/fake_epoch_{i:03}.png", | |
| range=(-1.0,1.0), normalize=True) | |
| torchvision.utils.save_image(real_color[:min(batch_len, 100)], | |
| f"stl_color/real_epoch_{i:03}.png", | |
| range=(-1.0, 1.0), normalize=True) | |
| # モデルの保存 | |
| if not os.path.exists("stl_color/models"): | |
| os.mkdir("stl_color/models") | |
| if i % 10 == 0 or i == 199: | |
| torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch") | |
| torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch") | |
| # ログの保存 | |
| with open("stl_color/logs.pkl", "wb") as fp: | |
| pickle.dump(result, fp) | |
| if __name__ == "__main__": | |
| train() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| from tqdm import tqdm | |
| import os | |
| import pickle | |
| import statistics | |
| def load_datasets(): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| trainset = torchvision.datasets.STL10(root="./data", | |
| split="unlabeled", | |
| download=True, | |
| transform=transform) | |
| train_loader = torch.utils.data.DataLoader(trainset, batch_size=512, | |
| shuffle=True, num_workers=4, pin_memory=True) | |
| return train_loader | |
| # 色空間変換用の定数 | |
| RGB2YCrCb = np.array([[0.299, 0.587, 0.114], | |
| [0.5, -0.418688, -0.081312], | |
| [-0.168736, -0.331264, 0.5]], np.float32) | |
| YCrCb2RGB = np.array([[1, 1.402, 0], | |
| [1, -0.714136, -0.344136], | |
| [1, 0, 1.772]], np.float32) | |
| RGB2YCrCb = torch.as_tensor(RGB2YCrCb.reshape(3, 3, 1, 1)).to("cuda") | |
| YCrCb2RGB = torch.as_tensor(YCrCb2RGB.reshape(3, 3, 1, 1)).to("cuda") | |
| def preprocess_generator(rgb_tensor): | |
| x = nn.functional.conv2d(rgb_tensor, RGB2YCrCb) # Yが0 - 1, CbCrが-0.5 - 0.5 | |
| x *= 2.0 # CbCrは-1 - 1になったのでOK、Yが0-2 | |
| x[:, 0,:,:] -= 1.0 # Yを-1 - 1にする | |
| return x | |
| def deprocess_generator(ycrcb_tensor): | |
| # inputが全て-1 - 1のスケールなので本来のスケールに直す | |
| x = ycrcb_tensor / 2.0 # 全て-0.5-0.5, CbCrはOK | |
| x[:, 0,:,:] += 0.5 # Yのスケールを0-1にする | |
| # RGBに変換 (0-1) | |
| return nn.functional.conv2d(x, YCrCb2RGB) | |
| class Generator(nn.Module): | |
| # input : 1x96x96 の Y [-1, 1] 本当は[0, 1] | |
| # output : 2x96x96 の CrCb [-1, 1] 本当は[-0.5, 0.5] | |
| def __init__(self): | |
| super().__init__() | |
| self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96 | |
| self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24 | |
| self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12 | |
| self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6 | |
| self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12 | |
| self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24 | |
| self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96 | |
| self.dec4 = nn.Sequential( | |
| nn.Conv2d(32 + 32, 2, kernel_size=5, padding=2), | |
| nn.Tanh() | |
| ) | |
| def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None): | |
| layers = [] | |
| if pool_kernel is not None: | |
| if pool_kernel > 0: | |
| layers.append(nn.AvgPool2d(pool_kernel)) | |
| elif pool_kernel < 0: | |
| layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel)) | |
| layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x1 = self.enc1(x) | |
| x2 = self.enc2(x1) | |
| x3 = self.enc3(x2) | |
| x4 = self.enc4(x3) | |
| out = self.dec1(x4) | |
| out = self.dec2(torch.cat([out, x3], dim=1)) | |
| out = self.dec3(torch.cat([out, x2], dim=1)) | |
| out = self.dec4(torch.cat([out, x1], dim=1)) | |
| return out | |
| class Discriminator(nn.Module): | |
| # Inputの色空間はYCrCb→RGBにする | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = self.conv_bn_relu(3, 16, kernel_size=5, reps=1) # RGB | |
| self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4) | |
| self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2) | |
| self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2) | |
| self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2) | |
| self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3 | |
| def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2): | |
| layers = [] | |
| for i in range(reps): | |
| if i == 0 and pool_kernel is not None: | |
| layers.append(nn.AvgPool2d(pool_kernel)) | |
| layers.append(nn.Conv2d(in_ch if i == 0 else out_ch, | |
| out_ch, kernel_size, padding=(kernel_size - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x))))) | |
| return self.out_patch(out) | |
| def train(): | |
| # モデル | |
| device = "cuda" | |
| torch.backends.cudnn.benchmark = True | |
| model_G, model_D = Generator(), Discriminator() | |
| model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D) | |
| model_G, model_D = model_G.to(device), model_D.to(device) | |
| params_G = torch.optim.Adam(model_G.parameters(), | |
| lr=0.0002, betas=(0.5, 0.999)) | |
| params_D = torch.optim.Adam(model_D.parameters(), | |
| lr=0.0002, betas=(0.5, 0.999)) | |
| # ロスを計算するためのラベル変数 (PatchGAN) | |
| ones = torch.ones(512, 1, 3, 3).to(device) | |
| zeros = torch.zeros(512, 1, 3, 3).to(device) | |
| # 損失関数 | |
| bce_loss = nn.BCEWithLogitsLoss() | |
| mae_loss = nn.L1Loss() | |
| # エラー推移 | |
| result = {} | |
| result["log_loss_G_sum"] = [] | |
| result["log_loss_G_bce"] = [] | |
| result["log_loss_G_mae"] = [] | |
| result["log_loss_D"] = [] | |
| # 訓練 | |
| dataset = load_datasets() | |
| for i in range(200): | |
| log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], [] | |
| for real_rgb, _ in tqdm(dataset): | |
| batch_len = len(real_rgb) | |
| real_rgb = real_rgb.to(device) | |
| real_ycrcb = preprocess_generator(real_rgb) | |
| # Gの訓練 | |
| # 偽のカラー画像を作成 | |
| fake_crcb = model_G(real_ycrcb[:,:1,:,:]) | |
| fake_ycrcb = torch.cat([real_ycrcb[:,:1,:,:], fake_crcb], dim=1) | |
| fake_rgb = deprocess_generator(fake_ycrcb) | |
| # 偽画像を一時保存 | |
| fake_rgb_tensor = fake_rgb.detach() | |
| # 偽画像を本物と騙せるようにロスを計算 | |
| out = model_D(fake_rgb) | |
| loss_G_bce = bce_loss(out, ones[:batch_len]) | |
| loss_G_mae = 75 * mae_loss(fake_crcb, real_ycrcb[:, 1:,:,:]) + 25 * mae_loss(fake_rgb, real_rgb) | |
| loss_G_sum = loss_G_bce + loss_G_mae | |
| log_loss_G_bce.append(loss_G_bce.item()) | |
| log_loss_G_mae.append(loss_G_mae.item()) | |
| log_loss_G_sum.append(loss_G_sum.item()) | |
| # 微分計算・重み更新 | |
| params_D.zero_grad() | |
| params_G.zero_grad() | |
| loss_G_sum.backward() | |
| params_G.step() | |
| # Discriminatoの訓練 | |
| # 本物のカラー画像を本物と識別できるようにロスを計算 | |
| real_out = model_D(real_rgb) | |
| loss_D_real = bce_loss(real_out, ones[:batch_len]) | |
| # 偽の画像の偽と識別できるようにロスを計算 | |
| fake_out = model_D(fake_rgb_tensor) | |
| loss_D_fake = bce_loss(fake_out, zeros[:batch_len]) | |
| # 実画像と偽画像のロスを合計 | |
| loss_D = loss_D_real + loss_D_fake | |
| log_loss_D.append(loss_D.item()) | |
| # 微分計算・重み更新 | |
| params_D.zero_grad() | |
| params_G.zero_grad() | |
| loss_D.backward() | |
| params_D.step() | |
| result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum)) | |
| result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce)) | |
| result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae)) | |
| result["log_loss_D"].append(statistics.mean(log_loss_D)) | |
| print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " + | |
| f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " + | |
| f"log_loss_D = {result['log_loss_D'][-1]}") | |
| # 画像を保存 | |
| if not os.path.exists("stl_color"): | |
| os.mkdir("stl_color") | |
| # 生成画像を保存 | |
| torchvision.utils.save_image(fake_rgb_tensor[:min(batch_len, 100)], | |
| f"stl_color/fake_epoch_{i:03}.png") | |
| torchvision.utils.save_image(real_rgb[:min(batch_len, 100)], | |
| f"stl_color/real_epoch_{i:03}.png") | |
| # モデルの保存 | |
| if not os.path.exists("stl_color/models"): | |
| os.mkdir("stl_color/models") | |
| if i % 10 == 0 or i == 199: | |
| torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch") | |
| torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch") | |
| # ログの保存 | |
| with open("stl_color/logs.pkl", "wb") as fp: | |
| pickle.dump(result, fp) | |
| if __name__ == "__main__": | |
| train() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| from tqdm import tqdm | |
| import os | |
| import pickle | |
| import statistics | |
| def load_datasets(): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| trainset = torchvision.datasets.STL10(root="./data", | |
| split="unlabeled", | |
| download=True, | |
| transform=transform) | |
| train_loader = torch.utils.data.DataLoader(trainset, batch_size=512, | |
| shuffle=True, num_workers=4, pin_memory=True) | |
| return train_loader | |
| # 色空間変換用の定数 | |
| RGB2YCrCb = np.array([[0.299, 0.587, 0.114], | |
| [0.5, -0.418688, -0.081312], | |
| [-0.168736, -0.331264, 0.5]], np.float32) | |
| YCrCb2RGB = np.array([[1, 1.402, 0], | |
| [1, -0.714136, -0.344136], | |
| [1, 0, 1.772]], np.float32) | |
| RGB2YCrCb = torch.as_tensor(RGB2YCrCb.reshape(3, 3, 1, 1)).to("cuda") | |
| YCrCb2RGB = torch.as_tensor(YCrCb2RGB.reshape(3, 3, 1, 1)).to("cuda") | |
| def preprocess_generator(rgb_tensor): | |
| x = nn.functional.conv2d(rgb_tensor, RGB2YCrCb) # Yが0 - 1, CbCrが-0.5 - 0.5 | |
| x *= 2.0 # CbCrは-1 - 1になったのでOK、Yが0-2 | |
| x[:, 0,:,:] -= 1.0 # Yを-1 - 1にする | |
| return x | |
| def deprocess_generator(ycrcb_tensor): | |
| # inputが全て-1 - 1のスケールなので本来のスケールに直す | |
| x = ycrcb_tensor / 2.0 # 全て-0.5-0.5, CbCrはOK | |
| x[:, 0,:,:] += 0.5 # Yのスケールを0-1にする | |
| # RGBに変換 (0-1) | |
| return nn.functional.conv2d(x, YCrCb2RGB) | |
| class Generator(nn.Module): | |
| # input : 1x96x96 の Y [-1, 1] 本当は[0, 1] | |
| # output : 2x96x96 の CrCb [-1, 1] 本当は[-0.5, 0.5] | |
| def __init__(self): | |
| super().__init__() | |
| self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96 | |
| self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24 | |
| self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12 | |
| self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6 | |
| self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12 | |
| self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24 | |
| self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96 | |
| self.dec4 = nn.Sequential( | |
| nn.Conv2d(32 + 32, 2, kernel_size=5, padding=2), | |
| nn.Tanh() | |
| ) | |
| def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None): | |
| layers = [] | |
| if pool_kernel is not None: | |
| if pool_kernel > 0: | |
| layers.append(nn.AvgPool2d(pool_kernel)) | |
| elif pool_kernel < 0: | |
| layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel)) | |
| layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x1 = self.enc1(x) | |
| x2 = self.enc2(x1) | |
| x3 = self.enc3(x2) | |
| x4 = self.enc4(x3) | |
| out = self.dec1(x4) | |
| out = self.dec2(torch.cat([out, x3], dim=1)) | |
| out = self.dec3(torch.cat([out, x2], dim=1)) | |
| out = self.dec4(torch.cat([out, x1], dim=1)) | |
| return out | |
| class Discriminator(nn.Module): | |
| # Inputの色空間はYCrCb | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = self.conv_bn_relu(3, 16, kernel_size=5, reps=1) # YCrCb | |
| self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4) | |
| self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2) | |
| self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2) | |
| self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2) | |
| self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3 | |
| def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2): | |
| layers = [] | |
| for i in range(reps): | |
| if i == 0 and pool_kernel is not None: | |
| layers.append(nn.AvgPool2d(pool_kernel)) | |
| layers.append(nn.Conv2d(in_ch if i == 0 else out_ch, | |
| out_ch, kernel_size, padding=(kernel_size - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x))))) | |
| return self.out_patch(out) | |
| def train(): | |
| # モデル | |
| device = "cuda" | |
| torch.backends.cudnn.benchmark = True | |
| model_G, model_D = Generator(), Discriminator() | |
| model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D) | |
| model_G, model_D = model_G.to(device), model_D.to(device) | |
| params_G = torch.optim.Adam(model_G.parameters(), | |
| lr=0.0002, betas=(0.5, 0.999)) | |
| params_D = torch.optim.Adam(model_D.parameters(), | |
| lr=0.0002, betas=(0.5, 0.999)) | |
| # ロスを計算するためのラベル変数 (PatchGAN) | |
| ones = torch.ones(512, 1, 3, 3).to(device) | |
| zeros = torch.zeros(512, 1, 3, 3).to(device) | |
| # 損失関数 | |
| bce_loss = nn.BCEWithLogitsLoss() | |
| mae_loss = nn.L1Loss() | |
| # エラー推移 | |
| result = {} | |
| result["log_loss_G_sum"] = [] | |
| result["log_loss_G_bce"] = [] | |
| result["log_loss_G_mae"] = [] | |
| result["log_loss_D"] = [] | |
| # 訓練 | |
| dataset = load_datasets() | |
| for i in range(200): | |
| log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], [] | |
| for real_rgb, _ in tqdm(dataset): | |
| batch_len = len(real_rgb) | |
| real_rgb = real_rgb.to(device) | |
| real_ycrcb = preprocess_generator(real_rgb) | |
| # Gの訓練 | |
| # 偽のカラー画像を作成 | |
| fake_crcb = model_G(real_ycrcb[:,:1,:,:]) | |
| fake_ycrcb = torch.cat([real_ycrcb[:,:1,:,:], fake_crcb], dim=1) | |
| fake_rgb = deprocess_generator(fake_ycrcb) | |
| # 偽画像を一時保存 | |
| fake_ycrcb_tensor = fake_ycrcb.detach() | |
| # 偽画像を本物と騙せるようにロスを計算 | |
| out = model_D(fake_ycrcb) | |
| loss_G_bce = bce_loss(out, ones[:batch_len]) | |
| loss_G_mae = 75 * mae_loss(fake_crcb, real_ycrcb[:, 1:,:,:]) + 25 * mae_loss(fake_rgb, real_rgb) | |
| loss_G_sum = loss_G_bce + loss_G_mae | |
| log_loss_G_bce.append(loss_G_bce.item()) | |
| log_loss_G_mae.append(loss_G_mae.item()) | |
| log_loss_G_sum.append(loss_G_sum.item()) | |
| # 微分計算・重み更新 | |
| params_D.zero_grad() | |
| params_G.zero_grad() | |
| loss_G_sum.backward() | |
| params_G.step() | |
| # Discriminatoの訓練 | |
| # 本物のカラー画像を本物と識別できるようにロスを計算 | |
| real_out = model_D(real_ycrcb) | |
| loss_D_real = bce_loss(real_out, ones[:batch_len]) | |
| # 偽の画像の偽と識別できるようにロスを計算 | |
| fake_out = model_D(fake_ycrcb_tensor) | |
| loss_D_fake = bce_loss(fake_out, zeros[:batch_len]) | |
| # 実画像と偽画像のロスを合計 | |
| loss_D = loss_D_real + loss_D_fake | |
| log_loss_D.append(loss_D.item()) | |
| # 微分計算・重み更新 | |
| params_D.zero_grad() | |
| params_G.zero_grad() | |
| loss_D.backward() | |
| params_D.step() | |
| result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum)) | |
| result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce)) | |
| result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae)) | |
| result["log_loss_D"].append(statistics.mean(log_loss_D)) | |
| print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " + | |
| f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " + | |
| f"log_loss_D = {result['log_loss_D'][-1]}") | |
| # 画像を保存 | |
| if not os.path.exists("stl_color"): | |
| os.mkdir("stl_color") | |
| # 生成画像を保存 | |
| fake_rgb_tensor = deprocess_generator(fake_ycrcb_tensor) | |
| torchvision.utils.save_image(fake_rgb_tensor[:min(batch_len, 100)], | |
| f"stl_color/fake_epoch_{i:03}.png") | |
| torchvision.utils.save_image(real_rgb[:min(batch_len, 100)], | |
| f"stl_color/real_epoch_{i:03}.png") | |
| # モデルの保存 | |
| if not os.path.exists("stl_color/models"): | |
| os.mkdir("stl_color/models") | |
| if i % 10 == 0 or i == 199: | |
| torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch") | |
| torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch") | |
| # ログの保存 | |
| with open("stl_color/logs.pkl", "wb") as fp: | |
| pickle.dump(result, fp) | |
| if __name__ == "__main__": | |
| train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment