Skip to content

Instantly share code, notes, and snippets.

@enochkan
Created June 12, 2020 05:41
Show Gist options
  • Save enochkan/b6ac9c9bfc00400afe6d49ac04596cca to your computer and use it in GitHub Desktop.
Save enochkan/b6ac9c9bfc00400afe6d49ac04596cca to your computer and use it in GitHub Desktop.
vox2vox generator
class GeneratorUNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super(GeneratorUNet, self).__init__()
self.down1 = UNetDown(in_channels, 64, normalize=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512)
self.mid1 = UNetMid(1024, 512, dropout=0.2)
self.mid2 = UNetMid(1024, 512, dropout=0.2)
self.mid3 = UNetMid(1024, 512, dropout=0.2)
self.mid4 = UNetMid(1024, 256, dropout=0.2)
self.up1 = UNetUp(256, 256)
self.up2 = UNetUp(512, 128)
self.up3 = UNetUp(256, 64)
# self.us = nn.Upsample(scale_factor=2)
self.final = nn.Sequential(
nn.ConvTranspose3d(128, out_channels, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
m1 = self.mid1(d4, d4)
m2 = self.mid2(m1, m1)
m3 = self.mid3(m2, m2)
m4 = self.mid4(m3, m3)
u1 = self.up1(m4, d3)
u2 = self.up2(u1, d2)
u3 = self.up3(u2, d1)
return self.final(u3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment