Last active
May 11, 2021 02:54
-
-
Save amohant4/9e3b1d9fa1e5e083e8334300ba918b49 to your computer and use it in GitHub Desktop.
Implementation of octave convolution in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class OctConv(nn.Module): | |
def __init__(self, ch_in, ch_out, kernel_size, stride=1, alphas=[0.5,0.5]): | |
super(OctConv, self).__init__() | |
# Get layer parameters | |
self.alpha_in, self.alpha_out = alphas | |
assert 0 <= self.alpha_in <= 1 and 0 <= self.alpha_in <= 1, \ | |
"Alphas must be in interval [0, 1]" | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = (kernel_size - stride ) // 2 | |
# Calculate the exact number of high/low frequency channels | |
self.ch_in_lf = int(self.alpha_in*ch_in) | |
self.ch_in_hf = ch_in - self.ch_in_lf | |
self.ch_out_lf = int(self.alpha_out*ch_out) | |
self.ch_out_hf = ch_out - self.ch_out_lf | |
# Create convolutional and other modules necessary. Not all paths | |
# will be created in call cases. So we check number of high/low freq | |
# channels in input/output to determine which paths are present. | |
# Example: First layer has alpha_in = 0, so hasLtoL and hasLtoH (bottom | |
# two paths) will be false in this case. | |
self.hasLtoL = self.hasLtoH = self.hasHtoL = self.hasHtoH = False | |
if (self.ch_in_lf and self.ch_out_lf): | |
# Green path at bottom. | |
self.hasLtoL = True | |
self.conv_LtoL = nn.Conv2d(self.ch_in_lf, self.ch_out_lf, \ | |
self.kernel_size, padding=self.padding) | |
if (self.ch_in_lf and self.ch_out_hf): | |
# Red path at bottom. | |
self.hasLtoH = True | |
self.conv_LtoH = nn.Conv2d(self.ch_in_lf, self.ch_out_hf, \ | |
self.kernel_size, padding=self.padding) | |
if (self.ch_in_hf and self.ch_out_lf): | |
# Red path at top | |
self.hasHtoL = True | |
self.conv_HtoL = nn.Conv2d(self.ch_in_hf, self.ch_out_lf, \ | |
self.kernel_size, padding=self.padding) | |
if (self.ch_in_hf and self.ch_out_hf): | |
# Green path at top | |
self.hasHtoH = True | |
self.conv_HtoH = nn.Conv2d(self.ch_in_hf, self.ch_out_hf, \ | |
self.kernel_size, padding=self.padding) | |
self.avg_pool = nn.AvgPool2d(2,2) | |
def forward(self, input): | |
# Split input into high frequency and low frequency components | |
fmap_w = input.shape[-1] | |
fmap_h = input.shape[-2] | |
# We resize the high freqency components to the same size as the low | |
# frequency component when sending out as output. So when bringing in as | |
# input, we want to reshape it to have the original size as the intended | |
# high frequnecy channel (if any high frequency component is available). | |
input_hf = input | |
if (self.ch_in_lf): | |
input_hf = input[:,:self.ch_in_hf*4,:,:].reshape(-1, \ | |
self.ch_in_hf,fmap_h*2,fmap_w*2) | |
input_lf = input[:,self.ch_in_hf*4:,:,:] | |
# Create all conditional branches | |
LtoH = HtoH = LtoL = HtoL = 0. | |
if (self.hasLtoL): | |
# Since, there is no change in spatial dimensions between input and | |
# output, we use vanilla convolution | |
LtoL = self.conv_LtoL(input_lf) | |
if (self.hasHtoH): | |
# Since, there is no change in spatial dimensions between input and | |
# output, we use vanilla convolution | |
HtoH = self.conv_HtoH(input_hf) | |
# We want the high freq channels and low freq channels to be | |
# packed together such that the output has one dimension. This | |
# enables octave convolution to be used as is with other layers | |
# like Relu, elementwise etc. So, we fold the high-freq channels | |
# to make its height and width same as the low-freq channels. So, | |
# h = h/2 and w = w/2 since we are making h and w smaller by a | |
# factor of 2, the number of channels increases by 4. | |
op_h, op_w = HtoH.shape[-2]//2, HtoH.shape[-1]//2 | |
HtoH = HtoH.reshape(-1, self.ch_out_hf*4, op_h, op_w) | |
if (self.hasLtoH): | |
# Since, the spatial dimension has to go up, we do | |
# bilinear interpolation to increase the size of output | |
# feature maps | |
LtoH = F.interpolate(self.conv_LtoH(input_lf), \ | |
scale_factor=2, mode='bilinear') | |
# We want the high freq channels and low freq channels to be | |
# packed together such that the output has one dimension. This | |
# enables octave convolution to be used as is with other layers | |
# like Relu, elementwise etc. So, we fold the high-freq channels | |
# to make its height and width same as the low-freq channels. So, | |
# h = h/2 and w = w/2 since we are making h and w smaller by a | |
# factor of 2, the number of channels increases by 4. | |
op_h, op_w = LtoH.shape[-2]//2, LtoH.shape[-1]//2 | |
LtoH = LtoH.reshape(-1, self.ch_out_hf*4, op_h, op_w) | |
if (self.hasHtoL): | |
# Since, the spatial dimension has to go down here, we do | |
# average pooling to reduce the height and width of output | |
# feature maps by a factor of 2 | |
HtoL = self.avg_pool(self.conv_HtoL(input_hf)) | |
# Elementwise addition of high and low freq branches to get the output | |
out_hf = LtoH + HtoH | |
out_lf = LtoL + HtoL | |
# Since, not all paths are always present, we need to put a check | |
# on how the output is generated. Example: the final convolution layer | |
# will have alpha_out == 0, so no low freq. output channels, | |
# so the layers returns just the high freq. components. If there are no | |
# high freq component then we send out the low freq channels (we have it | |
# just to have a general module even though this scenerio has not been | |
# used by the authors). If both low and high freq components are present, | |
# we concat them (we have already resized them to be of the same dimension) | |
# and send them out. | |
if (self.ch_out_lf == 0): | |
return out_hf | |
if (self.ch_out_hf == 0): | |
return out_lf | |
op = torch.cat([out_hf,out_lf],dim=1) | |
return op |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment