Last active
July 24, 2021 07:24
-
-
Save rdisipio/bcebc22116361c75f7b7e32ce5bb3045 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def top_k_filtering( logits, top_k = 5): | |
# a[...,1] equivalent to a[: ,: ,1 ] | |
indices_to_remove = logits < tf.math.top_k(logits,top_k)[0][..., -1, None] | |
# indices_to_remove is a tensor of bool values e.g. [ True, False, False, ..., True ] | |
# 1d indices | |
idx_remove = tf.where( indices_to_remove == True )[:,-1] | |
idx_keep = tf.where( indices_to_remove == False )[:,-1] | |
values_remove = tf.tile( [-float('inf')], [tf.shape(idx_remove)[0]] ) | |
values_keep = tf.gather( logits[0], idx_keep ) | |
# to create a sparse vector we still need 2d indices like [ [0,1], [0,2], [0,10] ] | |
# create vectors of 0's that we'll later stack with the actual indices | |
zeros_remove = tf.zeros_like(idx_remove) | |
zeros_keep = tf.zeros_like(idx_keep) | |
idx_remove = tf.stack( [ zeros_remove, idx_remove], axis=1 ) | |
idx_keep = tf.stack( [ zeros_keep, idx_keep], axis=1 ) | |
# now we can create a sparse matrix | |
logits_remove = tf.SparseTensor( idx_remove, values_remove, tf.shape(logits, out_type = tf.int64)) | |
logits_keep = tf.SparseTensor( idx_keep, values_keep, tf.shape(logits, out_type = tf.int64)) | |
# add together the two matrices (need to convert them to dense first) | |
filtered_logits = tf.add( | |
tf.sparse.to_dense(logits_remove, default_value = 0. ), | |
tf.sparse.to_dense(logits_keep, default_value = 0. ) | |
) | |
return filtered_logits |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment