Skip to content

Instantly share code, notes, and snippets.

@dvgodoy
Created April 30, 2022 12:34
Show Gist options
  • Save dvgodoy/ec317ac57b9930e2bafe52f5ddc5431c to your computer and use it in GitHub Desktop.
Save dvgodoy/ec317ac57b9930e2bafe52f5ddc5431c to your computer and use it in GitHub Desktop.
set_seed(13)
z_size = 1
n_filters = 32
in_channels = 1
img_size = 28
input_shape = (in_channels, img_size, img_size)
base_model = nn.Sequential(
# in_channels@28x28 -> n_filters@28x28
nn.Conv2d(in_channels, n_filters, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
# n_filters@28x28 -> (n_filters*2)@14x14
nn.Conv2d(n_filters, n_filters*2, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(),
# (n_filters*2)@14x14 -> (n_filters*2)@7x7
nn.Conv2d(n_filters*2, n_filters*2, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(),
# (n_filters*2)@7x7 -> (n_filters*2)@7x7
nn.Conv2d(n_filters*2, n_filters*2, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
# (n_filters*2)@7x7 -> (n_filters*2)*7*7
nn.Flatten(),
)
encoder_var_cnn = EncoderVar(input_shape, z_size, base_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment