Created
May 21, 2020 19:13
-
-
Save bearpelican/d2833b3e65134e66b59904f0aef11666 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from fastai.vision import *" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"One-time download, uncomment the next cells to get the data." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#path = Config().data_path()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#! wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip -P {path}\n", | |
"#! unzip -q -n {path}/horse2zebra.zip -d {path}\n", | |
"#! rm {path}/horse2zebra.zip" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data_path = Path('../../../../mnt/wamri/WAMRI-LevensonLab/datasets/')\n", | |
"muse2he_path = data_path/'muse2he_urothelial_carcinoma'\n", | |
"muse2he_path.ls()\n", | |
"\n", | |
"muse_path = muse2he_path/'trainA'\n", | |
"he_path = muse2he_path/'trainB'\n", | |
"\n", | |
"torch.cuda.set_device(0) #set GPU id" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# path = Config().data_path()/'horse2zebra'\n", | |
"# path.ls()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"See [this tutorial](https://docs.fast.ai/tutorial.itemlist.html) for a detailed walkthrough of how/why this custom `ItemList` was created." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ImageTuple(ItemBase):\n", | |
" def __init__(self, img1, img2):\n", | |
" self.img1,self.img2 = img1,img2\n", | |
" self.obj,self.data = (img1,img2),[-1+2*img1.data,-1+2*img2.data]\n", | |
" \n", | |
" def apply_tfms(self, tfms, **kwargs):\n", | |
" img1 = self.img1.apply_tfms(tfms, **kwargs)\n", | |
" img2 = self.img2.apply_tfms(tfms, **kwargs)\n", | |
" return ImageTuple(img1, img2)\n", | |
" \n", | |
" def to_one(self): return Image(0.5+torch.cat(self.data,2)/2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class TargetTupleList(ItemList):\n", | |
" def reconstruct(self, t:Tensor): \n", | |
" if len(t.size()) == 0: return t\n", | |
" return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ImageTupleList(ImageList):\n", | |
" _label_cls=TargetTupleList\n", | |
" def __init__(self, items, itemsB=None, **kwargs):\n", | |
" self.itemsB = itemsB\n", | |
" super().__init__(items, **kwargs)\n", | |
" \n", | |
" def new(self, items, **kwargs):\n", | |
" return super().new(items, itemsB=self.itemsB, **kwargs)\n", | |
" \n", | |
" def get(self, i):\n", | |
" img1 = super().get(i)\n", | |
" fn = self.itemsB[random.randint(0, len(self.itemsB)-1)]\n", | |
" return ImageTuple(img1, open_image(fn))\n", | |
" \n", | |
" def reconstruct(self, t:Tensor): \n", | |
" return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))\n", | |
" \n", | |
" @classmethod\n", | |
" def from_folders(cls, path, folderA, folderB, **kwargs):\n", | |
" itemsB = ImageList.from_folder(path/folderB).items\n", | |
" res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)\n", | |
" res.path = path\n", | |
" return res\n", | |
" \n", | |
" def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs):\n", | |
" \"Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method.\"\n", | |
" rows = int(math.sqrt(len(xs)))\n", | |
" fig, axs = plt.subplots(rows,rows,figsize=figsize)\n", | |
" for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):\n", | |
" xs[i].to_one().show(ax=ax, **kwargs)\n", | |
" plt.tight_layout()\n", | |
"\n", | |
" def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs):\n", | |
" \"\"\"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.\n", | |
" `kwargs` are passed to the show method.\"\"\"\n", | |
" figsize = ifnone(figsize, (12,3*len(xs)))\n", | |
" fig,axs = plt.subplots(len(xs), 2, figsize=figsize)\n", | |
" fig.suptitle('Ground truth / Predictions', weight='bold', size=14)\n", | |
" for i,(x,z) in enumerate(zip(xs,zs)):\n", | |
" x.to_one().show(ax=axs[i,0], **kwargs)\n", | |
" z.to_one().show(ax=axs[i,1], **kwargs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# src = ImageTupleList.from_folders(muse2he_path, 'trainA', 'trainB').split_none().label_empty()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = (ImageTupleList.from_folders(muse2he_path, 'trainA', 'trainB')\n", | |
" .split_none()\n", | |
" .label_empty()\n", | |
" .transform(2*[[crop(size=256,row_pct=0,col_pct=0),flip_lr(p=0.5)]],size=512,resize_method=ResizeMethod.SQUISH)\n", | |
" .databunch(bs=4,num_workers=2))\n", | |
"data.valid_dl = data.train_dl # a hack for proper evaluation of loss and metrics at end of training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data.show_batch(rows=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([4, 3, 256, 256]), torch.Size([4, 3, 256, 256]))" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xb, yb = data.one_batch()\n", | |
"xb[0].shape, xb[1].shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_image = next(iter(data.train_dl))[0][0][0]\n", | |
"plt.imshow(((train_image.permute(1,2,0)+1)/2*255).cpu().to(torch.int))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Models" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We use the models that were introduced in the [cycleGAN paper](https://arxiv.org/abs/1703.10593)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):\n", | |
" return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),\n", | |
" norm_layer(ch_out), nn.ReLU(True)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True, \n", | |
" pad=1, stride:int=1, activ:bool=True, init:Callable=nn.init.kaiming_normal_, init_gain:int=0.02)->List[nn.Module]:\n", | |
" layers = []\n", | |
" if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))\n", | |
" elif pad_mode == 'border': layers.append(nn.ReplicationPad2d(pad))\n", | |
" p = pad if pad_mode == 'zeros' else 0\n", | |
" conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)\n", | |
" if init:\n", | |
" if init == nn.init.normal_:\n", | |
" init(conv.weight, 0.0, init_gain)\n", | |
" else:\n", | |
" init(conv.weight)\n", | |
" if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n", | |
" layers += [conv, norm_layer(ch_out)]\n", | |
" if activ: layers.append(nn.ReLU(inplace=True))\n", | |
" return layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ResnetBlock(nn.Module):\n", | |
" def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):\n", | |
" super().__init__()\n", | |
" assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'\n", | |
" norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", | |
" layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)\n", | |
" if dropout != 0: layers.append(nn.Dropout(dropout))\n", | |
" layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)\n", | |
" self.conv_block = nn.Sequential(*layers)\n", | |
"\n", | |
" def forward(self, x): return x + self.conv_block(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None, \n", | |
" dropout:float=0., n_blocks:int=9, pad_mode:str='reflection')->nn.Module:\n", | |
" norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", | |
" bias = (norm_layer == nn.InstanceNorm2d)\n", | |
" layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)\n", | |
" for i in range(2):\n", | |
" layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)\n", | |
" n_ftrs *= 2\n", | |
" layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]\n", | |
" for i in range(2):\n", | |
" layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)\n", | |
" n_ftrs //= 2\n", | |
" layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]\n", | |
" return nn.Sequential(*layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"resnet_generator(3, 3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1, \n", | |
" activ:bool=True, slope:float=0.2, init:Callable=nn.init.normal_, init_gain:int=0.02)->List[nn.Module]:\n", | |
" conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)\n", | |
" if init:\n", | |
" if init == nn.init.normal_:\n", | |
" init(conv.weight, 0.0, init_gain)\n", | |
" else:\n", | |
" init(conv.weight)\n", | |
" if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n", | |
" layers = [conv]\n", | |
" if norm_layer is not None: layers.append(norm_layer(ch_out))\n", | |
" if activ: layers.append(nn.LeakyReLU(slope, inplace=True))\n", | |
" return layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def discriminator(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:\n", | |
" norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", | |
" bias = (norm_layer == nn.InstanceNorm2d)\n", | |
" layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)\n", | |
" for i in range(n_layers-1):\n", | |
" new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs\n", | |
" layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)\n", | |
" n_ftrs = new_ftrs\n", | |
" new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs\n", | |
" layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)\n", | |
" layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))\n", | |
" if sigmoid: layers.append(nn.Sigmoid())\n", | |
" return nn.Sequential(*layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Sequential(\n", | |
" (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", | |
" (1): LeakyReLU(negative_slope=0.2, inplace=True)\n", | |
" (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", | |
" (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", | |
" (4): LeakyReLU(negative_slope=0.2, inplace=True)\n", | |
" (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", | |
" (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", | |
" (7): LeakyReLU(negative_slope=0.2, inplace=True)\n", | |
" (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))\n", | |
" (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n", | |
" (10): LeakyReLU(negative_slope=0.2, inplace=True)\n", | |
" (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))\n", | |
")" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"discriminator(3)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We group two discriminators and two generators in a single model, then a `Callback` will take care of training them properly." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CycleGAN(nn.Module):\n", | |
" \n", | |
" def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True, \n", | |
" drop:float=0., norm_layer:nn.Module=None):\n", | |
" super().__init__()\n", | |
" self.D_A = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n", | |
" self.D_B = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n", | |
" self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n", | |
" self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n", | |
" #G_A: takes real input B and generates fake input A\n", | |
" #G_B: takes real input A and generates fake input B\n", | |
" #D_A: trained to make the difference between real input A and fake input A\n", | |
" #D_B: trained to make the difference between real input B and fake input B\n", | |
" \n", | |
" def forward(self, real_A, real_B):\n", | |
" fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)\n", | |
" idt_A, idt_B = self.G_A(real_A), self.G_B(real_B) #Needed for the identity loss during training.\n", | |
" if not self.training: return torch.cat([fake_A[:,None], fake_B[:,None], idt_A[:,None], idt_B[:,None]],1)\n", | |
" return [fake_A, fake_B, idt_A, idt_B]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`AdaptiveLoss` is a wrapper around a PyTorch loss function to compare an output of any size with a single number (0. or 1.). It will generate a target with the same shape as the output. A discriminator returns a feature map, and we want it to predict zeros (or ones) for each feature." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AdaptiveLoss(nn.Module):\n", | |
" def __init__(self, crit):\n", | |
" super().__init__()\n", | |
" self.crit = crit\n", | |
" \n", | |
" def forward(self, output, target:bool, **kwargs):\n", | |
" targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())\n", | |
" return self.crit(output, targ, **kwargs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The main loss used to train the generators. It has three parts:\n", | |
"- the classic GAN loss: they must make the critics believe their images are real\n", | |
"- identity loss: if they are given an image from the set they are trying to imitate, they should return the same thing\n", | |
"- cycle loss: if an image from A goes through the generator that imitates B then through the generator that imitates A, it should be the same as the initial image. Same for B and switching the generators" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CycleGanLoss(nn.Module):\n", | |
" \n", | |
" def __init__(self, cgan:nn.Module, lambda_A:float=10., lambda_B:float=10, lambda_idt:float=0.5, lsgan:bool=True):\n", | |
" super().__init__()\n", | |
" self.cgan,self.l_A,self.l_B,self.l_idt = cgan,lambda_A,lambda_B,lambda_idt\n", | |
" self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)\n", | |
" \n", | |
" def set_training(self,training):\n", | |
" self.training = training\n", | |
"\n", | |
" def set_input(self, input):\n", | |
" self.real_A,self.real_B = input\n", | |
"\n", | |
" def forward(self, output, target):\n", | |
" if self.training:\n", | |
" fake_A, fake_B, idt_A, idt_B = output\n", | |
" else:\n", | |
" fake_A, fake_B, idt_A, idt_B = output[:,0,:,:,:], output[:,1,:,:,:], output[:,2,:,:,:], output[:,3,:,:,:]\n", | |
"\n", | |
" #Generators should return identity on the datasets they try to convert to\n", | |
" self.id_loss = self.l_idt * (self.l_A * F.l1_loss(idt_A, self.real_A) + self.l_B * F.l1_loss(idt_B, self.real_B))\n", | |
" #Generators are trained to trick the discriminators so the following should be ones\n", | |
" self.gen_loss = self.crit(self.cgan.D_A(fake_A), True) + self.crit(self.cgan.D_B(fake_B), True)\n", | |
" #Cycle loss\n", | |
" self.cyc_loss = self.l_A * F.l1_loss(self.cgan.G_A(fake_B), self.real_A)\n", | |
" self.cyc_loss += self.l_B * F.l1_loss(self.cgan.G_B(fake_A), self.real_B)\n", | |
" return self.id_loss+self.gen_loss+self.cyc_loss" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The main callback to train a cycle GAN. The training loop will train the generators (so `learn.opt` is given those parameters) while the critics are trained by the callback during `on_batch_end`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CycleGANTrainer(LearnerCallback):\n", | |
" _order = -20 #Need to run before the Recorder\n", | |
" \n", | |
" def _set_trainable(self, D_A=False, D_B=False):\n", | |
" gen = (not D_A) and (not D_B)\n", | |
" requires_grad(self.learn.model.G_A, gen)\n", | |
" requires_grad(self.learn.model.G_B, gen)\n", | |
" requires_grad(self.learn.model.D_A, D_A)\n", | |
" requires_grad(self.learn.model.D_B, D_B)\n", | |
" if not gen:\n", | |
" self.opt_D_A.lr, self.opt_D_A.mom = self.learn.opt.lr, self.learn.opt.mom\n", | |
" self.opt_D_A.wd, self.opt_D_A.beta = self.learn.opt.wd, self.learn.opt.beta\n", | |
" self.opt_D_B.lr, self.opt_D_B.mom = self.learn.opt.lr, self.learn.opt.mom\n", | |
" self.opt_D_B.wd, self.opt_D_B.beta = self.learn.opt.wd, self.learn.opt.beta\n", | |
" \n", | |
" def on_train_begin(self, metrics_names, **kwargs):\n", | |
" self.G_A,self.G_B = self.learn.model.G_A,self.learn.model.G_B\n", | |
" self.D_A,self.D_B = self.learn.model.D_A,self.learn.model.D_B\n", | |
" self.crit = self.learn.loss_func.crit\n", | |
" if not getattr(self,'opt_G',None):\n", | |
" self.opt_G = self.learn.opt.new([nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))])\n", | |
" else: \n", | |
" self.opt_G.lr,self.opt_G.wd = self.opt.lr,self.opt.wd\n", | |
" self.opt_G.mom,self.opt_G.beta = self.opt.mom,self.opt.beta\n", | |
" if not getattr(self,'opt_D_A',None):\n", | |
" self.opt_D_A = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_A))])\n", | |
" if not getattr(self,'opt_D_B',None):\n", | |
" self.opt_D_B = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_B))])\n", | |
" self.learn.opt.opt = self.opt_G.opt\n", | |
" self._set_trainable()\n", | |
" self.id_smter,self.gen_smter,self.cyc_smter = SmoothenValue(0.98),SmoothenValue(0.98),SmoothenValue(0.98)\n", | |
" self.da_smter,self.db_smter = SmoothenValue(0.98),SmoothenValue(0.98)\n", | |
" self.recorder.add_metric_names(['id_loss', 'gen_loss', 'cyc_loss', 'D_A_loss', 'D_B_loss'])\n", | |
" \n", | |
" def on_epoch_begin(self, **kwargs):\n", | |
" torch.cuda.empty_cache()\n", | |
" \n", | |
" def on_batch_begin(self, last_input, **kwargs):\n", | |
" self.training = self.learn.model.training\n", | |
" self.learn.loss_func.set_training(self.training)\n", | |
" self.learn.loss_func.set_input(last_input)\n", | |
" \n", | |
" \n", | |
" def on_backward_begin(self, **kwargs):\n", | |
" self.id_smter.add_value(self.loss_func.id_loss.detach().cpu())\n", | |
" self.gen_smter.add_value(self.loss_func.gen_loss.detach().cpu())\n", | |
" self.cyc_smter.add_value(self.loss_func.cyc_loss.detach().cpu())\n", | |
" \n", | |
" def on_batch_end(self, last_input, last_output, **kwargs):\n", | |
" self.G_A.zero_grad(); self.G_B.zero_grad()\n", | |
" fake_A, fake_B = last_output[0].detach(), last_output[1].detach()\n", | |
" real_A, real_B = last_input\n", | |
" self._set_trainable(D_A=True)\n", | |
" self.D_A.zero_grad()\n", | |
" loss_D_A = 0.5 * (self.crit(self.D_A(real_A), True) + self.crit(self.D_A(fake_A), False))\n", | |
" self.da_smter.add_value(loss_D_A.detach().cpu())\n", | |
" if self.training:\n", | |
" loss_D_A.backward()\n", | |
" self.opt_D_A.step()\n", | |
" self._set_trainable(D_B=True)\n", | |
" self.D_B.zero_grad()\n", | |
" loss_D_B = 0.5 * (self.crit(self.D_B(real_B), True) + self.crit(self.D_B(fake_B), False))\n", | |
" self.db_smter.add_value(loss_D_B.detach().cpu())\n", | |
"\n", | |
" if self.training:\n", | |
" loss_D_B.backward()\n", | |
" self.opt_D_B.step()\n", | |
" self._set_trainable()\n", | |
" \n", | |
" def on_epoch_end(self, last_metrics, **kwargs):\n", | |
" return add_metrics(last_metrics, [s.smooth for s in [self.id_smter,self.gen_smter,self.cyc_smter,\n", | |
" self.da_smter,self.db_smter]])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Taken from https://github.com/fastai/fastai/blob/master/fastai/callbacks/flat_cos_anneal.py\n", | |
"# Work of Zach Mueller and Mikhail Grankin\n", | |
"from fastai.callback import *\n", | |
"from fastai.callbacks import *\n", | |
"def FlatAnnealScheduler(learn, lr:float=2e-4, n_epochs:int=100, n_epochs_decay:int=100, curve:str='linear'):\n", | |
" tot_epochs = n_epochs + n_epochs_decay\n", | |
" start_pct = n_epochs/tot_epochs\n", | |
" n = len(learn.data.train_dl)\n", | |
" anneal_start = int(n * tot_epochs * start_pct)\n", | |
" batch_finish = ((n * tot_epochs) - anneal_start)\n", | |
" if curve==\"cosine\": curve_type=annealing_cos\n", | |
" elif curve==\"linear\": curve_type=annealing_linear\n", | |
" elif curve==\"exponential\": curve_type=annealing_exp\n", | |
" else: raise ValueError(f\"annealing type not supported {curve}\")\n", | |
" phase0 = TrainingPhase(anneal_start).schedule_hp('lr', lr)\n", | |
" phase1 = TrainingPhase(batch_finish).schedule_hp('lr', lr, anneal=curve_type)\n", | |
" phases = [phase0, phase1]\n", | |
" return GeneralScheduler(learn, phases)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def fit_fa(learn:Learner, n_epochs:int=100, n_epochs_decay:int=100, lr:float=2e-4, curve:str='linear',\n", | |
" wd:float=None, callbacks:Optional[CallbackList]=None)->None:\n", | |
" \"Fit a model with Flat Cosine Annealing\"\n", | |
" max_lr = learn.lr_range(lr)\n", | |
" callbacks = listify(callbacks)\n", | |
" callbacks.append(FlatAnnealScheduler(learn, lr, n_epochs, n_epochs_decay, curve))\n", | |
" learn.fit(n_epochs+n_epochs_decay, max_lr, wd=wd, callbacks=callbacks)\n", | |
" \n", | |
"Learner.fit_fa = fit_fa" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AverageMetric(LearnerCallback):\n", | |
" \"Wrap a `func` in a callback for metrics computation.\"\n", | |
" def __init__(self, func):\n", | |
" # If func has a __name__ use this one else it should be a partial\n", | |
" name = func.__name__ if hasattr(func, '__name__') else func.func.__name__\n", | |
" self.func, self.name = func, name\n", | |
" self.world = num_distrib()\n", | |
"\n", | |
" def on_epoch_begin(self, **kwargs):\n", | |
" \"Set the inner value to 0.\"\n", | |
" self.val, self.count = 0.,0\n", | |
"\n", | |
" def on_batch_begin(self, last_input, **kwargs):\n", | |
" self.last_input = last_input\n", | |
" \n", | |
" def on_batch_end(self, last_output, last_target, **kwargs):\n", | |
" \"Update metric computation with `last_output` and `last_target`.\"\n", | |
" if not is_listy(last_target): last_target=[last_target]\n", | |
" self.count += first_el(last_target).size(0)\n", | |
" val = self.func(self.last_input, last_output, *last_target)\n", | |
" if self.world:\n", | |
" val = val.clone()\n", | |
" dist.all_reduce(val, op=dist.ReduceOp.SUM)\n", | |
" val /= self.world\n", | |
" self.val += first_el(last_target).size(0) * val.detach().cpu()\n", | |
"\n", | |
" def on_epoch_end(self, last_metrics, **kwargs):\n", | |
" \"Set the final result in `last_metrics`.\"\n", | |
" return add_metrics(last_metrics, self.val/self.count)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from metrics import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def ssim_fastai(xb,yb,_):\n", | |
" real_A, real_B = xb\n", | |
" fake_A, fake_B = yb[:,0,:,:,:], yb[:,1,:,:,:]\n", | |
" real_A = (real_A/2 + 0.5)*255\n", | |
" fake_B = (fake_B/2 + 0.5)*255\n", | |
" return ssim(real_A,fake_B)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def psnr_fastai(xb, yb, _):\n", | |
" real_A, real_B = xb\n", | |
" fake_A, fake_B = yb[:,0,:,:,:], yb[:,1,:,:,:]\n", | |
" real_A = (real_A/2 + 0.5)*255\n", | |
" fake_B = (fake_B/2 + 0.5)*255\n", | |
" return psnr(real_A,fake_B)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cycle_gan = CycleGAN(3,3, gen_blocks=9)\n", | |
"learn = Learner(data, cycle_gan, loss_func=CycleGanLoss(cycle_gan), opt_func=partial(optim.Adam, betas=(0.5,0.999)),\n", | |
" callback_fns=[CycleGANTrainer],metrics=[AverageMetric(ssim_fastai),AverageMetric(psnr_fastai)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#learn.lr_find()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#learn.recorder.plot()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>ssim_fastai</th>\n", | |
" <th>psnr_fastai</th>\n", | |
" <th>id_loss</th>\n", | |
" <th>gen_loss</th>\n", | |
" <th>cyc_loss</th>\n", | |
" <th>D_A_loss</th>\n", | |
" <th>D_B_loss</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>6.561883</td>\n", | |
" <td>5.135588</td>\n", | |
" <td>0.473679</td>\n", | |
" <td>9.387200</td>\n", | |
" <td>1.884625</td>\n", | |
" <td>0.796862</td>\n", | |
" <td>3.880396</td>\n", | |
" <td>0.324226</td>\n", | |
" <td>0.321400</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>5.647283</td>\n", | |
" <td>4.936441</td>\n", | |
" <td>0.465910</td>\n", | |
" <td>10.266727</td>\n", | |
" <td>1.607140</td>\n", | |
" <td>0.716910</td>\n", | |
" <td>3.323231</td>\n", | |
" <td>0.280169</td>\n", | |
" <td>0.277287</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>5.211876</td>\n", | |
" <td>4.717415</td>\n", | |
" <td>0.565417</td>\n", | |
" <td>9.830503</td>\n", | |
" <td>1.462108</td>\n", | |
" <td>0.719012</td>\n", | |
" <td>3.030756</td>\n", | |
" <td>0.264041</td>\n", | |
" <td>0.256112</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>4.929071</td>\n", | |
" <td>7.111959</td>\n", | |
" <td>0.502945</td>\n", | |
" <td>8.873150</td>\n", | |
" <td>1.369080</td>\n", | |
" <td>0.718374</td>\n", | |
" <td>2.841614</td>\n", | |
" <td>0.286366</td>\n", | |
" <td>0.260715</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>4.971952</td>\n", | |
" <td>4.743098</td>\n", | |
" <td>0.542775</td>\n", | |
" <td>9.094391</td>\n", | |
" <td>1.351116</td>\n", | |
" <td>0.769224</td>\n", | |
" <td>2.851610</td>\n", | |
" <td>0.250340</td>\n", | |
" <td>0.247486</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>5</td>\n", | |
" <td>4.926872</td>\n", | |
" <td>5.141806</td>\n", | |
" <td>0.515856</td>\n", | |
" <td>9.646178</td>\n", | |
" <td>1.288682</td>\n", | |
" <td>0.873950</td>\n", | |
" <td>2.764238</td>\n", | |
" <td>0.273015</td>\n", | |
" <td>0.241806</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>6</td>\n", | |
" <td>4.840660</td>\n", | |
" <td>4.774198</td>\n", | |
" <td>0.530930</td>\n", | |
" <td>8.995791</td>\n", | |
" <td>1.263278</td>\n", | |
" <td>0.897851</td>\n", | |
" <td>2.679529</td>\n", | |
" <td>0.246634</td>\n", | |
" <td>0.240409</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>7</td>\n", | |
" <td>4.715111</td>\n", | |
" <td>5.307607</td>\n", | |
" <td>0.485308</td>\n", | |
" <td>8.678684</td>\n", | |
" <td>1.200815</td>\n", | |
" <td>0.959975</td>\n", | |
" <td>2.554320</td>\n", | |
" <td>0.208429</td>\n", | |
" <td>0.203911</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>8</td>\n", | |
" <td>4.717269</td>\n", | |
" <td>5.407540</td>\n", | |
" <td>0.478808</td>\n", | |
" <td>8.845898</td>\n", | |
" <td>1.175990</td>\n", | |
" <td>1.029106</td>\n", | |
" <td>2.512172</td>\n", | |
" <td>0.190207</td>\n", | |
" <td>0.200359</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>9</td>\n", | |
" <td>4.771541</td>\n", | |
" <td>4.586983</td>\n", | |
" <td>0.521542</td>\n", | |
" <td>9.332252</td>\n", | |
" <td>1.168054</td>\n", | |
" <td>1.083328</td>\n", | |
" <td>2.520158</td>\n", | |
" <td>0.220892</td>\n", | |
" <td>0.198271</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>10</td>\n", | |
" <td>4.653002</td>\n", | |
" <td>5.640585</td>\n", | |
" <td>0.528441</td>\n", | |
" <td>9.376346</td>\n", | |
" <td>1.133995</td>\n", | |
" <td>1.054805</td>\n", | |
" <td>2.464202</td>\n", | |
" <td>0.203103</td>\n", | |
" <td>0.278971</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>11</td>\n", | |
" <td>4.518406</td>\n", | |
" <td>4.460074</td>\n", | |
" <td>0.426281</td>\n", | |
" <td>8.360112</td>\n", | |
" <td>1.089192</td>\n", | |
" <td>1.063930</td>\n", | |
" <td>2.365285</td>\n", | |
" <td>0.186630</td>\n", | |
" <td>0.224311</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>12</td>\n", | |
" <td>4.445597</td>\n", | |
" <td>4.410108</td>\n", | |
" <td>0.473550</td>\n", | |
" <td>8.817196</td>\n", | |
" <td>1.074073</td>\n", | |
" <td>1.043954</td>\n", | |
" <td>2.327569</td>\n", | |
" <td>0.207671</td>\n", | |
" <td>0.253871</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>13</td>\n", | |
" <td>4.358193</td>\n", | |
" <td>4.004363</td>\n", | |
" <td>0.533533</td>\n", | |
" <td>9.450353</td>\n", | |
" <td>1.067293</td>\n", | |
" <td>0.979354</td>\n", | |
" <td>2.311545</td>\n", | |
" <td>0.214134</td>\n", | |
" <td>0.234769</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>14</td>\n", | |
" <td>4.156375</td>\n", | |
" <td>3.787990</td>\n", | |
" <td>0.524728</td>\n", | |
" <td>9.441799</td>\n", | |
" <td>1.031719</td>\n", | |
" <td>0.918310</td>\n", | |
" <td>2.206345</td>\n", | |
" <td>0.236098</td>\n", | |
" <td>0.212539</td>\n", | |
" <td>00:29</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>15</td>\n", | |
" <td>4.053391</td>\n", | |
" <td>3.995172</td>\n", | |
" <td>0.527359</td>\n", | |
" <td>10.342698</td>\n", | |
" <td>1.005355</td>\n", | |
" <td>0.885043</td>\n", | |
" <td>2.162991</td>\n", | |
" <td>0.240513</td>\n", | |
" <td>0.207699</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>16</td>\n", | |
" <td>4.067924</td>\n", | |
" <td>3.625066</td>\n", | |
" <td>0.502665</td>\n", | |
" <td>9.238791</td>\n", | |
" <td>1.003777</td>\n", | |
" <td>0.911058</td>\n", | |
" <td>2.153088</td>\n", | |
" <td>0.246938</td>\n", | |
" <td>0.191794</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>17</td>\n", | |
" <td>4.037714</td>\n", | |
" <td>3.711069</td>\n", | |
" <td>0.464412</td>\n", | |
" <td>7.298216</td>\n", | |
" <td>1.000791</td>\n", | |
" <td>0.896303</td>\n", | |
" <td>2.140619</td>\n", | |
" <td>0.249679</td>\n", | |
" <td>0.187362</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>18</td>\n", | |
" <td>3.857249</td>\n", | |
" <td>4.345922</td>\n", | |
" <td>0.544652</td>\n", | |
" <td>9.388030</td>\n", | |
" <td>0.948343</td>\n", | |
" <td>0.891913</td>\n", | |
" <td>2.016993</td>\n", | |
" <td>0.257431</td>\n", | |
" <td>0.198414</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>19</td>\n", | |
" <td>3.762412</td>\n", | |
" <td>4.109817</td>\n", | |
" <td>0.490418</td>\n", | |
" <td>9.542933</td>\n", | |
" <td>0.926382</td>\n", | |
" <td>0.875323</td>\n", | |
" <td>1.960707</td>\n", | |
" <td>0.261851</td>\n", | |
" <td>0.210432</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>20</td>\n", | |
" <td>3.702829</td>\n", | |
" <td>3.526021</td>\n", | |
" <td>0.464647</td>\n", | |
" <td>8.307538</td>\n", | |
" <td>0.898579</td>\n", | |
" <td>0.908769</td>\n", | |
" <td>1.895481</td>\n", | |
" <td>0.261681</td>\n", | |
" <td>0.217534</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>21</td>\n", | |
" <td>3.622834</td>\n", | |
" <td>3.706043</td>\n", | |
" <td>0.502107</td>\n", | |
" <td>10.158941</td>\n", | |
" <td>0.875865</td>\n", | |
" <td>0.904829</td>\n", | |
" <td>1.842140</td>\n", | |
" <td>0.240720</td>\n", | |
" <td>0.208291</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>22</td>\n", | |
" <td>3.538983</td>\n", | |
" <td>3.435079</td>\n", | |
" <td>0.546980</td>\n", | |
" <td>9.465912</td>\n", | |
" <td>0.851039</td>\n", | |
" <td>0.919320</td>\n", | |
" <td>1.768624</td>\n", | |
" <td>0.258890</td>\n", | |
" <td>0.202507</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>23</td>\n", | |
" <td>3.513905</td>\n", | |
" <td>3.234569</td>\n", | |
" <td>0.508788</td>\n", | |
" <td>8.967988</td>\n", | |
" <td>0.841775</td>\n", | |
" <td>0.934261</td>\n", | |
" <td>1.737869</td>\n", | |
" <td>0.246404</td>\n", | |
" <td>0.231526</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>24</td>\n", | |
" <td>3.515435</td>\n", | |
" <td>3.423039</td>\n", | |
" <td>0.495962</td>\n", | |
" <td>8.955706</td>\n", | |
" <td>0.829034</td>\n", | |
" <td>0.963347</td>\n", | |
" <td>1.723054</td>\n", | |
" <td>0.217231</td>\n", | |
" <td>0.240800</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>25</td>\n", | |
" <td>3.495776</td>\n", | |
" <td>3.450833</td>\n", | |
" <td>0.489034</td>\n", | |
" <td>9.838599</td>\n", | |
" <td>0.812738</td>\n", | |
" <td>1.014734</td>\n", | |
" <td>1.668304</td>\n", | |
" <td>0.194369</td>\n", | |
" <td>0.248123</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>26</td>\n", | |
" <td>3.449779</td>\n", | |
" <td>3.365834</td>\n", | |
" <td>0.487408</td>\n", | |
" <td>9.713398</td>\n", | |
" <td>0.795532</td>\n", | |
" <td>1.049908</td>\n", | |
" <td>1.604338</td>\n", | |
" <td>0.181287</td>\n", | |
" <td>0.248181</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>27</td>\n", | |
" <td>3.393342</td>\n", | |
" <td>3.238819</td>\n", | |
" <td>0.534515</td>\n", | |
" <td>9.728602</td>\n", | |
" <td>0.766979</td>\n", | |
" <td>1.090326</td>\n", | |
" <td>1.536036</td>\n", | |
" <td>0.172006</td>\n", | |
" <td>0.235865</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>28</td>\n", | |
" <td>3.330395</td>\n", | |
" <td>3.262840</td>\n", | |
" <td>0.526032</td>\n", | |
" <td>9.696510</td>\n", | |
" <td>0.745259</td>\n", | |
" <td>1.104066</td>\n", | |
" <td>1.481069</td>\n", | |
" <td>0.159359</td>\n", | |
" <td>0.235580</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>29</td>\n", | |
" <td>3.290842</td>\n", | |
" <td>3.225084</td>\n", | |
" <td>0.528845</td>\n", | |
" <td>9.593530</td>\n", | |
" <td>0.723691</td>\n", | |
" <td>1.135712</td>\n", | |
" <td>1.431439</td>\n", | |
" <td>0.151975</td>\n", | |
" <td>0.231465</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"learn.fit_fa(lr=2e-4,n_epochs=15,n_epochs_decay=15)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.save('30fit')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's look at some results using `Learner.show_results`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.show_results(ds_type=DatasetType.Train, rows=2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now let's go through all the images of the training set and find the ones that are the best converted (according to our critics) or the worst converted." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(100, 344)" | |
] | |
}, | |
"execution_count": 38, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(learn.data.train_ds.items),len(learn.data.train_ds.itemsB)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_batch(filenames, tfms, **kwargs):\n", | |
" samples = [open_image(fn) for fn in filenames]\n", | |
" for s in samples: s = s.apply_tfms(tfms, **kwargs)\n", | |
" batch = torch.stack([s.data for s in samples], 0).cuda()\n", | |
" return 2. * (batch - 0.5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fnames = learn.data.train_ds.items[:8]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = get_batch(fnames, get_transforms()[1], size=128)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.model.eval()\n", | |
"tfms = get_transforms()[1]\n", | |
"bs = 16" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_losses(fnames, gen, crit, bs=16):\n", | |
" losses_in,losses_out = [],[]\n", | |
" with torch.no_grad():\n", | |
" for i in progress_bar(range(0, len(fnames), bs)):\n", | |
" xb = get_batch(fnames[i:i+bs], tfms, size=128)\n", | |
" fakes = gen(xb)\n", | |
" preds_in,preds_out = crit(xb),crit(fakes)\n", | |
" loss_in = learn.loss_func.crit(preds_in, True,reduction='none')\n", | |
" loss_out = learn.loss_func.crit(preds_out,True,reduction='none')\n", | |
" losses_in.append(loss_in.view(loss_in.size(0),-1).mean(1))\n", | |
" losses_out.append(loss_out.view(loss_out.size(0),-1).mean(1))\n", | |
" return torch.cat(losses_in),torch.cat(losses_out)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
" </style>\n", | |
" <progress value='7' class='' max='7', style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" 100.00% [7/7 00:03<00:00]\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"losses_A = get_losses(data.train_ds.x.items, learn.model.G_B, learn.model.D_B)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
" </style>\n", | |
" <progress value='22' class='' max='22', style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" 100.00% [22/22 00:13<00:00]\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"losses_B = get_losses(data.train_ds.x.itemsB, learn.model.G_A, learn.model.D_A)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def show_best(fnames, losses, gen, n=8):\n", | |
" sort_idx = losses.argsort()\n", | |
" _,axs = plt.subplots(n//2, 4, figsize=(12,2*n))\n", | |
" xb = get_batch(fnames[sort_idx][:n], tfms, size=128)\n", | |
" with torch.no_grad():\n", | |
" fakes = gen(xb)\n", | |
" xb,fakes = (1+xb)/2,(1+fakes)/2\n", | |
" for i in range(n):\n", | |
" axs.flatten()[2*i].imshow(xb[i].permute(1,2,0).cpu())\n", | |
" axs.flatten()[2*i].axis('off')\n", | |
" axs.flatten()[2*i+1].imshow(fakes[i].permute(1,2,0).cpu())\n", | |
" axs.flatten()[2*i+1].set_title(losses[sort_idx][i].item())\n", | |
" axs.flatten()[2*i+1].axis('off')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"show_best(data.train_ds.x.items, losses_A[1].cpu(), learn.model.G_B)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"show_best(data.train_ds.x.itemsB, losses_B[1].cpu(), learn.model.G_A)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Save Image Predictions (Fake images)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils.data import Dataset, DataLoader\n", | |
"import torchvision\n", | |
"import glob\n", | |
"\n", | |
"class FolderDataset(Dataset):\n", | |
" def __init__(self, path,transforms=None):\n", | |
" self.files = glob.glob(path+'/*')\n", | |
" self.totensor = torchvision.transforms.ToTensor()\n", | |
" if transforms:\n", | |
" self.transform = torchvision.transforms.Compose(transforms)\n", | |
" else:\n", | |
" self.transform = lambda x: x\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.files)\n", | |
"\n", | |
" def __getitem__(self, idx):\n", | |
" image = PIL.Image.open(self.files[idx % len(self.files)])\n", | |
" image = self.totensor(image)\n", | |
" image = self.transform(image)\n", | |
" return self.files[idx], image\n", | |
"\n", | |
"def load_dataset(test_path):\n", | |
" dataset = FolderDataset(\n", | |
" path=test_path,\n", | |
" #transforms=[torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", | |
" ) \n", | |
" loader = torch.utils.data.DataLoader(\n", | |
" dataset,\n", | |
" batch_size=2,\n", | |
" num_workers=4,\n", | |
" shuffle=True\n", | |
" )\n", | |
" return loader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tqdm\n", | |
"def get_preds_cyclegan(learn,test_path,pred_path,suffix='png'):\n", | |
" \n", | |
" assert os.path.exists(test_path)\n", | |
" \n", | |
" if not os.path.exists(pred_path):\n", | |
" os.mkdir(pred_path)\n", | |
" \n", | |
" model = learn.model.G_A\n", | |
" \n", | |
" test_dl = load_dataset(test_path)\n", | |
" \n", | |
" for i, xb in tqdm.tqdm(enumerate(test_dl),total=len(test_dl)):\n", | |
" fn, im = xb\n", | |
" preds = (learn.model.G_B(im.cuda())/2 + 0.5)\n", | |
" for i in range(len(fn)):\n", | |
" new_fn = os.path.join(pred_path,'.'.join([os.path.basename(fn[i]).split('.')[0]+'_fakeB',suffix])) \n", | |
" torchvision.utils.save_image(preds[i],new_fn)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 50/50 [00:12<00:00, 3.86it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"get_preds_cyclegan(learn,str(muse2he_path/'testA'),'./preds')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment