Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created July 28, 2019 15:06
Show Gist options
  • Save koshian2/47e01b575bc039e9aac3169ff9884630 to your computer and use it in GitHub Desktop.
Save koshian2/47e01b575bc039e9aac3169ff9884630 to your computer and use it in GitHub Desktop.
Coarse to fine model
import torch
import torchvision
from torchvision import transforms
from PIL import Image, ImageFilter
import numpy as np
from tqdm import tqdm
import glob
import os
def preprocess(n_validation=2048):
if os.path.exists("./waifus/checked"):
return
files = sorted(glob.glob("./waifus/images/*"))
# モノクロ画像を除去
for f in tqdm(files):
try:
with Image.open(f) as img:
img = img.convert("RGB").resize((64, 64), Image.LINEAR)
array = np.asarray(img, np.float32).reshape(-1, 3) / 255.0
corr = np.corrcoef(array, rowvar=False)
if np.mean(corr) >= 0.995:
# チャンネル間の相関係数が0.995以上ならモノクロ画像とみなす
os.remove(f)
except Exception as e:
os.remove(f)
# validationを作る
if not os.path.exists("./waifus_val/images"):
os.makedirs("./waifus_val/images")
files = sorted(glob.glob("./waifus/images/*"))
np.random.seed(123)
np.random.shuffle(files)
val_files = files[:n_validation]
for f in tqdm(val_files):
os.rename(f, f.replace("waifus", "waifus_val"))
with open("./waifus/checked", "w") as fp:
fp.write("")
class MultiInputWrapper(object):
def __init__(self, base_func):
"""
複数の入力に対してtransformsを使うためのラッパークラス
* base_func : 各入力に対して適用するtransformsの関数/クラス。関数またはlist(関数)。
"""
self.base_func = base_func
def __call__(self, xs):
if isinstance(self.base_func, list):
return [x if isinstance(x, tuple) else f(x)
for f, x in zip(self.base_func, xs)]
else:
return [x if isinstance(x, tuple) else self.base_func(x)
for x in xs]
class ColorAndGray(object):
def __call__(self, img):
# ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
gray = img.convert("L").convert("RGB") # 3チャンネルに戻す
return img, gray, img.size
class RealAndMosaic(object):
def __init__(self, mosaic_kernel, gaussian_blur_kernel):
self.mosaic_kernel = mosaic_kernel
self.gaussian_blur_kernel = gaussian_blur_kernel
def __call__(self, img):
width, height = img.size
# モザイクをかける部分
mosaic = img.resize([x // self.mosaic_kernel for x in img.size]).resize(img.size)
# ガウシアンぼかし
mosaic = mosaic.filter(ImageFilter.GaussianBlur(self.gaussian_blur_kernel))
return img, mosaic, img.size
def load_dataset(transform_type, batch_size):
"""
My waifuデータセットをロードします。
* transform_type = color : [(真の画像, 白黒画像, サイズ), ダミー]
* transform_type = mosaic : [(真の画像, モザイク, サイズ), ダミー]
"""
assert transform_type in ["color", "mosaic"]
# 前処理
preprocess()
if transform_type == "color":
transform = transforms.Compose([
transforms.Resize((256, 256), Image.LANCZOS),
ColorAndGray(),
MultiInputWrapper(transforms.ToTensor()),
MultiInputWrapper([
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
transforms.Normalize(mean=(0.5,), std=(0.5,)),
None
]),
])
elif transform_type == "mosaic":
transform = transforms.Compose([
transforms.Resize((256, 256), Image.LANCZOS),
RealAndMosaic(mosaic_kernel=16, gaussian_blur_kernel=4),
MultiInputWrapper(transforms.ToTensor()),
MultiInputWrapper(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
])
trainset = torchvision.datasets.ImageFolder("./waifus", transform=transform)
valset = torchvision.datasets.ImageFolder("./waifus_val", transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size, True, num_workers=4)
valloader = torch.utils.data.DataLoader(valset, batch_size, False, num_workers=4)
return trainloader, valloader
def test():
trainset, valset = load_dataset("mosaic", 16)
for (img_real, img_input, size), _ in tqdm(valset):
torchvision.utils.save_image(img_real, "real.png", nrow=4, normalize=True, range=(-1.0, 1.0))
torchvision.utils.save_image(img_input, "input.png", nrow=4, normalize=True, range=(-1.0, 1.0))
exit()
if __name__ == "__main__":
test()
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
class ConvBnAct(nn.Module):
def __init__(self, in_ch, out_ch, kernel=3, stride=1, act="relu"):
super().__init__()
layers = []
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=(kernel - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
if act == "relu":
layers.append(nn.ReLU(inplace=True))
elif act == "tanh":
layers.append(nn.Tanh())
self.main = nn.Sequential(*layers)
def forward(self, x):
return self.main(x)
class ResNetBlock(nn.Module):
def __init__(self, base_ch):
super().__init__()
self.conv1 = ConvBnAct(base_ch, base_ch)
self.conv2 = ConvBnAct(base_ch, base_ch)
def forward(self, inputs):
x = self.conv2(self.conv1(inputs))
return x + inputs
class Downsampling(nn.Module):
# 1x1 convによるダウンサンプリング
def __init__(self, in_ch, out_ch, pool_size):
super().__init__()
self.conv = ConvBnAct(in_ch, out_ch, kernel=1, stride=pool_size)
def forward(self, inputs):
return self.conv(inputs)
class Upsampling(nn.Module):
# baseをupsampling -> encoder側とconcat -> 1x1 conv
def __init__(self, base_ch, enc_ch, out_ch, scale_factor):
super().__init__()
self.scale_factor = scale_factor
self.conv = ConvBnAct(base_ch + enc_ch, out_ch, kernel=1)
def forward(self, base_inputs, enc_inputs):
x = F.upsample_nearest(base_inputs, scale_factor=self.scale_factor)
x = torch.cat([x, enc_inputs], dim=1)
return self.conv(x)
def create_subnetwork(base_ch, downsampling=True, upsampling=True, decoder=True):
# 入力~ブロック
conv_in = ConvBnAct(3, base_ch)
# Encoder
encoder_block = nn.Sequential(*[ResNetBlock(base_ch) for i in range(3)])
# ダウンサンプリング
if downsampling:
downsampling_block = Downsampling(base_ch, base_ch * 2, pool_size=2)
else:
downsampling_block = None
# アップサンプリング
if upsampling:
upsampling_block = Upsampling(base_ch * 2, base_ch, base_ch, 2)
else:
upsampling_block = None
# Decoder
if decoder:
decoder_block = ResNetBlock(base_ch)
else:
decoder_block = None
# 出力ブロック
conv_out = ConvBnAct(base_ch, 3, kernel=5, act="tanh")
return conv_in, encoder_block, downsampling_block, upsampling_block, decoder_block, conv_out
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class CoarseToFineModel(nn.Module):
def __init__(self):
super().__init__()
self.in1, self.enc1, self.down1, self.up1, self.dec1, self.out1 = create_subnetwork(16) # 256x256
self.in2, self.enc2, self.down2, self.up2, self.dec2, self.out2 = create_subnetwork(32) # 128x128
self.in3, self.enc3, self.down3, self.up3, self.dec3, self.out3 = create_subnetwork(64) # 64x64
self.in4, self.enc4, self.down4, self.up4, self.dec4, self.out4 = create_subnetwork(128) # 32x32
self.in5, self.enc5, self.down5, self.up5, self.dec5, self.out5 = create_subnetwork(256) # 16x16
self.in6, self.enc6, _, _, _, self.out6 = create_subnetwork(512) # 8x8
self.learning_stage = 0
def set_trainable(self, layer_index, trainable_flag):
if layer_index == 0:
layers = [self.in1, self.enc1, self.down1, self.up1, self.dec1, self.out1]
elif layer_index == 1:
layers = [self.in2, self.enc2, self.down2, self.up2, self.dec2, self.out2]
elif layer_index == 2:
layers = [self.in3, self.enc3, self.down3, self.up3, self.dec3, self.out3]
elif layer_index == 3:
layers = [self.in4, self.enc4, self.down4, self.up4, self.dec4, self.out4]
elif layer_index == 4:
layers = [self.in5, self.enc5, self.down5, self.up5, self.dec5, self.out5]
elif layer_index == 5:
layers = [self.in6, self.enc6, self.out6]
for l in layers:
for p in l.parameters():
p.requires_grad = trainable_flag
def switch_epoch(self, epoch, epoch_per_mode=50, burnin_epoch=10):
# burn_in epochが終わったら全再開
if epoch % epoch_per_mode == burnin_epoch:
print("---", epoch, ":", "全再開する")
for i in range(6):
self.set_trainable(i, True)
# burn_inの設定+訓練切り替え
elif epoch % epoch_per_mode == 0:
new_training_stage = 5 - epoch // epoch_per_mode
self.learning_stage = new_training_stage
print("---", epoch, ":", new_training_stage, "にステージ変更")
if epoch != 0:
self.set_trainable(new_training_stage + 1, False)
print("---", epoch, ":", new_training_stage+1, "の係数を固定")
def subblock_encoder_forward(self, main_inputs, encoder_layer, downsampling_layer, lower_input_layer, threshold):
if self.learning_stage <= threshold:
mid = encoder_layer(main_inputs)
x = downsampling_layer(mid)
else:
x = F.avg_pool2d(main_inputs, kernel_size=2)
if self.learning_stage == threshold + 1:
x = lower_input_layer(x)
mid = None
return x, mid
def subblock_decoder_foward(self, main_inputs, encoder_inputs, upsampling_layer, decoder_layer, output_layer, threshold):
if self.learning_stage <= threshold:
x = upsampling_layer(main_inputs, encoder_inputs) # decoderを入れること
x = decoder_layer(x)
if self.learning_stage == threshold:
x = output_layer(x)
else:
x = F.interpolate(main_inputs, scale_factor=2)
return x
def forward(self, inputs):
# Encoder
if self.learning_stage <= 0:
x = self.in1(inputs)
else:
x = inputs
x, mid1 = self.subblock_encoder_forward(x, self.enc1, self.down1, self.in2, 0)
x, mid2 = self.subblock_encoder_forward(x, self.enc2, self.down2, self.in3, 1)
x, mid3 = self.subblock_encoder_forward(x, self.enc3, self.down3, self.in4, 2)
x, mid4 = self.subblock_encoder_forward(x, self.enc4, self.down4, self.in5, 3)
x, mid5 = self.subblock_encoder_forward(x, self.enc5, self.down5, self.in6, 4)
x = self.enc6(x)
# Decoder
if self.learning_stage >= 5:
x = self.out6(x)
x = self.subblock_decoder_foward(x, mid5, self.up5, self.dec5, self.out5, 4)
x = self.subblock_decoder_foward(x, mid4, self.up4, self.dec4, self.out4, 3)
x = self.subblock_decoder_foward(x, mid3, self.up3, self.dec3, self.out3, 2)
x = self.subblock_decoder_foward(x, mid2, self.up2, self.dec2, self.out2, 1)
x = self.subblock_decoder_foward(x, mid1, self.up1, self.dec1, self.out1, 0)
return x
if __name__ == "__main__":
x = torch.randn(16, 3, 256, 256)
model = CoarseToFineModel()
model.learning_stage = 0
y = model(x)
torchvision.utils.save_image(y, "hoge.png", nrow=4, padding=10, normalize=True, range=(-1.0,1.0))
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
class ConvBnAct(nn.Module):
def __init__(self, in_ch, out_ch, kernel=3, stride=1, act="relu"):
super().__init__()
layers = []
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=(kernel - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
if act == "relu":
layers.append(nn.ReLU(inplace=True))
elif act == "tanh":
layers.append(nn.Tanh())
self.main = nn.Sequential(*layers)
def forward(self, x):
return self.main(x)
class ResNetBlock(nn.Module):
def __init__(self, base_ch):
super().__init__()
self.conv1 = ConvBnAct(base_ch, base_ch)
self.conv2 = ConvBnAct(base_ch, base_ch)
def forward(self, inputs):
x = self.conv2(self.conv1(inputs))
return x + inputs
class Downsampling(nn.Module):
# 1x1 convによるダウンサンプリング
def __init__(self, in_ch, out_ch, pool_size):
super().__init__()
self.conv = ConvBnAct(in_ch, out_ch, kernel=1, stride=pool_size)
def forward(self, inputs):
return self.conv(inputs)
class Upsampling(nn.Module):
# baseをupsampling -> 1x1 conv
def __init__(self, base_ch, out_ch, scale_factor):
super().__init__()
self.scale_factor = scale_factor
self.conv = ConvBnAct(base_ch, out_ch, kernel=1)
def forward(self, base_inputs):
x = F.upsample_nearest(base_inputs, scale_factor=self.scale_factor)
return self.conv(x)
def create_subnetwork(base_ch, downsampling=True, upsampling=True, decoder=True):
# 入力~ブロック
conv_in = ConvBnAct(3, base_ch)
# Encoder
encoder_block = nn.Sequential(*[ResNetBlock(base_ch) for i in range(3)])
# ダウンサンプリング
if downsampling:
downsampling_block = Downsampling(base_ch, base_ch * 2, pool_size=2)
else:
downsampling_block = None
# アップサンプリング
if upsampling:
upsampling_block = Upsampling(base_ch * 2, base_ch, 2)
else:
upsampling_block = None
# Decoder
if decoder:
decoder_block = ResNetBlock(base_ch)
else:
decoder_block = None
# 出力ブロック
conv_out = ConvBnAct(base_ch, 3, kernel=5, act="tanh")
return conv_in, encoder_block, downsampling_block, upsampling_block, decoder_block, conv_out
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class CoarseToFineAutoEncoderModel(nn.Module):
def __init__(self):
super().__init__()
self.in1, self.enc1, self.down1, self.up1, self.dec1, self.out1 = create_subnetwork(16) # 256x256
self.in2, self.enc2, self.down2, self.up2, self.dec2, self.out2 = create_subnetwork(32) # 128x128
self.in3, self.enc3, self.down3, self.up3, self.dec3, self.out3 = create_subnetwork(64) # 64x64
self.in4, self.enc4, self.down4, self.up4, self.dec4, self.out4 = create_subnetwork(128) # 32x32
self.in5, self.enc5, self.down5, self.up5, self.dec5, self.out5 = create_subnetwork(256) # 16x16
self.in6, self.enc6, _, _, _, self.out6 = create_subnetwork(512) # 8x8
self.learning_stage = 0
def set_trainable(self, layer_index, trainable_flag):
if layer_index == 0:
layers = [self.in1, self.enc1, self.down1, self.up1, self.dec1, self.out1]
elif layer_index == 1:
layers = [self.in2, self.enc2, self.down2, self.up2, self.dec2, self.out2]
elif layer_index == 2:
layers = [self.in3, self.enc3, self.down3, self.up3, self.dec3, self.out3]
elif layer_index == 3:
layers = [self.in4, self.enc4, self.down4, self.up4, self.dec4, self.out4]
elif layer_index == 4:
layers = [self.in5, self.enc5, self.down5, self.up5, self.dec5, self.out5]
elif layer_index == 5:
layers = [self.in6, self.enc6, self.out6]
for l in layers:
for p in l.parameters():
p.requires_grad = trainable_flag
def switch_epoch(self, epoch, epoch_per_mode=50, burnin_epoch=10):
# burn_in epochが終わったら全再開
if epoch % epoch_per_mode == burnin_epoch:
print("---", epoch, ":", "全再開する")
for i in range(6):
self.set_trainable(i, True)
# burn_inの設定+訓練切り替え
elif epoch % epoch_per_mode == 0:
new_training_stage = 5 - epoch // epoch_per_mode
self.learning_stage = new_training_stage
print("---", epoch, ":", new_training_stage, "にステージ変更")
if epoch != 0:
self.set_trainable(new_training_stage + 1, False)
print("---", epoch, ":", new_training_stage+1, "の係数を固定")
def subblock_encoder_forward(self, main_inputs, encoder_layer, downsampling_layer, lower_input_layer, threshold):
if self.learning_stage <= threshold:
mid = encoder_layer(main_inputs)
x = downsampling_layer(mid)
else:
x = F.avg_pool2d(main_inputs, kernel_size=2)
if self.learning_stage == threshold + 1:
x = lower_input_layer(x)
mid = None
return x, mid
def subblock_decoder_foward(self, main_inputs, encoder_inputs, upsampling_layer, decoder_layer, output_layer, threshold):
if self.learning_stage <= threshold:
x = upsampling_layer(main_inputs)
x = decoder_layer(x)
if self.learning_stage == threshold:
x = output_layer(x)
else:
x = F.interpolate(main_inputs, scale_factor=2)
return x
def forward(self, inputs):
# Encoder
if self.learning_stage <= 0:
x = self.in1(inputs)
else:
x = inputs
x, mid1 = self.subblock_encoder_forward(x, self.enc1, self.down1, self.in2, 0)
x, mid2 = self.subblock_encoder_forward(x, self.enc2, self.down2, self.in3, 1)
x, mid3 = self.subblock_encoder_forward(x, self.enc3, self.down3, self.in4, 2)
x, mid4 = self.subblock_encoder_forward(x, self.enc4, self.down4, self.in5, 3)
x, mid5 = self.subblock_encoder_forward(x, self.enc5, self.down5, self.in6, 4)
x = self.enc6(x)
# Decoder
if self.learning_stage >= 5:
x = self.out6(x)
x = self.subblock_decoder_foward(x, mid5, self.up5, self.dec5, self.out5, 4)
x = self.subblock_decoder_foward(x, mid4, self.up4, self.dec4, self.out4, 3)
x = self.subblock_decoder_foward(x, mid3, self.up3, self.dec3, self.out3, 2)
x = self.subblock_decoder_foward(x, mid2, self.up2, self.dec2, self.out2, 1)
x = self.subblock_decoder_foward(x, mid1, self.up1, self.dec1, self.out1, 0)
return x
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
class ConvBnAct(nn.Module):
def __init__(self, in_ch, out_ch, kernel=3, stride=1, act="relu"):
super().__init__()
layers = []
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=(kernel - 1) // 2))
layers.append(nn.BatchNorm2d(out_ch))
if act == "relu":
layers.append(nn.ReLU(inplace=True))
elif act == "tanh":
layers.append(nn.Tanh())
self.main = nn.Sequential(*layers)
def forward(self, x):
return self.main(x)
class ResNetBlock(nn.Module):
def __init__(self, base_ch):
super().__init__()
self.conv1 = ConvBnAct(base_ch, base_ch)
self.conv2 = ConvBnAct(base_ch, base_ch)
def forward(self, inputs):
x = self.conv2(self.conv1(inputs))
return x + inputs
class Downsampling(nn.Module):
# 1x1 convによるダウンサンプリング
def __init__(self, in_ch, out_ch, pool_size):
super().__init__()
self.conv = ConvBnAct(in_ch, out_ch, kernel=1, stride=pool_size)
def forward(self, inputs):
return self.conv(inputs)
class Upsampling(nn.Module):
# baseをupsampling -> encoder側とconcat -> 1x1 conv
def __init__(self, base_ch, enc_ch, out_ch, scale_factor):
super().__init__()
self.scale_factor = scale_factor
self.conv = ConvBnAct(base_ch + enc_ch, out_ch, kernel=1)
def forward(self, base_inputs, enc_inputs):
x = F.upsample_nearest(base_inputs, scale_factor=self.scale_factor)
x = torch.cat([x, enc_inputs], dim=1)
return self.conv(x)
def create_subnetwork(base_ch, downsampling=True, upsampling=True, decoder=True):
# 入力~ブロック
conv_in = ConvBnAct(3, base_ch)
# Encoder
encoder_block = nn.Sequential(*[ResNetBlock(base_ch) for i in range(3)])
# ダウンサンプリング
if downsampling:
downsampling_block = Downsampling(base_ch, base_ch * 2, pool_size=2)
else:
downsampling_block = None
# アップサンプリング
if upsampling:
upsampling_block = Upsampling(base_ch * 2, base_ch, base_ch, 2)
else:
upsampling_block = None
# Decoder
if decoder:
decoder_block = ResNetBlock(base_ch)
else:
decoder_block = None
# 出力ブロック
conv_out = ConvBnAct(base_ch, 3, kernel=5, act="tanh")
return conv_in, encoder_block, downsampling_block, upsampling_block, decoder_block, conv_out
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Unet(nn.Module):
def __init__(self):
super().__init__()
self.enc1 = self.encoder_block(3, 16, downsampling=False) # 256x256
self.enc2 = self.encoder_block(16, 32) # 128x128
self.enc3 = self.encoder_block(32, 64) # 64x64
self.enc4 = self.encoder_block(64, 128) # 32x32
self.enc5 = self.encoder_block(128, 256) # 16x16
self.enc6 = self.encoder_block(256, 512) # 8x8
self.up5 = Upsampling(512, 256, 256, 2)
self.up4 = Upsampling(256, 128, 128, 2)
self.up3 = Upsampling(128, 64, 64, 2)
self.up2 = Upsampling(64, 32, 32, 2)
self.up1 = Upsampling(32, 16, 16, 2)
self.dec5 = ResNetBlock(256)
self.dec4 = ResNetBlock(128)
self.dec3 = ResNetBlock(64)
self.dec2 = ResNetBlock(32)
self.dec1 = ResNetBlock(16)
self.out = ConvBnAct(16, 3, act="tanh")
def encoder_block(self, old_ch, base_ch, downsampling=True):
return nn.Sequential(
ConvBnAct(old_ch, base_ch, stride=2 if downsampling else 1, kernel=1),
ResNetBlock(base_ch),
ResNetBlock(base_ch),
ResNetBlock(base_ch)
)
def decoder_block(self, base_ch, enc_ch, out_ch):
return nn.Sequential(
Upsampling(base_ch, enc_ch, out_ch),
ResNetBlock(out_ch),
)
def forward(self, inputs):
mid1 = self.enc1(inputs)
mid2 = self.enc2(mid1)
mid3 = self.enc3(mid2)
mid4 = self.enc4(mid3)
mid5 = self.enc5(mid4)
x = self.enc6(mid5)
x = self.dec5(self.up5(x, mid5))
x = self.dec4(self.up4(x, mid4))
x = self.dec3(self.up3(x, mid3))
x = self.dec2(self.up2(x, mid2))
x = self.dec1(self.up1(x, mid1))
return self.out(x)
import torch
import torch.nn.functional as F
import torchvision
from dataloader import load_dataset
from models import CoarseToFineModel
from models_ae import CoarseToFineAutoEncoderModel
from models_unet import Unet
import os
import pickle
import statistics
from tqdm import tqdm
import time
def train(transform_type):
trainloader, testloader = load_dataset(transform_type, 64)
model = CoarseToFineModel().to("cuda") # ここを切り替える
# model = CoarseToFineAutoEncoderModel().to("cuda")
# model = Unet().to("cuda")
print(model)
model = torch.nn.DataParallel(model)
loss_l1 = torch.nn.L1Loss()
loss_mse = torch.nn.MSELoss()
params = torch.optim.Adam(model.parameters(), lr=1e-3)
result = {"loss":[], "psnr":[], "val_psnr":[], "time":[]}
for i in range(300):
# 訓練フェーズの更新
model.module.switch_epoch(i) # ここを切り替える
log_loss = []
log_psnr = []
log_val_psnr = []
start_time = time.time()
with tqdm(trainloader) as pbar:
for (img_real, img_input, size), _ in pbar:
# fore-prop
img_real, img_input = img_real.to("cuda"), img_input.to("cuda")
params.zero_grad()
y_pred = model(img_input)
# Coarse to fineにあわせてrealの解像度をあわせる
#coarse_ratio = 2 ** model.module.learning_stage
#if coarse_ratio > 1:
# y_true = F.interpolate(F.avg_pool2d(img_real, kernel_size=coarse_ratio), scale_factor=coarse_ratio)
#else:
# y_true = img_real
y_true = img_real
# L1ロスで最適化
loss = loss_l1(y_pred, y_true)
# PSNR
psnr = 10.0 * torch.log10(2.0 ** 2 / (loss_mse(y_pred, img_real) + 1e-10) )
# back prop
loss.backward()
params.step()
log_loss.append(loss.item())
log_psnr.append(psnr.item())
pbar.set_postfix(epoch=i, loss=statistics.mean(log_loss), psnr=statistics.mean(log_psnr))
if not os.path.exists(transform_type):
os.mkdir(transform_type)
batch_len = min(len(img_input), len(y_pred), len(y_true), 16)
x = torch.cat([img_input[:batch_len], y_pred[:batch_len], y_true[:batch_len]], dim=0)
torchvision.utils.save_image(x, transform_type + f"/epoch_{i:03}.png", nrow=8, padding=10,
normalize=True, range=(-1.0, 1.0))
with tqdm(testloader) as pbar:
for (img_real, img_input, size), _ in pbar:
img_real, img_input = img_real.to("cuda"), img_input.to("cuda")
with torch.no_grad():
y_pred = model(img_input)
psnr = 10.0 * torch.log10(2.0 ** 2 / (loss_mse(y_pred, img_real) + 1e-10) )
log_val_psnr.append(psnr.item())
pbar.set_postfix(epoch=i, val_psnr=statistics.mean(log_val_psnr))
result["loss"].append(statistics.mean(log_loss))
result["psnr"].append(statistics.mean(log_psnr))
result["val_psnr"].append(statistics.mean(log_val_psnr))
result["time"].append(time.time()-start_time)
with open(transform_type + "/logs.pkl", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
train("mosaic")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment