Skip to content

Instantly share code, notes, and snippets.

@system123
Created March 12, 2019 16:32
Show Gist options
  • Save system123/c4b8ef3824f2230f181f8cfba84f0cfd to your computer and use it in GitHub Desktop.
Save system123/c4b8ef3824f2230f181f8cfba84f0cfd to your computer and use it in GitHub Desktop.
Fast Dense CNN Features
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from math import ceil, floor
from functools import reduce
# https://www.dfki.de/fileadmin/user_upload/import/9245_FastCNNFeature_BMVC.pdf
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class LambdaReduce(LambdaBase):
def forward(self, input):
# result is a Variable
return reduce(self.lambda_func, self.forward_prepare(input))
class Unsqueeze(nn.Module):
def __init__(self, dim):
super(Unsqueeze, self).__init__()
self.dim = dim
def forward(self, x):
return torch.unsqueeze(x, self.dim)
class MultiPoolPrepare(nn.Module):
def __init__(self, patchX, patchY):
super(MultiPoolPrepare, self).__init__()
padx = patchX - 1
pady = patchY - 1
self.net = nn.Sequential(
nn.ZeroPad2d(( padx//2 , padx//2, pady//2, pady//2)),
Unsqueeze(3)
)
def forward(self, x):
return self.net(x)
class UnwarpPrepare(nn.Module):
def forward(self, x):
y = x.view(x.shape[0], x.shape[1], -1)
y = torch.transpose(y, 1, 2)
return y
class UnwarpPool(nn.Module):
def __init__(self, out_channels, img_w, img_h, dW, dH):
super(UnwarpPool, self).__init__()
self.img_w = img_w
self.img_h = img_h
self.dW = dW
self.dH = dH
self.out_ch = out_channels
def forward(self, x):
y = x.view(x.shape[0], self.out_ch, self.img_h, self.img_w, self.dH, self.dW, -1)
y = torch.transpose(y, 3, 4)
return y
def calculate_pad(hei, wid, hei_pad, wid_pad):
left = (wid_pad - wid) // 2
right = wid_pad - wid - left
top = (hei_pad - hei) // 2
bottom = hei_pad - hei - top
return [left, right, top, bottom]
# concat 3D/4D variable
# dim = 0: 3D; = 1: 4D
def concat_with_pad(seq, dim):
# get maximum size
hei_pad = 0
wid_pad = 0
for input in seq:
hei_pad = max(hei_pad, input.size(dim + 1))
wid_pad = max(hei_pad, input.size(dim + 2))
# pad each input
output = []
for input in seq:
pad = calculate_pad(input.size(dim + 1), input.size(dim + 2), hei_pad, wid_pad)
input_pad = torch.nn.functional.pad(input, pad)
output.append(input_pad)
return torch.cat(output, dim)
class MultiMaxPooling(nn.Module):
def __init__(self, kW, kH, dW, dH):
super(MultiMaxPooling, self).__init__()
self.kW = kW
self.kH = kH #Kernel size
self.dW = dW #Step size (stride)
self.dH = dH
pools = []
for i in range(0, self.dH):
for j in range(0, self.dW):
pools.append( nn.MaxPool2d( (self.kH, self.kW), stride=(self.dW, self.dH), padding=(-j, -i) ) )
self.net = LambdaReduce(lambda x, y, dim=1: concat_with_pad((x, y), dim), *pools)
def forward(self, x):
return self.net(x)
if __name__=="__main__":
from vgg11_ae import VGG11Encoder
import numpy as np
all_layers = []
def remove_sequential(network):
for layer in network.children():
if type(layer) == nn.Sequential: # if sequential layer, apply recursively to layers in sequential layer
remove_sequential(layer)
if list(layer.children()) == []: # if leaf node, add it to list
all_layers.append(layer)
batch = torch.ones((1, 1, 256, 256))
vgg = VGG11Encoder(z_dim=256)
remove_sequential(vgg)
n = nn.ModuleList([MultiPoolPrepare(128, 128)])
strides = []
for m in all_layers:
if isinstance(m, nn.MaxPool2d):
k = m.kernel_size[0] if isinstance(m.kernel_size, (list, tuple)) else m.kernel_size
s = m.stride[0] if isinstance(m.stride, (list, tuple)) else m.stride
strides.append(s)
n.append(MultiMaxPooling( k, k, s, s))
else:
n.append(m)
n.append( UnwarpPrepare() )
for i in range(len(strides), 0, -1):
n.append( UnwarpPool(256, 256/np.prod(strides[:i]), 256/np.prod(strides[:i]), strides[i-1], strides[i-1]) )
net = nn.Sequential(*n)
print(net)
img = net(batch)
img = img.view(img.shape[0], 256, 256)
print(img.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment