Last active
May 10, 2020 14:00
-
-
Save guillefix/23bff068bdc457649b81027942873ce5 to your computer and use it in GitHub Desktop.
temporal workaround to get Conv2dLocal to work in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# In[1]: | |
import math | |
import torch | |
from torch.nn.parameter import Parameter | |
import torch.nn.functional as F | |
import torch.nn as nn | |
Module = nn.Module | |
import collections | |
from itertools import repeat | |
# In[2]: | |
def _ntuple(n): | |
def parse(x): | |
if isinstance(x, collections.Iterable): | |
return x | |
return tuple(repeat(x, n)) | |
return parse | |
_single = _ntuple(1) | |
_pair = _ntuple(2) | |
_triple = _ntuple(3) | |
_quadruple = _ntuple(4) | |
# In[3]: | |
class _ConvNd(Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride, | |
padding, dilation, transposed, output_padding, groups, bias): | |
super(_ConvNd, self).__init__() | |
if in_channels % groups != 0: | |
raise ValueError('in_channels must be divisible by groups') | |
if out_channels % groups != 0: | |
raise ValueError('out_channels must be divisible by groups') | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.transposed = transposed | |
self.output_padding = output_padding | |
self.groups = groups | |
if transposed: | |
self.weight = Parameter(torch.Tensor( | |
in_channels, out_channels // groups, *kernel_size)) | |
else: | |
self.weight = Parameter(torch.Tensor( | |
out_channels, in_channels // groups, *kernel_size)) | |
if bias: | |
self.bias = Parameter(torch.Tensor(out_channels)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self): | |
n = self.in_channels | |
for k in self.kernel_size: | |
n *= k | |
stdv = 1. / math.sqrt(n) | |
self.weight.data.uniform_(-stdv, stdv) | |
if self.bias is not None: | |
self.bias.data.uniform_(-stdv, stdv) | |
def __repr__(self): | |
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' | |
', stride={stride}') | |
if self.padding != (0,) * len(self.padding): | |
s += ', padding={padding}' | |
if self.dilation != (1,) * len(self.dilation): | |
s += ', dilation={dilation}' | |
if self.output_padding != (0,) * len(self.output_padding): | |
s += ', output_padding={output_padding}' | |
if self.groups != 1: | |
s += ', groups={groups}' | |
if self.bias is None: | |
s += ', bias=False' | |
s += ')' | |
return s.format(name=self.__class__.__name__, **self.__dict__) | |
# In[4]: | |
class Conv2dLocal(Module): | |
def __init__(self, in_height, in_width, in_channels, out_channels, | |
kernel_size, stride=1, padding=0, bias=True, dilation=1): | |
super(Conv2dLocal, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = _pair(kernel_size) | |
self.stride = _pair(stride) | |
self.padding = _pair(padding) | |
self.dilation = _pair(dilation) | |
self.in_height = in_height | |
self.in_width = in_width | |
self.out_height = int(math.floor( | |
(in_height + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)) | |
self.out_width = int(math.floor( | |
(in_width + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)) | |
self.weight = Parameter(torch.Tensor( | |
self.out_height, self.out_width, | |
out_channels, in_channels, *self.kernel_size)) | |
if bias: | |
self.bias = Parameter(torch.Tensor( | |
out_channels, self.out_height, self.out_width)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self): | |
n = self.in_channels | |
for k in self.kernel_size: | |
n *= k | |
stdv = 1. / math.sqrt(n) | |
self.weight.data.uniform_(-stdv, stdv) | |
if self.bias is not None: | |
self.bias.data.uniform_(-stdv, stdv) | |
def __repr__(self): | |
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' | |
', stride={stride}') | |
if self.padding != (0,) * len(self.padding): | |
s += ', padding={padding}' | |
if self.dilation != (1,) * len(self.dilation): | |
s += ', dilation={dilation}' | |
if self.bias is None: | |
s += ', bias=False' | |
s += ')' | |
return s.format(name=self.__class__.__name__, **self.__dict__) | |
def forward(self, input): | |
return conv2d_local( | |
input, self.weight, self.bias, stride=self.stride, | |
padding=self.padding, dilation=self.dilation) | |
# In[5]: | |
unfold = F.unfold | |
# In[6]: | |
def conv2d_local(input, weight, bias=None, padding=0, stride=1, dilation=1): | |
if input.dim() != 4: | |
raise NotImplementedError("Input Error: Only 4D input Tensors supported (got {}D)".format(input.dim())) | |
if weight.dim() != 6: | |
# outH x outW x outC x inC x kH x kW | |
raise NotImplementedError("Input Error: Only 6D weight Tensors supported (got {}D)".format(weight.dim())) | |
outH, outW, outC, inC, kH, kW = weight.size() | |
kernel_size = (kH, kW) | |
# N x [inC * kH * kW] x [outH * outW] | |
cols = unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride) | |
cols = cols.view(cols.size(0), cols.size(1), cols.size(2), 1).permute(0, 2, 3, 1) | |
out = torch.matmul(cols, weight.view(outH * outW, outC, inC * kH * kW).permute(0, 2, 1)) | |
out = out.view(cols.size(0), outH, outW, outC).permute(0, 3, 1, 2) | |
if bias is not None: | |
out = out + bias.expand_as(out) | |
return out | |
# In[8]: | |
# lc = Conv2dLocal(3, 3, 64, 2,3) | |
# In[9]: | |
# lc(torch.autograd.Variable(torch.randn((1,64,3,3)))) | |
# In[43]: | |
# x=torch.autograd.Variable(torch.randn((64,6,6))) | |
# In[47]: | |
# lc._backend.SpatialConvolutionLocal?? | |
# In[58]: | |
# from torch.nn import Conv2dLocal | |
# In[59]: | |
# Conv2dLocal?? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
HI, why I do not have the F.unfold function? And, why this line
works? I guest the cols is of 4 dimensions after permutation, of shape(N x [outH * outW] x 1 x [inC * kH * kW] ); and the weight is of 3 dimensions of after permutation, of shape ( [outH x outW] x [inC x kH x kW] x outC). If I am not misunderstanding, the two elements of the torch.matmul() should be of the same dimension?