Skip to content

Instantly share code, notes, and snippets.

@ericjang
Created February 25, 2016 05:38
Show Gist options
  • Save ericjang/a19d3ec85f08663dabe8 to your computer and use it in GitHub Desktop.
Save ericjang/a19d3ec85f08663dabe8 to your computer and use it in GitHub Desktop.
def write_no_attn(h_dec):
with tf.variable_scope("write",reuse=DO_SHARE):
return linear(h_dec,img_size)
def write_attn(h_dec):
with tf.variable_scope("writeW",reuse=DO_SHARE):
w=linear(h_dec,write_size) # batch x (write_n*write_n)
N=write_n
w=tf.reshape(w,[batch_size,N,N])
Fx,Fy,gamma=attn_window("write",h_dec,write_n)
Fyt=tf.transpose(Fy,perm=[0,2,1])
wr=tf.batch_matmul(Fyt,tf.batch_matmul(w,Fx))
wr=tf.reshape(wr,[batch_size,B*A])
#gamma=tf.tile(gamma,[1,B*A])
return wr*tf.reshape(1.0/gamma,[-1,1])
write=write_attn if FLAGS.write_attn else write_no_attn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment