Skip to content

Instantly share code, notes, and snippets.

@Flock1
Created September 21, 2020 09:39
Show Gist options
  • Save Flock1/eaf1740fe4b0bd63f96dfdb93770ba71 to your computer and use it in GitHub Desktop.
Save Flock1/eaf1740fe4b0bd63f96dfdb93770ba71 to your computer and use it in GitHub Desktop.
class UpSample(nn.Module):
def __init__(self,feat_in,feat_out,out_shape=None,scale=2):
super().__init__()
self.conv = nn.Conv2d(feat_in,feat_out,kernel_size=(3,3),stride=1,padding=1)
self.out_shape,self.scale = out_shape,scale
def forward(self,x):
return self.conv(
nn.functional.interpolate(
x,size=self.out_shape,scale_factor=self.scale,mode='bilinear',align_corners=True))
def get_upSamp(feat_in,feat_out, out_shape=None, scale=2, act='relu'):
upSamp = UpSample(feat_in,feat_out,out_shape=out_shape,scale=scale).cuda()
layer = nn.Sequential(upSamp)
if act == 'relu':
act_f = nn.ReLU(inplace=True).cuda()
bn = nn.BatchNorm2d(feat_out).cuda()
layer.add_module('ReLU',act_f)
layer.add_module('BN',bn)
elif act == 'sig':
act_f = nn.Sigmoid()
layer.add_module('Sigmoid',act_f)
return layer
def add_layer(m,feat_in,feat_out,name,out_shape=None,scale=2,act='relu'):
upSamp = get_upSamp(feat_in,feat_out,out_shape=out_shape,scale=scale,act=act)
m.add_module(name,upSamp)
m = nn.Sequential(*list(m.children())[:-3])
code_sz = 32
conv = nn.Conv2d(256, code_sz, kernel_size=(2,2)).cuda()
m.add_module('CodeIn',conv)
m._modules['0'] = nn.Conv2d(1, 64, kernel_size=(7,7),stride=2,padding=1).cuda()
add_layer(m,code_sz,256,'CodeOut',out_shape=(64,64),scale=None)
add_layer(m,256,128,'Upsample0')
add_layer(m,128,64,'Upsample1')
add_layer(m,64,32,'Upsample2')
add_layer(m,32,1,'Upsample3',act='sig', scale=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment