Created
July 28, 2019 15:06
-
-
Save koshian2/47e01b575bc039e9aac3169ff9884630 to your computer and use it in GitHub Desktop.
Coarse to fine model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torchvision | |
| from torchvision import transforms | |
| from PIL import Image, ImageFilter | |
| import numpy as np | |
| from tqdm import tqdm | |
| import glob | |
| import os | |
| def preprocess(n_validation=2048): | |
| if os.path.exists("./waifus/checked"): | |
| return | |
| files = sorted(glob.glob("./waifus/images/*")) | |
| # モノクロ画像を除去 | |
| for f in tqdm(files): | |
| try: | |
| with Image.open(f) as img: | |
| img = img.convert("RGB").resize((64, 64), Image.LINEAR) | |
| array = np.asarray(img, np.float32).reshape(-1, 3) / 255.0 | |
| corr = np.corrcoef(array, rowvar=False) | |
| if np.mean(corr) >= 0.995: | |
| # チャンネル間の相関係数が0.995以上ならモノクロ画像とみなす | |
| os.remove(f) | |
| except Exception as e: | |
| os.remove(f) | |
| # validationを作る | |
| if not os.path.exists("./waifus_val/images"): | |
| os.makedirs("./waifus_val/images") | |
| files = sorted(glob.glob("./waifus/images/*")) | |
| np.random.seed(123) | |
| np.random.shuffle(files) | |
| val_files = files[:n_validation] | |
| for f in tqdm(val_files): | |
| os.rename(f, f.replace("waifus", "waifus_val")) | |
| with open("./waifus/checked", "w") as fp: | |
| fp.write("") | |
| class MultiInputWrapper(object): | |
| def __init__(self, base_func): | |
| """ | |
| 複数の入力に対してtransformsを使うためのラッパークラス | |
| * base_func : 各入力に対して適用するtransformsの関数/クラス。関数またはlist(関数)。 | |
| """ | |
| self.base_func = base_func | |
| def __call__(self, xs): | |
| if isinstance(self.base_func, list): | |
| return [x if isinstance(x, tuple) else f(x) | |
| for f, x in zip(self.base_func, xs)] | |
| else: | |
| return [x if isinstance(x, tuple) else self.base_func(x) | |
| for x in xs] | |
| class ColorAndGray(object): | |
| def __call__(self, img): | |
| # ToTensor()の前に呼ぶ場合はimgはPILのインスタンス | |
| gray = img.convert("L").convert("RGB") # 3チャンネルに戻す | |
| return img, gray, img.size | |
| class RealAndMosaic(object): | |
| def __init__(self, mosaic_kernel, gaussian_blur_kernel): | |
| self.mosaic_kernel = mosaic_kernel | |
| self.gaussian_blur_kernel = gaussian_blur_kernel | |
| def __call__(self, img): | |
| width, height = img.size | |
| # モザイクをかける部分 | |
| mosaic = img.resize([x // self.mosaic_kernel for x in img.size]).resize(img.size) | |
| # ガウシアンぼかし | |
| mosaic = mosaic.filter(ImageFilter.GaussianBlur(self.gaussian_blur_kernel)) | |
| return img, mosaic, img.size | |
| def load_dataset(transform_type, batch_size): | |
| """ | |
| My waifuデータセットをロードします。 | |
| * transform_type = color : [(真の画像, 白黒画像, サイズ), ダミー] | |
| * transform_type = mosaic : [(真の画像, モザイク, サイズ), ダミー] | |
| """ | |
| assert transform_type in ["color", "mosaic"] | |
| # 前処理 | |
| preprocess() | |
| if transform_type == "color": | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256), Image.LANCZOS), | |
| ColorAndGray(), | |
| MultiInputWrapper(transforms.ToTensor()), | |
| MultiInputWrapper([ | |
| transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), | |
| transforms.Normalize(mean=(0.5,), std=(0.5,)), | |
| None | |
| ]), | |
| ]) | |
| elif transform_type == "mosaic": | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256), Image.LANCZOS), | |
| RealAndMosaic(mosaic_kernel=16, gaussian_blur_kernel=4), | |
| MultiInputWrapper(transforms.ToTensor()), | |
| MultiInputWrapper(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))), | |
| ]) | |
| trainset = torchvision.datasets.ImageFolder("./waifus", transform=transform) | |
| valset = torchvision.datasets.ImageFolder("./waifus_val", transform=transform) | |
| trainloader = torch.utils.data.DataLoader(trainset, batch_size, True, num_workers=4) | |
| valloader = torch.utils.data.DataLoader(valset, batch_size, False, num_workers=4) | |
| return trainloader, valloader | |
| def test(): | |
| trainset, valset = load_dataset("mosaic", 16) | |
| for (img_real, img_input, size), _ in tqdm(valset): | |
| torchvision.utils.save_image(img_real, "real.png", nrow=4, normalize=True, range=(-1.0, 1.0)) | |
| torchvision.utils.save_image(img_input, "input.png", nrow=4, normalize=True, range=(-1.0, 1.0)) | |
| exit() | |
| if __name__ == "__main__": | |
| test() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| class ConvBnAct(nn.Module): | |
| def __init__(self, in_ch, out_ch, kernel=3, stride=1, act="relu"): | |
| super().__init__() | |
| layers = [] | |
| layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=(kernel - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| if act == "relu": | |
| layers.append(nn.ReLU(inplace=True)) | |
| elif act == "tanh": | |
| layers.append(nn.Tanh()) | |
| self.main = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.main(x) | |
| class ResNetBlock(nn.Module): | |
| def __init__(self, base_ch): | |
| super().__init__() | |
| self.conv1 = ConvBnAct(base_ch, base_ch) | |
| self.conv2 = ConvBnAct(base_ch, base_ch) | |
| def forward(self, inputs): | |
| x = self.conv2(self.conv1(inputs)) | |
| return x + inputs | |
| class Downsampling(nn.Module): | |
| # 1x1 convによるダウンサンプリング | |
| def __init__(self, in_ch, out_ch, pool_size): | |
| super().__init__() | |
| self.conv = ConvBnAct(in_ch, out_ch, kernel=1, stride=pool_size) | |
| def forward(self, inputs): | |
| return self.conv(inputs) | |
| class Upsampling(nn.Module): | |
| # baseをupsampling -> encoder側とconcat -> 1x1 conv | |
| def __init__(self, base_ch, enc_ch, out_ch, scale_factor): | |
| super().__init__() | |
| self.scale_factor = scale_factor | |
| self.conv = ConvBnAct(base_ch + enc_ch, out_ch, kernel=1) | |
| def forward(self, base_inputs, enc_inputs): | |
| x = F.upsample_nearest(base_inputs, scale_factor=self.scale_factor) | |
| x = torch.cat([x, enc_inputs], dim=1) | |
| return self.conv(x) | |
| def create_subnetwork(base_ch, downsampling=True, upsampling=True, decoder=True): | |
| # 入力~ブロック | |
| conv_in = ConvBnAct(3, base_ch) | |
| # Encoder | |
| encoder_block = nn.Sequential(*[ResNetBlock(base_ch) for i in range(3)]) | |
| # ダウンサンプリング | |
| if downsampling: | |
| downsampling_block = Downsampling(base_ch, base_ch * 2, pool_size=2) | |
| else: | |
| downsampling_block = None | |
| # アップサンプリング | |
| if upsampling: | |
| upsampling_block = Upsampling(base_ch * 2, base_ch, base_ch, 2) | |
| else: | |
| upsampling_block = None | |
| # Decoder | |
| if decoder: | |
| decoder_block = ResNetBlock(base_ch) | |
| else: | |
| decoder_block = None | |
| # 出力ブロック | |
| conv_out = ConvBnAct(base_ch, 3, kernel=5, act="tanh") | |
| return conv_in, encoder_block, downsampling_block, upsampling_block, decoder_block, conv_out | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| class CoarseToFineModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.in1, self.enc1, self.down1, self.up1, self.dec1, self.out1 = create_subnetwork(16) # 256x256 | |
| self.in2, self.enc2, self.down2, self.up2, self.dec2, self.out2 = create_subnetwork(32) # 128x128 | |
| self.in3, self.enc3, self.down3, self.up3, self.dec3, self.out3 = create_subnetwork(64) # 64x64 | |
| self.in4, self.enc4, self.down4, self.up4, self.dec4, self.out4 = create_subnetwork(128) # 32x32 | |
| self.in5, self.enc5, self.down5, self.up5, self.dec5, self.out5 = create_subnetwork(256) # 16x16 | |
| self.in6, self.enc6, _, _, _, self.out6 = create_subnetwork(512) # 8x8 | |
| self.learning_stage = 0 | |
| def set_trainable(self, layer_index, trainable_flag): | |
| if layer_index == 0: | |
| layers = [self.in1, self.enc1, self.down1, self.up1, self.dec1, self.out1] | |
| elif layer_index == 1: | |
| layers = [self.in2, self.enc2, self.down2, self.up2, self.dec2, self.out2] | |
| elif layer_index == 2: | |
| layers = [self.in3, self.enc3, self.down3, self.up3, self.dec3, self.out3] | |
| elif layer_index == 3: | |
| layers = [self.in4, self.enc4, self.down4, self.up4, self.dec4, self.out4] | |
| elif layer_index == 4: | |
| layers = [self.in5, self.enc5, self.down5, self.up5, self.dec5, self.out5] | |
| elif layer_index == 5: | |
| layers = [self.in6, self.enc6, self.out6] | |
| for l in layers: | |
| for p in l.parameters(): | |
| p.requires_grad = trainable_flag | |
| def switch_epoch(self, epoch, epoch_per_mode=50, burnin_epoch=10): | |
| # burn_in epochが終わったら全再開 | |
| if epoch % epoch_per_mode == burnin_epoch: | |
| print("---", epoch, ":", "全再開する") | |
| for i in range(6): | |
| self.set_trainable(i, True) | |
| # burn_inの設定+訓練切り替え | |
| elif epoch % epoch_per_mode == 0: | |
| new_training_stage = 5 - epoch // epoch_per_mode | |
| self.learning_stage = new_training_stage | |
| print("---", epoch, ":", new_training_stage, "にステージ変更") | |
| if epoch != 0: | |
| self.set_trainable(new_training_stage + 1, False) | |
| print("---", epoch, ":", new_training_stage+1, "の係数を固定") | |
| def subblock_encoder_forward(self, main_inputs, encoder_layer, downsampling_layer, lower_input_layer, threshold): | |
| if self.learning_stage <= threshold: | |
| mid = encoder_layer(main_inputs) | |
| x = downsampling_layer(mid) | |
| else: | |
| x = F.avg_pool2d(main_inputs, kernel_size=2) | |
| if self.learning_stage == threshold + 1: | |
| x = lower_input_layer(x) | |
| mid = None | |
| return x, mid | |
| def subblock_decoder_foward(self, main_inputs, encoder_inputs, upsampling_layer, decoder_layer, output_layer, threshold): | |
| if self.learning_stage <= threshold: | |
| x = upsampling_layer(main_inputs, encoder_inputs) # decoderを入れること | |
| x = decoder_layer(x) | |
| if self.learning_stage == threshold: | |
| x = output_layer(x) | |
| else: | |
| x = F.interpolate(main_inputs, scale_factor=2) | |
| return x | |
| def forward(self, inputs): | |
| # Encoder | |
| if self.learning_stage <= 0: | |
| x = self.in1(inputs) | |
| else: | |
| x = inputs | |
| x, mid1 = self.subblock_encoder_forward(x, self.enc1, self.down1, self.in2, 0) | |
| x, mid2 = self.subblock_encoder_forward(x, self.enc2, self.down2, self.in3, 1) | |
| x, mid3 = self.subblock_encoder_forward(x, self.enc3, self.down3, self.in4, 2) | |
| x, mid4 = self.subblock_encoder_forward(x, self.enc4, self.down4, self.in5, 3) | |
| x, mid5 = self.subblock_encoder_forward(x, self.enc5, self.down5, self.in6, 4) | |
| x = self.enc6(x) | |
| # Decoder | |
| if self.learning_stage >= 5: | |
| x = self.out6(x) | |
| x = self.subblock_decoder_foward(x, mid5, self.up5, self.dec5, self.out5, 4) | |
| x = self.subblock_decoder_foward(x, mid4, self.up4, self.dec4, self.out4, 3) | |
| x = self.subblock_decoder_foward(x, mid3, self.up3, self.dec3, self.out3, 2) | |
| x = self.subblock_decoder_foward(x, mid2, self.up2, self.dec2, self.out2, 1) | |
| x = self.subblock_decoder_foward(x, mid1, self.up1, self.dec1, self.out1, 0) | |
| return x | |
| if __name__ == "__main__": | |
| x = torch.randn(16, 3, 256, 256) | |
| model = CoarseToFineModel() | |
| model.learning_stage = 0 | |
| y = model(x) | |
| torchvision.utils.save_image(y, "hoge.png", nrow=4, padding=10, normalize=True, range=(-1.0,1.0)) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| class ConvBnAct(nn.Module): | |
| def __init__(self, in_ch, out_ch, kernel=3, stride=1, act="relu"): | |
| super().__init__() | |
| layers = [] | |
| layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=(kernel - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| if act == "relu": | |
| layers.append(nn.ReLU(inplace=True)) | |
| elif act == "tanh": | |
| layers.append(nn.Tanh()) | |
| self.main = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.main(x) | |
| class ResNetBlock(nn.Module): | |
| def __init__(self, base_ch): | |
| super().__init__() | |
| self.conv1 = ConvBnAct(base_ch, base_ch) | |
| self.conv2 = ConvBnAct(base_ch, base_ch) | |
| def forward(self, inputs): | |
| x = self.conv2(self.conv1(inputs)) | |
| return x + inputs | |
| class Downsampling(nn.Module): | |
| # 1x1 convによるダウンサンプリング | |
| def __init__(self, in_ch, out_ch, pool_size): | |
| super().__init__() | |
| self.conv = ConvBnAct(in_ch, out_ch, kernel=1, stride=pool_size) | |
| def forward(self, inputs): | |
| return self.conv(inputs) | |
| class Upsampling(nn.Module): | |
| # baseをupsampling -> 1x1 conv | |
| def __init__(self, base_ch, out_ch, scale_factor): | |
| super().__init__() | |
| self.scale_factor = scale_factor | |
| self.conv = ConvBnAct(base_ch, out_ch, kernel=1) | |
| def forward(self, base_inputs): | |
| x = F.upsample_nearest(base_inputs, scale_factor=self.scale_factor) | |
| return self.conv(x) | |
| def create_subnetwork(base_ch, downsampling=True, upsampling=True, decoder=True): | |
| # 入力~ブロック | |
| conv_in = ConvBnAct(3, base_ch) | |
| # Encoder | |
| encoder_block = nn.Sequential(*[ResNetBlock(base_ch) for i in range(3)]) | |
| # ダウンサンプリング | |
| if downsampling: | |
| downsampling_block = Downsampling(base_ch, base_ch * 2, pool_size=2) | |
| else: | |
| downsampling_block = None | |
| # アップサンプリング | |
| if upsampling: | |
| upsampling_block = Upsampling(base_ch * 2, base_ch, 2) | |
| else: | |
| upsampling_block = None | |
| # Decoder | |
| if decoder: | |
| decoder_block = ResNetBlock(base_ch) | |
| else: | |
| decoder_block = None | |
| # 出力ブロック | |
| conv_out = ConvBnAct(base_ch, 3, kernel=5, act="tanh") | |
| return conv_in, encoder_block, downsampling_block, upsampling_block, decoder_block, conv_out | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| class CoarseToFineAutoEncoderModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.in1, self.enc1, self.down1, self.up1, self.dec1, self.out1 = create_subnetwork(16) # 256x256 | |
| self.in2, self.enc2, self.down2, self.up2, self.dec2, self.out2 = create_subnetwork(32) # 128x128 | |
| self.in3, self.enc3, self.down3, self.up3, self.dec3, self.out3 = create_subnetwork(64) # 64x64 | |
| self.in4, self.enc4, self.down4, self.up4, self.dec4, self.out4 = create_subnetwork(128) # 32x32 | |
| self.in5, self.enc5, self.down5, self.up5, self.dec5, self.out5 = create_subnetwork(256) # 16x16 | |
| self.in6, self.enc6, _, _, _, self.out6 = create_subnetwork(512) # 8x8 | |
| self.learning_stage = 0 | |
| def set_trainable(self, layer_index, trainable_flag): | |
| if layer_index == 0: | |
| layers = [self.in1, self.enc1, self.down1, self.up1, self.dec1, self.out1] | |
| elif layer_index == 1: | |
| layers = [self.in2, self.enc2, self.down2, self.up2, self.dec2, self.out2] | |
| elif layer_index == 2: | |
| layers = [self.in3, self.enc3, self.down3, self.up3, self.dec3, self.out3] | |
| elif layer_index == 3: | |
| layers = [self.in4, self.enc4, self.down4, self.up4, self.dec4, self.out4] | |
| elif layer_index == 4: | |
| layers = [self.in5, self.enc5, self.down5, self.up5, self.dec5, self.out5] | |
| elif layer_index == 5: | |
| layers = [self.in6, self.enc6, self.out6] | |
| for l in layers: | |
| for p in l.parameters(): | |
| p.requires_grad = trainable_flag | |
| def switch_epoch(self, epoch, epoch_per_mode=50, burnin_epoch=10): | |
| # burn_in epochが終わったら全再開 | |
| if epoch % epoch_per_mode == burnin_epoch: | |
| print("---", epoch, ":", "全再開する") | |
| for i in range(6): | |
| self.set_trainable(i, True) | |
| # burn_inの設定+訓練切り替え | |
| elif epoch % epoch_per_mode == 0: | |
| new_training_stage = 5 - epoch // epoch_per_mode | |
| self.learning_stage = new_training_stage | |
| print("---", epoch, ":", new_training_stage, "にステージ変更") | |
| if epoch != 0: | |
| self.set_trainable(new_training_stage + 1, False) | |
| print("---", epoch, ":", new_training_stage+1, "の係数を固定") | |
| def subblock_encoder_forward(self, main_inputs, encoder_layer, downsampling_layer, lower_input_layer, threshold): | |
| if self.learning_stage <= threshold: | |
| mid = encoder_layer(main_inputs) | |
| x = downsampling_layer(mid) | |
| else: | |
| x = F.avg_pool2d(main_inputs, kernel_size=2) | |
| if self.learning_stage == threshold + 1: | |
| x = lower_input_layer(x) | |
| mid = None | |
| return x, mid | |
| def subblock_decoder_foward(self, main_inputs, encoder_inputs, upsampling_layer, decoder_layer, output_layer, threshold): | |
| if self.learning_stage <= threshold: | |
| x = upsampling_layer(main_inputs) | |
| x = decoder_layer(x) | |
| if self.learning_stage == threshold: | |
| x = output_layer(x) | |
| else: | |
| x = F.interpolate(main_inputs, scale_factor=2) | |
| return x | |
| def forward(self, inputs): | |
| # Encoder | |
| if self.learning_stage <= 0: | |
| x = self.in1(inputs) | |
| else: | |
| x = inputs | |
| x, mid1 = self.subblock_encoder_forward(x, self.enc1, self.down1, self.in2, 0) | |
| x, mid2 = self.subblock_encoder_forward(x, self.enc2, self.down2, self.in3, 1) | |
| x, mid3 = self.subblock_encoder_forward(x, self.enc3, self.down3, self.in4, 2) | |
| x, mid4 = self.subblock_encoder_forward(x, self.enc4, self.down4, self.in5, 3) | |
| x, mid5 = self.subblock_encoder_forward(x, self.enc5, self.down5, self.in6, 4) | |
| x = self.enc6(x) | |
| # Decoder | |
| if self.learning_stage >= 5: | |
| x = self.out6(x) | |
| x = self.subblock_decoder_foward(x, mid5, self.up5, self.dec5, self.out5, 4) | |
| x = self.subblock_decoder_foward(x, mid4, self.up4, self.dec4, self.out4, 3) | |
| x = self.subblock_decoder_foward(x, mid3, self.up3, self.dec3, self.out3, 2) | |
| x = self.subblock_decoder_foward(x, mid2, self.up2, self.dec2, self.out2, 1) | |
| x = self.subblock_decoder_foward(x, mid1, self.up1, self.dec1, self.out1, 0) | |
| return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| class ConvBnAct(nn.Module): | |
| def __init__(self, in_ch, out_ch, kernel=3, stride=1, act="relu"): | |
| super().__init__() | |
| layers = [] | |
| layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=(kernel - 1) // 2)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| if act == "relu": | |
| layers.append(nn.ReLU(inplace=True)) | |
| elif act == "tanh": | |
| layers.append(nn.Tanh()) | |
| self.main = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.main(x) | |
| class ResNetBlock(nn.Module): | |
| def __init__(self, base_ch): | |
| super().__init__() | |
| self.conv1 = ConvBnAct(base_ch, base_ch) | |
| self.conv2 = ConvBnAct(base_ch, base_ch) | |
| def forward(self, inputs): | |
| x = self.conv2(self.conv1(inputs)) | |
| return x + inputs | |
| class Downsampling(nn.Module): | |
| # 1x1 convによるダウンサンプリング | |
| def __init__(self, in_ch, out_ch, pool_size): | |
| super().__init__() | |
| self.conv = ConvBnAct(in_ch, out_ch, kernel=1, stride=pool_size) | |
| def forward(self, inputs): | |
| return self.conv(inputs) | |
| class Upsampling(nn.Module): | |
| # baseをupsampling -> encoder側とconcat -> 1x1 conv | |
| def __init__(self, base_ch, enc_ch, out_ch, scale_factor): | |
| super().__init__() | |
| self.scale_factor = scale_factor | |
| self.conv = ConvBnAct(base_ch + enc_ch, out_ch, kernel=1) | |
| def forward(self, base_inputs, enc_inputs): | |
| x = F.upsample_nearest(base_inputs, scale_factor=self.scale_factor) | |
| x = torch.cat([x, enc_inputs], dim=1) | |
| return self.conv(x) | |
| def create_subnetwork(base_ch, downsampling=True, upsampling=True, decoder=True): | |
| # 入力~ブロック | |
| conv_in = ConvBnAct(3, base_ch) | |
| # Encoder | |
| encoder_block = nn.Sequential(*[ResNetBlock(base_ch) for i in range(3)]) | |
| # ダウンサンプリング | |
| if downsampling: | |
| downsampling_block = Downsampling(base_ch, base_ch * 2, pool_size=2) | |
| else: | |
| downsampling_block = None | |
| # アップサンプリング | |
| if upsampling: | |
| upsampling_block = Upsampling(base_ch * 2, base_ch, base_ch, 2) | |
| else: | |
| upsampling_block = None | |
| # Decoder | |
| if decoder: | |
| decoder_block = ResNetBlock(base_ch) | |
| else: | |
| decoder_block = None | |
| # 出力ブロック | |
| conv_out = ConvBnAct(base_ch, 3, kernel=5, act="tanh") | |
| return conv_in, encoder_block, downsampling_block, upsampling_block, decoder_block, conv_out | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| class Unet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.enc1 = self.encoder_block(3, 16, downsampling=False) # 256x256 | |
| self.enc2 = self.encoder_block(16, 32) # 128x128 | |
| self.enc3 = self.encoder_block(32, 64) # 64x64 | |
| self.enc4 = self.encoder_block(64, 128) # 32x32 | |
| self.enc5 = self.encoder_block(128, 256) # 16x16 | |
| self.enc6 = self.encoder_block(256, 512) # 8x8 | |
| self.up5 = Upsampling(512, 256, 256, 2) | |
| self.up4 = Upsampling(256, 128, 128, 2) | |
| self.up3 = Upsampling(128, 64, 64, 2) | |
| self.up2 = Upsampling(64, 32, 32, 2) | |
| self.up1 = Upsampling(32, 16, 16, 2) | |
| self.dec5 = ResNetBlock(256) | |
| self.dec4 = ResNetBlock(128) | |
| self.dec3 = ResNetBlock(64) | |
| self.dec2 = ResNetBlock(32) | |
| self.dec1 = ResNetBlock(16) | |
| self.out = ConvBnAct(16, 3, act="tanh") | |
| def encoder_block(self, old_ch, base_ch, downsampling=True): | |
| return nn.Sequential( | |
| ConvBnAct(old_ch, base_ch, stride=2 if downsampling else 1, kernel=1), | |
| ResNetBlock(base_ch), | |
| ResNetBlock(base_ch), | |
| ResNetBlock(base_ch) | |
| ) | |
| def decoder_block(self, base_ch, enc_ch, out_ch): | |
| return nn.Sequential( | |
| Upsampling(base_ch, enc_ch, out_ch), | |
| ResNetBlock(out_ch), | |
| ) | |
| def forward(self, inputs): | |
| mid1 = self.enc1(inputs) | |
| mid2 = self.enc2(mid1) | |
| mid3 = self.enc3(mid2) | |
| mid4 = self.enc4(mid3) | |
| mid5 = self.enc5(mid4) | |
| x = self.enc6(mid5) | |
| x = self.dec5(self.up5(x, mid5)) | |
| x = self.dec4(self.up4(x, mid4)) | |
| x = self.dec3(self.up3(x, mid3)) | |
| x = self.dec2(self.up2(x, mid2)) | |
| x = self.dec1(self.up1(x, mid1)) | |
| return self.out(x) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| from dataloader import load_dataset | |
| from models import CoarseToFineModel | |
| from models_ae import CoarseToFineAutoEncoderModel | |
| from models_unet import Unet | |
| import os | |
| import pickle | |
| import statistics | |
| from tqdm import tqdm | |
| import time | |
| def train(transform_type): | |
| trainloader, testloader = load_dataset(transform_type, 64) | |
| model = CoarseToFineModel().to("cuda") # ここを切り替える | |
| # model = CoarseToFineAutoEncoderModel().to("cuda") | |
| # model = Unet().to("cuda") | |
| print(model) | |
| model = torch.nn.DataParallel(model) | |
| loss_l1 = torch.nn.L1Loss() | |
| loss_mse = torch.nn.MSELoss() | |
| params = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| result = {"loss":[], "psnr":[], "val_psnr":[], "time":[]} | |
| for i in range(300): | |
| # 訓練フェーズの更新 | |
| model.module.switch_epoch(i) # ここを切り替える | |
| log_loss = [] | |
| log_psnr = [] | |
| log_val_psnr = [] | |
| start_time = time.time() | |
| with tqdm(trainloader) as pbar: | |
| for (img_real, img_input, size), _ in pbar: | |
| # fore-prop | |
| img_real, img_input = img_real.to("cuda"), img_input.to("cuda") | |
| params.zero_grad() | |
| y_pred = model(img_input) | |
| # Coarse to fineにあわせてrealの解像度をあわせる | |
| #coarse_ratio = 2 ** model.module.learning_stage | |
| #if coarse_ratio > 1: | |
| # y_true = F.interpolate(F.avg_pool2d(img_real, kernel_size=coarse_ratio), scale_factor=coarse_ratio) | |
| #else: | |
| # y_true = img_real | |
| y_true = img_real | |
| # L1ロスで最適化 | |
| loss = loss_l1(y_pred, y_true) | |
| # PSNR | |
| psnr = 10.0 * torch.log10(2.0 ** 2 / (loss_mse(y_pred, img_real) + 1e-10) ) | |
| # back prop | |
| loss.backward() | |
| params.step() | |
| log_loss.append(loss.item()) | |
| log_psnr.append(psnr.item()) | |
| pbar.set_postfix(epoch=i, loss=statistics.mean(log_loss), psnr=statistics.mean(log_psnr)) | |
| if not os.path.exists(transform_type): | |
| os.mkdir(transform_type) | |
| batch_len = min(len(img_input), len(y_pred), len(y_true), 16) | |
| x = torch.cat([img_input[:batch_len], y_pred[:batch_len], y_true[:batch_len]], dim=0) | |
| torchvision.utils.save_image(x, transform_type + f"/epoch_{i:03}.png", nrow=8, padding=10, | |
| normalize=True, range=(-1.0, 1.0)) | |
| with tqdm(testloader) as pbar: | |
| for (img_real, img_input, size), _ in pbar: | |
| img_real, img_input = img_real.to("cuda"), img_input.to("cuda") | |
| with torch.no_grad(): | |
| y_pred = model(img_input) | |
| psnr = 10.0 * torch.log10(2.0 ** 2 / (loss_mse(y_pred, img_real) + 1e-10) ) | |
| log_val_psnr.append(psnr.item()) | |
| pbar.set_postfix(epoch=i, val_psnr=statistics.mean(log_val_psnr)) | |
| result["loss"].append(statistics.mean(log_loss)) | |
| result["psnr"].append(statistics.mean(log_psnr)) | |
| result["val_psnr"].append(statistics.mean(log_val_psnr)) | |
| result["time"].append(time.time()-start_time) | |
| with open(transform_type + "/logs.pkl", "wb") as fp: | |
| pickle.dump(result, fp) | |
| if __name__ == "__main__": | |
| train("mosaic") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment