Skip to content

Instantly share code, notes, and snippets.

@ericjang
Last active February 24, 2016 05:19
Show Gist options
  • Save ericjang/f6393b86282bddd31d71 to your computer and use it in GitHub Desktop.
Save ericjang/f6393b86282bddd31d71 to your computer and use it in GitHub Desktop.
DRAW high-level implementation
cs,mus,logsigmas,sigmas=[0]*T,[0]*T,[0]*T,[0]*T # parameters we'll need to access later
# initial states
DO_SHARE=False
h_dec_prev=tf.zeros((batch_size,dec_size))
enc_state=lstm_enc.zero_state(batch_size, tf.float32)
dec_state=lstm_dec.zero_state(batch_size, tf.float32)
# build the graph
for t in range(T):
c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1]
x_hat=x-tf.sigmoid(c_prev) # error image
r=read(x,x_hat,h_dec_prev) # eq 4
h_enc,enc_state=encode(enc_state,tf.concat(1,[r,h_dec_prev])) # eq 5
z,mus[t],logsigmas[t],sigmas[t]=sampleQ(h_enc) # eq 6
h_dec,dec_state=decode(dec_state,z) # eq 7
cs[t]=c_prev+write(h_dec) # store results # eq 8
h_dec_prev=h_dec
DO_SHARE=True # from now on, share variables
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment