Skip to content

Instantly share code, notes, and snippets.

@gallir
Last active June 19, 2024 20:03
Show Gist options
  • Save gallir/e719f8c816b8c7d349a8d69d3678acbb to your computer and use it in GitHub Desktop.
Save gallir/e719f8c816b8c7d349a8d69d3678acbb to your computer and use it in GitHub Desktop.
Very fast function to get cosine similarity between 2 short texts, where counting the number of words is no needed (i.e. binary bag of words) but it works pretty well with non-ascii weird characters.
from unidecode import unidecode
import re
import sys
import inflection
import numpy as np
import math
from collections import defaultdict
# Using cosine_similarity, own faster implementation, inspired by
# https://towardsdatascience.com/calculating-string-similarity-in-python-276e18a7d33a
_tokens_cache = defaultdict(lambda: None)
_phone_regex = re.compile(r'[^\d]|^0+')
# last is a subset from string.punctuation
_nopunctuation = str.maketrans('()[]-&:;./-.', ' ', '\'`´!"#$%*+,<=>?@\\^_`{|}~')
def cosine_similarity(text1, text2, cache=False, stopwords=None, enders=None):
# Return cosine similarity between text1 and text2
tok1 = tok2 = None
if cache:
tok1 = _tokens_cache[text1]
tok2 = _tokens_cache[text2]
if tok1 is None:
tok1 = get_tokens(text1)
if cache:
_tokens_cache[text1] = tok1
if tok2 is None:
tok2 = get_tokens(text2)
if cache:
_tokens_cache[text2] = tok2
if not tok1 or not tok2:
return 0.0
if tok1 == tok2:
return 1.0
vocabulary = set(tok1 + tok2)
if len(vocabulary) == len(tok1) + len(tok2):
# No intersections
return 0.0
v1 = np.zeros(len(vocabulary))
v2 = np.zeros(len(vocabulary))
for i, w in enumerate(vocabulary):
if w in tok1:
v1[i] = 1
if w in tok2:
v2[i] = 1
# This the cosine = v1 DOT v2 / (norm-2(v1) * norm-2(v2))
# equivalent but +2x faster than np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
return np.dot(v1, v2) / (math.sqrt(np.dot(v1, v1)) * math.sqrt(np.dot(v2, v2)))
def get_tokens(text, stopwords=None, enders=None):
text = text.translate(_nopunctuation)
text = unidecode(text)
text = text.lower()
tokens = [inflection.singularize(w) for w in text.split() if len(w) > 1 and (not stopwords or w not in stopwords)]
if enders:
for i, w in enumerate(tokens):
if w in enders:
tokens = tokens[:i]
break
return sorted(tokens)
def phonenumber_equal(a, b):
a = _phone_regex.sub('', a)
b = _phone_regex.sub('', b)
if len(a) > 8 or len(b) > 8 and a == b: # Only if at least they hav 9 digits
return True
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment