Last active
June 19, 2024 20:03
-
-
Save gallir/e719f8c816b8c7d349a8d69d3678acbb to your computer and use it in GitHub Desktop.
Very fast function to get cosine similarity between 2 short texts, where counting the number of words is no needed (i.e. binary bag of words) but it works pretty well with non-ascii weird characters.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from unidecode import unidecode | |
import re | |
import sys | |
import inflection | |
import numpy as np | |
import math | |
from collections import defaultdict | |
# Using cosine_similarity, own faster implementation, inspired by | |
# https://towardsdatascience.com/calculating-string-similarity-in-python-276e18a7d33a | |
_tokens_cache = defaultdict(lambda: None) | |
_phone_regex = re.compile(r'[^\d]|^0+') | |
# last is a subset from string.punctuation | |
_nopunctuation = str.maketrans('()[]-&:;./-.', ' ', '\'`´!"#$%*+,<=>?@\\^_`{|}~') | |
def cosine_similarity(text1, text2, cache=False, stopwords=None, enders=None): | |
# Return cosine similarity between text1 and text2 | |
tok1 = tok2 = None | |
if cache: | |
tok1 = _tokens_cache[text1] | |
tok2 = _tokens_cache[text2] | |
if tok1 is None: | |
tok1 = get_tokens(text1) | |
if cache: | |
_tokens_cache[text1] = tok1 | |
if tok2 is None: | |
tok2 = get_tokens(text2) | |
if cache: | |
_tokens_cache[text2] = tok2 | |
if not tok1 or not tok2: | |
return 0.0 | |
if tok1 == tok2: | |
return 1.0 | |
vocabulary = set(tok1 + tok2) | |
if len(vocabulary) == len(tok1) + len(tok2): | |
# No intersections | |
return 0.0 | |
v1 = np.zeros(len(vocabulary)) | |
v2 = np.zeros(len(vocabulary)) | |
for i, w in enumerate(vocabulary): | |
if w in tok1: | |
v1[i] = 1 | |
if w in tok2: | |
v2[i] = 1 | |
# This the cosine = v1 DOT v2 / (norm-2(v1) * norm-2(v2)) | |
# equivalent but +2x faster than np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) | |
return np.dot(v1, v2) / (math.sqrt(np.dot(v1, v1)) * math.sqrt(np.dot(v2, v2))) | |
def get_tokens(text, stopwords=None, enders=None): | |
text = text.translate(_nopunctuation) | |
text = unidecode(text) | |
text = text.lower() | |
tokens = [inflection.singularize(w) for w in text.split() if len(w) > 1 and (not stopwords or w not in stopwords)] | |
if enders: | |
for i, w in enumerate(tokens): | |
if w in enders: | |
tokens = tokens[:i] | |
break | |
return sorted(tokens) | |
def phonenumber_equal(a, b): | |
a = _phone_regex.sub('', a) | |
b = _phone_regex.sub('', b) | |
if len(a) > 8 or len(b) > 8 and a == b: # Only if at least they hav 9 digits | |
return True | |
return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment