Created
April 24, 2017 11:51
-
-
Save 0xnurl/d2a4b708fbd2de8dd08ce7f70fc9a3e8 to your computer and use it in GitHub Desktop.
Keras metric for non-null targets accuracy, mainly for Named Entity Recognition models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def non_null_label_accuracy(y_true, y_pred): | |
"""Calculate accuracy excluding targets that are the null label (at index 0). | |
Useful when the null target is over-represented in the data, like in Named Entity Recognition tasks. | |
typical y shape: (batch_size, sentence_length, num_labels) | |
""" | |
y_true_argmax = K.argmax(y_true, -1) # ==> (batch_size, sentence_length, 1) | |
y_pred_argmax = K.argmax(y_pred, -1) # ==> (batch_size, sentence_length, 1) | |
y_true_argmax_flat = tf.reshape(y_true_argmax, [-1]) | |
y_pred_argmax_flat = tf.reshape(y_pred_argmax, [-1]) | |
non_null_targets_bool = K.not_equal(y_true_argmax_flat, K.zeros_like(y_true_argmax_flat)) | |
non_null_target_idx = K.flatten(K.cast(tf.where(non_null_targets_bool), 'int32')) | |
y_true_without_null = K.gather(y_true_argmax_flat, non_null_target_idx) | |
y_pred_without_null = K.gather(y_pred_argmax_flat, non_null_target_idx) | |
mean = K.mean(K.cast(K.equal(y_pred_without_null, | |
y_true_without_null), | |
K.floatx())) | |
fake_shape_mean = K.ones_like(y_true_argmax, K.floatx()) * mean # If model uses masking, Keras forces metric ouput to have same shape as y | |
return fake_shape_mean |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment