Created
October 12, 2017 20:23
-
-
Save previtus/07e5c0bbb348bfba7cd0ab28767f6fc0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# tensorflow inside keras metric, with debugging tf variables | |
def grouped_mse(k=3): | |
def f(y_true, y_pred): | |
group_by = tf.constant(k) | |
real_size = tf.size(y_pred) | |
remainder = tf.truncatemod(real_size, group_by) | |
remainder = K.print_tensor(remainder, message="remainder is: ") | |
# ignore the rest, this could apply for few in the last minibatch | |
y_true = y_true[0:real_size - remainder] | |
y_pred = y_pred[0:real_size - remainder] | |
real_size = tf.size(y_pred) + 0*remainder | |
real_size = K.print_tensor(real_size, message="real_size is: ") | |
n = real_size / group_by | |
n = K.print_tensor(n, message="n is: ") | |
# from 1..n create a vector with [111 222 333 ... nnn] where the number of repeats is group_by | |
idx = tf.range(n) | |
idx = tf.reshape(idx, [-1, 1]) # Convert to a len(yp) x 1 matrix. | |
idx = tf.tile(idx, [1, group_by]) # Create multiple columns. | |
idx = tf.reshape(idx, [-1]) # Convert back to a vector. | |
y_pred_byK = tf.segment_mean(y_pred, idx) | |
y_true_byK = tf.segment_mean(y_true, idx) | |
tmp = K.mean(K.square(y_pred_byK - y_true_byK), axis=-1) | |
return tmp | |
""" | |
# this can be a metric, but not a loss function - we would have to explode it again to its original size | |
a = tf.size(y_pred_byK) | |
b = tf.size(y_true_byK) | |
a = tf.cast(a, tf.float32) | |
b = tf.cast(b, tf.float32) | |
a = K.print_tensor(a, message="a is: ") | |
b = K.print_tensor(b, message="b is: ") | |
tmp = y_pred_byK - y_true_byK | |
tmp = K.print_tensor(tmp, message="tmp1 is: ") | |
tmp = K.square(tmp) | |
tmp = K.print_tensor(tmp, message="tmp2 is: ") | |
tmp = K.mean(tmp, axis=-1) | |
tmp = K.print_tensor(tmp, message="tmp3 is: ") | |
tmp = tf.scalar_mul(1+0*a+0*b,tmp) | |
""" | |
return tmp | |
return f |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment