Skip to content

Instantly share code, notes, and snippets.

@ericjang
Created February 24, 2016 14:36
Show Gist options
  • Save ericjang/5a49da4efdab031c9d8b to your computer and use it in GitHub Desktop.
Save ericjang/5a49da4efdab031c9d8b to your computer and use it in GitHub Desktop.
def read_attn(x,x_hat,h_dec_prev):
Fx,Fy,gamma=attn_window("read",h_dec_prev,read_n)
def filter_img(img,Fx,Fy,gamma,N):
Fxt=tf.transpose(Fx,perm=[0,2,1])
img=tf.reshape(img,[-1,B,A])
glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt))
glimpse=tf.reshape(glimpse,[-1,N*N])
return glimpse*tf.reshape(gamma,[-1,1])
x=filter_img(x,Fx,Fy,gamma,read_n) # batch x (read_n*read_n)
x_hat=filter_img(x_hat,Fx,Fy,gamma,read_n)
return tf.concat(1,[x,x_hat]) # concat along feature axis
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment