Last active
October 1, 2015 03:57
-
-
Save EderSantana/bd0d940c5c34386b76b6 to your computer and use it in GitHub Desktop.
DRAW config
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We are training this DRAW network | |
# using a model similar to the one | |
# described in the paper http://arxiv.org/pdf/1502.04623.pdf | |
# | |
# Dataset: Binary-MNIST from mila-udem/fuel | |
from keras.initializations import normal | |
from seya.layers.draw import DRAW | |
def myinit(shape): | |
return normal(shape, scale=.01) | |
# h_dim: number of states of the inner RNN | |
# z_dim: dimension of the Gaussian sample (remember VAE and the repametrization trick?) | |
draw = DRAW(h_dim=256, z_dim=100, input_shape=(1, 28, 28), N_enc=2, N_dec=5, | |
return_sequences=True, inner_rnn='lstm', init=myinit, inner_init=myinit) | |
adam = Adam(lr=3e-4, clipnorm=10) | |
# Stay tuned for more details |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment