Skip to content

Instantly share code, notes, and snippets.

@rdisipio
Last active July 24, 2021 07:24
Show Gist options
  • Save rdisipio/bcebc22116361c75f7b7e32ce5bb3045 to your computer and use it in GitHub Desktop.
Save rdisipio/bcebc22116361c75f7b7e32ce5bb3045 to your computer and use it in GitHub Desktop.
def top_k_filtering( logits, top_k = 5):
# a[...,1] equivalent to a[: ,: ,1 ]
indices_to_remove = logits < tf.math.top_k(logits,top_k)[0][..., -1, None]
# indices_to_remove is a tensor of bool values e.g. [ True, False, False, ..., True ]
# 1d indices
idx_remove = tf.where( indices_to_remove == True )[:,-1]
idx_keep = tf.where( indices_to_remove == False )[:,-1]
values_remove = tf.tile( [-float('inf')], [tf.shape(idx_remove)[0]] )
values_keep = tf.gather( logits[0], idx_keep )
# to create a sparse vector we still need 2d indices like [ [0,1], [0,2], [0,10] ]
# create vectors of 0's that we'll later stack with the actual indices
zeros_remove = tf.zeros_like(idx_remove)
zeros_keep = tf.zeros_like(idx_keep)
idx_remove = tf.stack( [ zeros_remove, idx_remove], axis=1 )
idx_keep = tf.stack( [ zeros_keep, idx_keep], axis=1 )
# now we can create a sparse matrix
logits_remove = tf.SparseTensor( idx_remove, values_remove, tf.shape(logits, out_type = tf.int64))
logits_keep = tf.SparseTensor( idx_keep, values_keep, tf.shape(logits, out_type = tf.int64))
# add together the two matrices (need to convert them to dense first)
filtered_logits = tf.add(
tf.sparse.to_dense(logits_remove, default_value = 0. ),
tf.sparse.to_dense(logits_keep, default_value = 0. )
)
return filtered_logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment