Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created June 27, 2019 09:11
Show Gist options
  • Save koshian2/4b00bd6d2453450c1d7ea703f2218ca3 to your computer and use it in GitHub Desktop.
Save koshian2/4b00bd6d2453450c1d7ea703f2218ca3 to your computer and use it in GitHub Desktop.
Pix2pix STL Colorize
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import os
import pickle
import statistics
class ColorAndGray(object):
def __call__(self, img):
# ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
gray = img.convert("L")
return img, gray
# 複数の入力をtransformsに展開するラッパークラスを作る
class MultiInputWrapper(object):
def __init__(self, base_func):
self.base_func = base_func
def __call__(self, xs):
if isinstance(self.base_func, list):
return [f(x) for f, x in zip(self.base_func, xs)]
else:
return [self.base_func(x) for x in xs]
def load_datasets():
transform = transforms.Compose([
ColorAndGray(),
MultiInputWrapper(transforms.ToTensor()),
MultiInputWrapper([
transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
])
trainset = torchvision.datasets.STL10(root="./data",
split="unlabeled",
download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=512,
shuffle=True, num_workers=4, pin_memory=True)
return train_loader
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96
self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24
self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12
self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6
self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12
self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24
self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96
self.dec4 = nn.Sequential(
nn.Conv2d(32 + 32, 3, kernel_size=5, padding=2),
nn.Tanh()
)
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None):
layers = []
if pool_kernel is not None:
if pool_kernel > 0:
layers.append(nn.AvgPool2d(pool_kernel))
elif pool_kernel < 0:
layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel))
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
out = self.dec1(x4)
out = self.dec2(torch.cat([out, x3], dim=1))
out = self.dec3(torch.cat([out, x2], dim=1))
out = self.dec4(torch.cat([out, x1], dim=1))
return out
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_relu(4, 16, kernel_size=5, reps=1) # fake/true color + gray
self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4)
self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2)
self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2)
self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2)
self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2):
layers = []
for i in range(reps):
if i == 0 and pool_kernel is not None:
layers.append(nn.AvgPool2d(pool_kernel))
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch,
out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
return self.out_patch(out)
def train():
# モデル
device = "cuda"
torch.backends.cudnn.benchmark = True
model_G, model_D = Generator(), Discriminator()
model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D)
model_G, model_D = model_G.to(device), model_D.to(device)
params_G = torch.optim.Adam(model_G.parameters(),
lr=0.0002, betas=(0.5, 0.999))
params_D = torch.optim.Adam(model_D.parameters(),
lr=0.0002, betas=(0.5, 0.999))
# ロスを計算するためのラベル変数 (PatchGAN)
ones = torch.ones(512, 1, 3, 3).to(device)
zeros = torch.zeros(512, 1, 3, 3).to(device)
# 損失関数
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()
# エラー推移
result = {}
result["log_loss_G_sum"] = []
result["log_loss_G_bce"] = []
result["log_loss_G_mae"] = []
result["log_loss_D"] = []
# 訓練
dataset = load_datasets()
for i in range(200):
log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], []
for (real_color, input_gray), _ in tqdm(dataset):
batch_len = len(real_color)
real_color, input_gray = real_color.to(device), input_gray.to(device)
# Gの訓練
# 偽のカラー画像を作成
fake_color = model_G(input_gray)
# 偽画像を一時保存
fake_color_tensor = fake_color.detach()
# 偽画像を本物と騙せるようにロスを計算
LAMBD = 100.0 # BCEとMAEの係数
out = model_D(torch.cat([fake_color, input_gray], dim=1))
loss_G_bce = bce_loss(out, ones[:batch_len])
loss_G_mae = LAMBD * mae_loss(fake_color, real_color)
loss_G_sum = loss_G_bce + loss_G_mae
log_loss_G_bce.append(loss_G_bce.item())
log_loss_G_mae.append(loss_G_mae.item())
log_loss_G_sum.append(loss_G_sum.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_G_sum.backward()
params_G.step()
# Discriminatoの訓練
# 本物のカラー画像を本物と識別できるようにロスを計算
real_out = model_D(torch.cat([real_color, input_gray], dim=1))
loss_D_real = bce_loss(real_out, ones[:batch_len])
# 偽の画像の偽と識別できるようにロスを計算
fake_out = model_D(torch.cat([fake_color_tensor, input_gray], dim=1))
loss_D_fake = bce_loss(fake_out, zeros[:batch_len])
# 実画像と偽画像のロスを合計
loss_D = loss_D_real + loss_D_fake
log_loss_D.append(loss_D.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_D.backward()
params_D.step()
result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum))
result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce))
result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae))
result["log_loss_D"].append(statistics.mean(log_loss_D))
print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " +
f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " +
f"log_loss_D = {result['log_loss_D'][-1]}")
# 画像を保存
if not os.path.exists("stl_color"):
os.mkdir("stl_color")
# 生成画像を保存
torchvision.utils.save_image(fake_color_tensor[:min(batch_len, 100)],
f"stl_color/fake_epoch_{i:03}.png",
range=(-1.0,1.0), normalize=True)
torchvision.utils.save_image(real_color[:min(batch_len, 100)],
f"stl_color/real_epoch_{i:03}.png",
range=(-1.0, 1.0), normalize=True)
# モデルの保存
if not os.path.exists("stl_color/models"):
os.mkdir("stl_color/models")
if i % 10 == 0 or i == 199:
torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch")
torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch")
# ログの保存
with open("stl_color/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
train()
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import os
import pickle
import statistics
def load_datasets():
transform = transforms.Compose([
transforms.ToTensor(),
])
trainset = torchvision.datasets.STL10(root="./data",
split="unlabeled",
download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=512,
shuffle=True, num_workers=4, pin_memory=True)
return train_loader
# 色空間変換用の定数
RGB2YCrCb = np.array([[0.299, 0.587, 0.114],
[0.5, -0.418688, -0.081312],
[-0.168736, -0.331264, 0.5]], np.float32)
YCrCb2RGB = np.array([[1, 1.402, 0],
[1, -0.714136, -0.344136],
[1, 0, 1.772]], np.float32)
RGB2YCrCb = torch.as_tensor(RGB2YCrCb.reshape(3, 3, 1, 1)).to("cuda")
YCrCb2RGB = torch.as_tensor(YCrCb2RGB.reshape(3, 3, 1, 1)).to("cuda")
def preprocess_generator(rgb_tensor):
x = nn.functional.conv2d(rgb_tensor, RGB2YCrCb) # Yが0 - 1, CbCrが-0.5 - 0.5
x *= 2.0 # CbCrは-1 - 1になったのでOK、Yが0-2
x[:, 0,:,:] -= 1.0 # Yを-1 - 1にする
return x
def deprocess_generator(ycrcb_tensor):
# inputが全て-1 - 1のスケールなので本来のスケールに直す
x = ycrcb_tensor / 2.0 # 全て-0.5-0.5, CbCrはOK
x[:, 0,:,:] += 0.5 # Yのスケールを0-1にする
# RGBに変換 (0-1)
return nn.functional.conv2d(x, YCrCb2RGB)
class Generator(nn.Module):
# input : 1x96x96 の Y [-1, 1] 本当は[0, 1]
# output : 2x96x96 の CrCb [-1, 1] 本当は[-0.5, 0.5]
def __init__(self):
super().__init__()
self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96
self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24
self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12
self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6
self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12
self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24
self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96
self.dec4 = nn.Sequential(
nn.Conv2d(32 + 32, 2, kernel_size=5, padding=2),
nn.Tanh()
)
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None):
layers = []
if pool_kernel is not None:
if pool_kernel > 0:
layers.append(nn.AvgPool2d(pool_kernel))
elif pool_kernel < 0:
layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel))
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
out = self.dec1(x4)
out = self.dec2(torch.cat([out, x3], dim=1))
out = self.dec3(torch.cat([out, x2], dim=1))
out = self.dec4(torch.cat([out, x1], dim=1))
return out
class Discriminator(nn.Module):
# Inputの色空間はYCrCb→RGBにする
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_relu(3, 16, kernel_size=5, reps=1) # RGB
self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4)
self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2)
self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2)
self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2)
self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2):
layers = []
for i in range(reps):
if i == 0 and pool_kernel is not None:
layers.append(nn.AvgPool2d(pool_kernel))
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch,
out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
return self.out_patch(out)
def train():
# モデル
device = "cuda"
torch.backends.cudnn.benchmark = True
model_G, model_D = Generator(), Discriminator()
model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D)
model_G, model_D = model_G.to(device), model_D.to(device)
params_G = torch.optim.Adam(model_G.parameters(),
lr=0.0002, betas=(0.5, 0.999))
params_D = torch.optim.Adam(model_D.parameters(),
lr=0.0002, betas=(0.5, 0.999))
# ロスを計算するためのラベル変数 (PatchGAN)
ones = torch.ones(512, 1, 3, 3).to(device)
zeros = torch.zeros(512, 1, 3, 3).to(device)
# 損失関数
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()
# エラー推移
result = {}
result["log_loss_G_sum"] = []
result["log_loss_G_bce"] = []
result["log_loss_G_mae"] = []
result["log_loss_D"] = []
# 訓練
dataset = load_datasets()
for i in range(200):
log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], []
for real_rgb, _ in tqdm(dataset):
batch_len = len(real_rgb)
real_rgb = real_rgb.to(device)
real_ycrcb = preprocess_generator(real_rgb)
# Gの訓練
# 偽のカラー画像を作成
fake_crcb = model_G(real_ycrcb[:,:1,:,:])
fake_ycrcb = torch.cat([real_ycrcb[:,:1,:,:], fake_crcb], dim=1)
fake_rgb = deprocess_generator(fake_ycrcb)
# 偽画像を一時保存
fake_rgb_tensor = fake_rgb.detach()
# 偽画像を本物と騙せるようにロスを計算
out = model_D(fake_rgb)
loss_G_bce = bce_loss(out, ones[:batch_len])
loss_G_mae = 75 * mae_loss(fake_crcb, real_ycrcb[:, 1:,:,:]) + 25 * mae_loss(fake_rgb, real_rgb)
loss_G_sum = loss_G_bce + loss_G_mae
log_loss_G_bce.append(loss_G_bce.item())
log_loss_G_mae.append(loss_G_mae.item())
log_loss_G_sum.append(loss_G_sum.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_G_sum.backward()
params_G.step()
# Discriminatoの訓練
# 本物のカラー画像を本物と識別できるようにロスを計算
real_out = model_D(real_rgb)
loss_D_real = bce_loss(real_out, ones[:batch_len])
# 偽の画像の偽と識別できるようにロスを計算
fake_out = model_D(fake_rgb_tensor)
loss_D_fake = bce_loss(fake_out, zeros[:batch_len])
# 実画像と偽画像のロスを合計
loss_D = loss_D_real + loss_D_fake
log_loss_D.append(loss_D.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_D.backward()
params_D.step()
result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum))
result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce))
result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae))
result["log_loss_D"].append(statistics.mean(log_loss_D))
print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " +
f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " +
f"log_loss_D = {result['log_loss_D'][-1]}")
# 画像を保存
if not os.path.exists("stl_color"):
os.mkdir("stl_color")
# 生成画像を保存
torchvision.utils.save_image(fake_rgb_tensor[:min(batch_len, 100)],
f"stl_color/fake_epoch_{i:03}.png")
torchvision.utils.save_image(real_rgb[:min(batch_len, 100)],
f"stl_color/real_epoch_{i:03}.png")
# モデルの保存
if not os.path.exists("stl_color/models"):
os.mkdir("stl_color/models")
if i % 10 == 0 or i == 199:
torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch")
torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch")
# ログの保存
with open("stl_color/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
train()
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
import os
import pickle
import statistics
def load_datasets():
transform = transforms.Compose([
transforms.ToTensor(),
])
trainset = torchvision.datasets.STL10(root="./data",
split="unlabeled",
download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=512,
shuffle=True, num_workers=4, pin_memory=True)
return train_loader
# 色空間変換用の定数
RGB2YCrCb = np.array([[0.299, 0.587, 0.114],
[0.5, -0.418688, -0.081312],
[-0.168736, -0.331264, 0.5]], np.float32)
YCrCb2RGB = np.array([[1, 1.402, 0],
[1, -0.714136, -0.344136],
[1, 0, 1.772]], np.float32)
RGB2YCrCb = torch.as_tensor(RGB2YCrCb.reshape(3, 3, 1, 1)).to("cuda")
YCrCb2RGB = torch.as_tensor(YCrCb2RGB.reshape(3, 3, 1, 1)).to("cuda")
def preprocess_generator(rgb_tensor):
x = nn.functional.conv2d(rgb_tensor, RGB2YCrCb) # Yが0 - 1, CbCrが-0.5 - 0.5
x *= 2.0 # CbCrは-1 - 1になったのでOK、Yが0-2
x[:, 0,:,:] -= 1.0 # Yを-1 - 1にする
return x
def deprocess_generator(ycrcb_tensor):
# inputが全て-1 - 1のスケールなので本来のスケールに直す
x = ycrcb_tensor / 2.0 # 全て-0.5-0.5, CbCrはOK
x[:, 0,:,:] += 0.5 # Yのスケールを0-1にする
# RGBに変換 (0-1)
return nn.functional.conv2d(x, YCrCb2RGB)
class Generator(nn.Module):
# input : 1x96x96 の Y [-1, 1] 本当は[0, 1]
# output : 2x96x96 の CrCb [-1, 1] 本当は[-0.5, 0.5]
def __init__(self):
super().__init__()
self.enc1 = self.conv_bn_relu(1, 32, kernel_size=5) # 32x96x96
self.enc2 = self.conv_bn_relu(32, 64, kernel_size=3, pool_kernel=4) # 64x24x24
self.enc3 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=2) # 128x12x12
self.enc4 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2) # 256x6x6
self.dec1 = self.conv_bn_relu(256, 128, kernel_size=3, pool_kernel=-2) # 128x12x12
self.dec2 = self.conv_bn_relu(128 + 128, 64, kernel_size=3, pool_kernel=-2) # 64x24x24
self.dec3 = self.conv_bn_relu(64 + 64, 32, kernel_size=3, pool_kernel=-4) # 32x96x96
self.dec4 = nn.Sequential(
nn.Conv2d(32 + 32, 2, kernel_size=5, padding=2),
nn.Tanh()
)
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None):
layers = []
if pool_kernel is not None:
if pool_kernel > 0:
layers.append(nn.AvgPool2d(pool_kernel))
elif pool_kernel < 0:
layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel))
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
out = self.dec1(x4)
out = self.dec2(torch.cat([out, x3], dim=1))
out = self.dec3(torch.cat([out, x2], dim=1))
out = self.dec4(torch.cat([out, x1], dim=1))
return out
class Discriminator(nn.Module):
# Inputの色空間はYCrCb
def __init__(self):
super().__init__()
self.conv1 = self.conv_bn_relu(3, 16, kernel_size=5, reps=1) # YCrCb
self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4)
self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2)
self.conv4 = self.conv_bn_relu(64, 128, pool_kernel=2)
self.conv5 = self.conv_bn_relu(128, 256, pool_kernel=2)
self.out_patch = nn.Conv2d(256, 1, kernel_size=1) #1x3x3
def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2):
layers = []
for i in range(reps):
if i == 0 and pool_kernel is not None:
layers.append(nn.AvgPool2d(pool_kernel))
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch,
out_ch, kernel_size, padding=(kernel_size - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
return self.out_patch(out)
def train():
# モデル
device = "cuda"
torch.backends.cudnn.benchmark = True
model_G, model_D = Generator(), Discriminator()
model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D)
model_G, model_D = model_G.to(device), model_D.to(device)
params_G = torch.optim.Adam(model_G.parameters(),
lr=0.0002, betas=(0.5, 0.999))
params_D = torch.optim.Adam(model_D.parameters(),
lr=0.0002, betas=(0.5, 0.999))
# ロスを計算するためのラベル変数 (PatchGAN)
ones = torch.ones(512, 1, 3, 3).to(device)
zeros = torch.zeros(512, 1, 3, 3).to(device)
# 損失関数
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()
# エラー推移
result = {}
result["log_loss_G_sum"] = []
result["log_loss_G_bce"] = []
result["log_loss_G_mae"] = []
result["log_loss_D"] = []
# 訓練
dataset = load_datasets()
for i in range(200):
log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], []
for real_rgb, _ in tqdm(dataset):
batch_len = len(real_rgb)
real_rgb = real_rgb.to(device)
real_ycrcb = preprocess_generator(real_rgb)
# Gの訓練
# 偽のカラー画像を作成
fake_crcb = model_G(real_ycrcb[:,:1,:,:])
fake_ycrcb = torch.cat([real_ycrcb[:,:1,:,:], fake_crcb], dim=1)
fake_rgb = deprocess_generator(fake_ycrcb)
# 偽画像を一時保存
fake_ycrcb_tensor = fake_ycrcb.detach()
# 偽画像を本物と騙せるようにロスを計算
out = model_D(fake_ycrcb)
loss_G_bce = bce_loss(out, ones[:batch_len])
loss_G_mae = 75 * mae_loss(fake_crcb, real_ycrcb[:, 1:,:,:]) + 25 * mae_loss(fake_rgb, real_rgb)
loss_G_sum = loss_G_bce + loss_G_mae
log_loss_G_bce.append(loss_G_bce.item())
log_loss_G_mae.append(loss_G_mae.item())
log_loss_G_sum.append(loss_G_sum.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_G_sum.backward()
params_G.step()
# Discriminatoの訓練
# 本物のカラー画像を本物と識別できるようにロスを計算
real_out = model_D(real_ycrcb)
loss_D_real = bce_loss(real_out, ones[:batch_len])
# 偽の画像の偽と識別できるようにロスを計算
fake_out = model_D(fake_ycrcb_tensor)
loss_D_fake = bce_loss(fake_out, zeros[:batch_len])
# 実画像と偽画像のロスを合計
loss_D = loss_D_real + loss_D_fake
log_loss_D.append(loss_D.item())
# 微分計算・重み更新
params_D.zero_grad()
params_G.zero_grad()
loss_D.backward()
params_D.step()
result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum))
result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce))
result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae))
result["log_loss_D"].append(statistics.mean(log_loss_D))
print(f"log_loss_G_sum = {result['log_loss_G_sum'][-1]} " +
f"({result['log_loss_G_bce'][-1]}, {result['log_loss_G_mae'][-1]}) " +
f"log_loss_D = {result['log_loss_D'][-1]}")
# 画像を保存
if not os.path.exists("stl_color"):
os.mkdir("stl_color")
# 生成画像を保存
fake_rgb_tensor = deprocess_generator(fake_ycrcb_tensor)
torchvision.utils.save_image(fake_rgb_tensor[:min(batch_len, 100)],
f"stl_color/fake_epoch_{i:03}.png")
torchvision.utils.save_image(real_rgb[:min(batch_len, 100)],
f"stl_color/real_epoch_{i:03}.png")
# モデルの保存
if not os.path.exists("stl_color/models"):
os.mkdir("stl_color/models")
if i % 10 == 0 or i == 199:
torch.save(model_G.state_dict(), f"stl_color/models/gen_{i:03}.pytorch")
torch.save(model_D.state_dict(), f"stl_color/models/dis_{i:03}.pytorch")
# ログの保存
with open("stl_color/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment