Skip to content

Instantly share code, notes, and snippets.

@JossWhittle
Created September 12, 2022 16:12
Show Gist options
  • Save JossWhittle/64042f2c42d154975987404b86f4666c to your computer and use it in GitHub Desktop.
Save JossWhittle/64042f2c42d154975987404b86f4666c to your computer and use it in GitHub Desktop.
import numpy as np
from numba import njit, prange
BIT_COUNT_LOOKUP = np.array([bin(i).count('1') for i in range(256)]).astype(np.uint8)
@njit(fastmath=True, nopython=True, parallel=True)
def fast_tanimoto_matrix(fingerprints, progress):
"""
Compute a symmetric Tanimoto similarity matrix over a set of fingerprints of size (N, F//8).
Where N is the number of fingerprints, and F is the length of the boolean fingerprint.
F//8 refers to the length fingerprint after it is bit-packed into uint8's.
:param fingerprints: A numpy array of type uint8 representing the bit-packed boolean fingerprints.
:param progress: A numba-progress object for reporting intermediate progress.
:return: A symmetric (N, N) matrix of float32's with the Tanimoto similarity scores between each fingerprint pair.
"""
n = fingerprints.shape[0]
# Use np.ones because diagonal values will be Tanimoto similarity 1
matrix = np.ones((n, n), dtype=np.float32)
# Pre-compute the bit count for each fingerprint
fingerprint_counts = np.sum(BIT_COUNT_LOOKUP[fingerprints.flatten()].reshape((-1, fingerprints.shape[1])), axis=-1)
# For each fingerprint A in parallel, compute the Tanimoto similarity compared to all fingerprints B
for row in prange(n):
# Bit count of fingerprint A
na = fingerprint_counts[row]
# Bit counts for each fingerprint B to compare A to
nb = fingerprint_counts[(row+1):]
# Bit count for the combination of each pair A B combined with bitwise AND
nc = np.sum(BIT_COUNT_LOOKUP[np.bitwise_and(fingerprints[row], fingerprints[(row + 1):]).flatten()]
.reshape((-1, fingerprints.shape[1])), axis=-1)
# Tanimoto distances of fingerprint A to each fingerprint B
distances = nc / (na + nb - nc)
# Update the symmetric distance matrix
matrix[row, (row+1):] = distances # Update end of each row
matrix[(row+1):, row] = distances # Update end of each column
progress.update(1)
return matrix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment