Skip to content

Instantly share code, notes, and snippets.

@ericjang
Last active February 24, 2016 16:19
Show Gist options
  • Save ericjang/d647d632c39ab22d6fad to your computer and use it in GitHub Desktop.
Save ericjang/d647d632c39ab22d6fad to your computer and use it in GitHub Desktop.
vectorized matrix multiplication in TF
# applying read filter. inspired by https://github.com/ikostrikov/TensorFlow-VAE-GAN-DRAW/blob/master/main-draw.py
Fxt=tf.transpose(Fx, [0,2,1]) # batch x N x A
Fxt=tf.reshape(Fxt, [-1,1,A,N,1]) # batch x 1 x A x N x 1
Fxt=tf.tile(Fxt, [1,N,1,1,1]) # batch x N x A x N x 1 (repmat'ed along dim=1)
Fy=tf.reshape(Fy, [-1,N,B,1,1]) # batch x N x B x 1 x 1
x=tf.reshape(x,[-1,1,B,A,1]) # batch x 1 x B x A x 1
x=tf.tile(x,[1,N,1,1,1]) # batch x N x B x A x 1
Fydotx=tf.reduce_sum(Fy*x,2) # batch x N x A x 1
Fydotx=tf.reshape(x,[-1,N,A,1,1]) # batch x N x A x 1 x 1
FydotxdotFxt=tf.reduce_sum(Fydotx*Fxt,2) # batch x N x N x 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment