This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def forward(self, input, target): | |
| res = [F.l1_loss(input, target) / 100] # pixel loss | |
| in_fts = self.make_fts(input, True) # extracting activations for predicted image | |
| out_fts = self.make_fts(target) # extracting activations for ground-truth image | |
| res+= [F.l1_loss(inp, targ.features) * w | |
| for inp, targ, w in zip(in_fts, out_fts, self.wgts)] # Perceptual loss | |
| return sum(res) | |
| def close(self): | |
| for o in self.sfs: o.remove() # This removes the activations from your memory. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class SavedFeatures(): | |
| def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn) | |
| def hook_fn(self, model, input, output): self.features = output | |
| def remove(self): self.hook.remove() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class FeatureLoss(nn.Module): | |
| def __init__(self, m, layer_ids, layer_wgts): | |
| super().__init__() | |
| self.m, self.wgts = m, layer_wgts | |
| self.sfs = [SavedFeatures(m[i]) for i in blocks] | |
| def make_fts(self, x, clone=False): | |
| self.m(x) | |
| return [(o.features.data.clone() if clone else o) for o in self.sfs] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| m = SrResnet(scale) | |
| learner = Learner(data, m, opt_func= optim.Adam, wd=0.9) | |
| learner.loss_func = F.mse_loss | |
| #----------------------- | |
| lr = 1e-4 # learning rate | |
| cycles = 8 # cycle_len following the 1cycle policy | |
| #----------------------- | |
| learner.fit_one_cycle(cycles, lr, pct_start=0.9) | |
| learner.show_results(rows=1, imgsize = 5) # shows the model input/prediction/ground-truth |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class SrResnet(nn.Module): | |
| def __init__(self, scale, nf = 64): | |
| super().__init__() | |
| features = [conv(3, nf)] # conv -> relu | |
| for i in range(8): features.append(res_block(nf)) | |
| features += [conv(nf, nf), upsample(nf, nf, scale), | |
| nn.BatchNorm2d(nf), | |
| conv(nf, 3, actn=False)] | |
| self.features = nn.Sequential(*features) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def res_block(nf): | |
| return ResSequential( | |
| [conv(nf, nf), conv(nf, nf, actn=False)], | |
| 0.1) # conv-> Relu-> conv | |
| #---------------------- | |
| def upsample(ni, nf, scale): | |
| layers = [] | |
| for i in range(int(math.log(scale,2))): | |
| layers += [conv(ni, nf*4), nn.PixelShuffle(2)] | |
| return nn.Sequential(*layers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def conv(ni, nf, kernel_size = 3, actn=True): | |
| layers = [nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2)] | |
| if actn: layers.append(nn.ReLU(True)) | |
| return nn.Sequential(*layers) | |
| #---------------------- | |
| class ResSequential(nn.Module): | |
| def __init__(self, layers, res_scale=1.0): | |
| super().__init__() | |
| self.res_scale = res_scale # A factor less than 1.0 will help stabelize training | |
| self.layers = nn.Sequential(*layers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| data = get_data(bs, (sz_lr, sz_lr*scale)) | |
| data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| scale, bs = 2, 8 | |
| # scale, bs = 4, 4 | |
| sz_lr = 256 | |
| src = ImageImageList.from_folder(path).split_by_rand_pct(0.1, seed = 42) | |
| def get_data(bs, size): | |
| data = (src.label_from_func(lambda x: path/x.name) | |
| .transform(get_transforms(do_flip=True, max_rotate=20), | |
| size=size[0],tfm_y =True) | |
| .transform_y(size = size[1]) | |
| .databunch(bs=bs, no_check = True).normalize(imagenet_stats, do_y=True)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ## Numpy implementation of GRU | |
| # Params | |
| ht = [] | |
| for t, xt in enumerate(train_data): | |
| if t == 0: h.append(t) | |
| if t >= 1: | |
| ht_1 = ht[t-1] | |
| zt = sigmoid(np.dot(xt, Wz) + np.dot(ht_1, Uz) + bz) # Update Gate Vector | |
| rt = sigmoid(np.dot(xt, Wr)+ np.dot(ht_1, Ur) + br) # Forget/Reset Gate Vector | |
| ht.append((1 - z) * ht_1 + z * np.tanh(np.dot(xt, Wh) |