Skip to content

Instantly share code, notes, and snippets.

@enochkan
Created June 12, 2020 05:11
Show Gist options
  • Save enochkan/2f7ce86dd977fb6917e1e95d77fbc814 to your computer and use it in GitHub Desktop.
Save enochkan/2f7ce86dd977fb6917e1e95d77fbc814 to your computer and use it in GitHub Desktop.
Vox2Vox Encoder, Bottleneck and Decoder blocks
#***********************#
#***Code by:************#
#***Chi Nok Enoch Kan***#
#***********************#
#*******<(^.^)>*********#
#***********************#
#*****Encoder Block*****#
#***********************#
#***********************#
#***********************#
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
layers = [nn.Conv3d(in_size, out_size, 4, 2, 1, bias=False)]
if normalize:
layers.append(nn.InstanceNorm3d(out_size))
layers.append(nn.LeakyReLU(0.2))
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
#*****Bottleneck Block*****#
class UNetMid(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetMid, self).__init__()
layers = [
nn.Conv3d(in_size, out_size, 4, 1, 1, bias=False),
nn.InstanceNorm3d(out_size),
nn.LeakyReLU(0.2)
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
# print(x.shape)
x = torch.cat((x, skip_input), 1)
x = self.model(x)
x = nn.functional.pad(x, (1,0,1,0,1,0))
return x
#*****Decoder Block*****#
class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).__init__()
layers = [
nn.ConvTranspose3d(in_size, out_size, 4, 2, 1, bias=False),
nn.InstanceNorm3d(out_size),
nn.ReLU(inplace=True),
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment