Created
December 13, 2018 16:23
-
-
Save kzinmr/791ca313c7ef0fe211b64cc04a409e7e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def calculate_cosine(query, X, k=10, threshold=0.5): | |
""" | |
Arguments: | |
query: (dim,) or (dim, 1)-array | |
X: (dim, N)-array of N-vectors | |
Returns: | |
topk_score_indices: sorted top-k-scored indices of N-vectors | |
topk_score: and its scores | |
""" | |
# (dim,) -> (dim, 1) if necessary | |
if len(query.shape) == 1: | |
query = query[:, None] | |
assert len(query.shape) == 2 and query.shape[1] == 1 | |
# (dim, N) -> (1, N) | |
x_norm = (X * X).sum(0, keepdims=True) ** .5 | |
# (dim, 1) -> (1,) | |
q_norm = (query * query).sum(0, keepdims=True) ** .5 | |
# (1, N) | |
qx = query.T @ X | |
# (N,) | |
cos_qx = np.squeeze((qx / x_norm / q_norm)) | |
# nan is always sorted to the last (-nan=nan) | |
# e.g. [-1, 0, 2, nan] => [-2, 0, 1, nan] => [2, 1, 0, 3] | |
# (k,) | |
cos_qx_filtered = cos_qx[cos_qx > threshold] | |
if cos_qx_filtered.shape[0] == 0: | |
return [], [] | |
topk_score_indices = np.argsort(-cos_qx_filtered)[:k] | |
topk_score = -np.sort(-cos_qx_filtered)[:k] | |
return topk_score_indices, topk_score |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment