Skip to content

Instantly share code, notes, and snippets.

@aaaddress1
Last active April 15, 2020 19:43
Show Gist options
  • Save aaaddress1/d39e4b2610d52b80735cc39f8a2f8dcb to your computer and use it in GitHub Desktop.
Save aaaddress1/d39e4b2610d52b80735cc39f8a2f8dcb to your computer and use it in GitHub Desktop.
cos_similarity.py
# co-occurence matrix & cos-similarity, by [email protected]
testSample = 'adr have 30cm and shenghao have 30cm'
in_sample = testSample.split()
corups = set(in_sample)
co_matrix = { x: dict.fromkeys(corups, 0) for x in corups }
win_size = 1
for indx, curr_token in enumerate(in_sample):
if indx - win_size >= 0:
for prev_token in in_sample[indx - win_size: indx]:
co_matrix[curr_token][prev_token] += 1
if indx + win_size < len(in_sample):
for next_token in in_sample[indx + 1: indx + win_size + 1]:
co_matrix[curr_token][next_token] += 1
def cos_similarity(token_a: str, token_b: str):
import math
L_a = math.sqrt(sum([ x**2 for x in co_matrix[token_a].values()])) + 1e-8
L_b = math.sqrt(sum([ x**2 for x in co_matrix[token_b].values()])) + 1e-8
dot_ab = float(sum([ co_matrix[token_a][token] * co_matrix[token_b][token] for token in corups]))
return dot_ab / (L_a * L_b)
print(cos_similarity('adr', 'shenghao')) # ans -> 0.7071
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment