Skip to content

Instantly share code, notes, and snippets.

@dvgodoy
Last active April 30, 2022 09:47
Show Gist options
  • Save dvgodoy/9bdd0bf1d22797b503a27cfc7298d765 to your computer and use it in GitHub Desktop.
Save dvgodoy/9bdd0bf1d22797b503a27cfc7298d765 to your computer and use it in GitHub Desktop.
set_seed(13)
z_size = 1
input_shape = (1, 28, 28)
base_model = nn.Sequential(
# (C, H, W) -> C*H*W
nn.Flatten(),
# C*H*W -> 2048
nn.Linear(np.prod(input_shape), 2048),
nn.LeakyReLU(),
# 2048 -> 2048
nn.Linear(2048, 2048),
nn.LeakyReLU(),
)
encoder_var = EncoderVar(input_shape, z_size, base_model)
decoder_var = nn.Sequential(
# z_size -> 2048
nn.Linear(z_size, 2048),
nn.LeakyReLU(),
# 2048 -> 2048
nn.Linear(2048, 2048),
nn.LeakyReLU(),
# 2048 -> C*H*W
nn.Linear(2048, np.prod(input_shape)),
# C*H*W -> (C, H, W)
nn.Unflatten(1, input_shape)
)
model_vae = AutoEncoder(encoder_var, decoder_var)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment