Created
March 12, 2019 16:32
-
-
Save system123/c4b8ef3824f2230f181f8cfba84f0cfd to your computer and use it in GitHub Desktop.
Fast Dense CNN Features
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torchvision import models | |
from math import ceil, floor | |
from functools import reduce | |
# https://www.dfki.de/fileadmin/user_upload/import/9245_FastCNNFeature_BMVC.pdf | |
class LambdaBase(nn.Sequential): | |
def __init__(self, fn, *args): | |
super(LambdaBase, self).__init__(*args) | |
self.lambda_func = fn | |
def forward_prepare(self, input): | |
output = [] | |
for module in self._modules.values(): | |
output.append(module(input)) | |
return output if output else input | |
class LambdaReduce(LambdaBase): | |
def forward(self, input): | |
# result is a Variable | |
return reduce(self.lambda_func, self.forward_prepare(input)) | |
class Unsqueeze(nn.Module): | |
def __init__(self, dim): | |
super(Unsqueeze, self).__init__() | |
self.dim = dim | |
def forward(self, x): | |
return torch.unsqueeze(x, self.dim) | |
class MultiPoolPrepare(nn.Module): | |
def __init__(self, patchX, patchY): | |
super(MultiPoolPrepare, self).__init__() | |
padx = patchX - 1 | |
pady = patchY - 1 | |
self.net = nn.Sequential( | |
nn.ZeroPad2d(( padx//2 , padx//2, pady//2, pady//2)), | |
Unsqueeze(3) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class UnwarpPrepare(nn.Module): | |
def forward(self, x): | |
y = x.view(x.shape[0], x.shape[1], -1) | |
y = torch.transpose(y, 1, 2) | |
return y | |
class UnwarpPool(nn.Module): | |
def __init__(self, out_channels, img_w, img_h, dW, dH): | |
super(UnwarpPool, self).__init__() | |
self.img_w = img_w | |
self.img_h = img_h | |
self.dW = dW | |
self.dH = dH | |
self.out_ch = out_channels | |
def forward(self, x): | |
y = x.view(x.shape[0], self.out_ch, self.img_h, self.img_w, self.dH, self.dW, -1) | |
y = torch.transpose(y, 3, 4) | |
return y | |
def calculate_pad(hei, wid, hei_pad, wid_pad): | |
left = (wid_pad - wid) // 2 | |
right = wid_pad - wid - left | |
top = (hei_pad - hei) // 2 | |
bottom = hei_pad - hei - top | |
return [left, right, top, bottom] | |
# concat 3D/4D variable | |
# dim = 0: 3D; = 1: 4D | |
def concat_with_pad(seq, dim): | |
# get maximum size | |
hei_pad = 0 | |
wid_pad = 0 | |
for input in seq: | |
hei_pad = max(hei_pad, input.size(dim + 1)) | |
wid_pad = max(hei_pad, input.size(dim + 2)) | |
# pad each input | |
output = [] | |
for input in seq: | |
pad = calculate_pad(input.size(dim + 1), input.size(dim + 2), hei_pad, wid_pad) | |
input_pad = torch.nn.functional.pad(input, pad) | |
output.append(input_pad) | |
return torch.cat(output, dim) | |
class MultiMaxPooling(nn.Module): | |
def __init__(self, kW, kH, dW, dH): | |
super(MultiMaxPooling, self).__init__() | |
self.kW = kW | |
self.kH = kH #Kernel size | |
self.dW = dW #Step size (stride) | |
self.dH = dH | |
pools = [] | |
for i in range(0, self.dH): | |
for j in range(0, self.dW): | |
pools.append( nn.MaxPool2d( (self.kH, self.kW), stride=(self.dW, self.dH), padding=(-j, -i) ) ) | |
self.net = LambdaReduce(lambda x, y, dim=1: concat_with_pad((x, y), dim), *pools) | |
def forward(self, x): | |
return self.net(x) | |
if __name__=="__main__": | |
from vgg11_ae import VGG11Encoder | |
import numpy as np | |
all_layers = [] | |
def remove_sequential(network): | |
for layer in network.children(): | |
if type(layer) == nn.Sequential: # if sequential layer, apply recursively to layers in sequential layer | |
remove_sequential(layer) | |
if list(layer.children()) == []: # if leaf node, add it to list | |
all_layers.append(layer) | |
batch = torch.ones((1, 1, 256, 256)) | |
vgg = VGG11Encoder(z_dim=256) | |
remove_sequential(vgg) | |
n = nn.ModuleList([MultiPoolPrepare(128, 128)]) | |
strides = [] | |
for m in all_layers: | |
if isinstance(m, nn.MaxPool2d): | |
k = m.kernel_size[0] if isinstance(m.kernel_size, (list, tuple)) else m.kernel_size | |
s = m.stride[0] if isinstance(m.stride, (list, tuple)) else m.stride | |
strides.append(s) | |
n.append(MultiMaxPooling( k, k, s, s)) | |
else: | |
n.append(m) | |
n.append( UnwarpPrepare() ) | |
for i in range(len(strides), 0, -1): | |
n.append( UnwarpPool(256, 256/np.prod(strides[:i]), 256/np.prod(strides[:i]), strides[i-1], strides[i-1]) ) | |
net = nn.Sequential(*n) | |
print(net) | |
img = net(batch) | |
img = img.view(img.shape[0], 256, 256) | |
print(img.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment