Created
February 10, 2021 09:37
-
-
Save MrYakobo/a77be1f0db4d7e00ed12ea95af8ccc74 to your computer and use it in GitHub Desktop.
Maximum mean discrepancy for tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# based on https://github.com/wzell/mann/blob/master/models/maximum_mean_discrepancy.py | |
# that didn't work for me on tensorflow | |
import tensorflow as tf | |
def gaussian_kernel(x1, x2, beta = 1.0): | |
r = tf.transpose(x1) | |
r = tf.expand_dims(r, 2) | |
return tf.reduce_sum(K.exp( -beta * K.square(r - x2)), axis=-1) | |
def MMD(x1, x2, beta): | |
""" | |
maximum mean discrepancy (MMD) based on Gaussian kernel | |
function for keras models (theano or tensorflow backend) | |
- Gretton, Arthur, et al. "A kernel method for the two-sample-problem." | |
Advances in neural information processing systems. 2007. | |
""" | |
x1x1 = gaussian_kernel(x1, x1, beta) | |
x1x2 = gaussian_kernel(x1, x2, beta) | |
x2x2 = gaussian_kernel(x2, x2, beta) | |
diff = tf.reduce_mean(x1x1) - 2 * tf.reduce_mean(x1x2) + tf.reduce_mean(x2x2) | |
return diff |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment