Last active
March 19, 2023 21:46
-
-
Save lemon24/b9af5ade919713406bda9603847d32e5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
(Broken?) MinHash implementation attempting to optimize Jaccard similarity | |
for https://github.com/lemon24/reader/issues/202 (reader.entry_dedupe plugin). | |
--- | |
Current state: | |
It kinda works, but no matter how much I increase LOOPS, | |
the result doesn't seem to converge to the real similarity | |
(e.g. for 1_000_000, the weighted version has differences up to 0.10, | |
not much different from the 1000 ones; ???). | |
This is likely due to a bug, or me misunderstanding the algorithm. | |
2023-03 UPDATE: | |
With hashlib.md5() as the hash function, | |
it's within 0.02 of the real similarity, | |
and it does converge with more loops! | |
(I probably just looked at the hash() version for some reason.) | |
--- | |
Even if we fix it: | |
The L5-Minhash.pdf doc below says that to reduce the error rate | |
to < 0.05 99% of the times we need ~1000 loops. | |
I don't think this would be an optimiztion for "online" use. | |
(Haven't really thought how document size affects things, though.) | |
*However*, if we use a known array of randoms and a stable hash function, | |
we can precompute min_hashes and store it for each entry "offline", | |
and then just compare the hashes to get the actual similarity. | |
We might also use this to bucket/group entries by similarity. | |
--- | |
Obviously, I don't need to roll my own; | |
datasketch below seems to do a great job. | |
--- | |
https://en.wikipedia.org/wiki/Jaccard_index | |
https://www.cs.utah.edu/~jeffp/teaching/cs5140-S15/cs5140/L4-Jaccard+nGram.pdf | |
https://en.wikipedia.org/wiki/MinHash | |
https://en.wikipedia.org/wiki/MinHash#Incorporating_weights | |
https://www.cs.utah.edu/~jeffp/teaching/cs5140-S15/cs5140/L5-Minhash.pdf | |
--- | |
$ time python minhash.py | |
js mh dmh jsw mhw dmhn dmhw | |
0.50 0.41 0.48 0.95 0.94 0.94 0.99 one=40 one=39 two=1 | |
1.00 1.00 1.00 0.95 0.94 0.93 0.00 one=40 one=38 | |
0.14 0.17 0.14 0.72 0.72 0.72 0.92 one=40 one=37 two=3 | |
0.17 0.13 0.16 0.79 0.73 0.76 0.00 one=50 one=47 two=1 three=1 | |
0.12 0.18 0.12 0.82 0.80 0.83 0.00 one=50 one=55 two=5 | |
0.11 0.13 0.12 0.75 0.71 0.73 0.00 one=70 one=63 two=5 three=1 | |
0.12 0.15 0.12 0.68 0.59 0.68 0.88 one=70 one=60 two=10 | |
0.00 0.02 0.13 0.00 0.23 0.14 0.25 times | |
python minhash.py 1.45s user 0.12s system 113% cpu 1.393 total | |
""" | |
from collections import Counter | |
import sys | |
import random | |
import hashlib | |
import time | |
from itertools import groupby | |
from reader.plugins.entry_dedupe import _ngrams | |
sys.path.append('tests') | |
import test_plugins_entry_dedupe | |
from datasketch import MinHash, WeightedMinHashGenerator | |
DATA = [] | |
for one, two, _ in test_plugins_entry_dedupe.IS_DUPLICATE_DATA: | |
if not one.summary or 'one' not in one.summary: | |
continue | |
DATA.append((one.summary.split(), two.summary.split())) | |
def to_str(value): | |
parts = [] | |
parts.extend(f'{k}={v}' for k, v in Counter(value).items()) | |
return ' '.join(parts) | |
def jaccard(one, two, n): | |
one = set(_ngrams(one, n)) | |
two = set(_ngrams(two, n)) | |
return len(one & two) / len(one | two) | |
def jaccard_weighted(one, two, n): | |
one = Counter(_ngrams(one, n)) | |
two = Counter(_ngrams(two, n)) | |
return sum((one & two).values()) / sum((one | two).values()) | |
LOOPS = 1000 | |
# sys.maxsize because https://stackoverflow.com/a/19133757 | |
HASH_MAX = sys.maxsize | |
HASH = hash | |
# at least 5x slower | |
#HASH_MAX = b'\xff' * hashlib.md5().digest_size | |
#def HASH(thing): return hashlib.md5(repr(thing).encode('utf-8')).digest() | |
def minhash(one, two, n): | |
one = set(_ngrams(one, n)) | |
two = set(_ngrams(two, n)) | |
loops = LOOPS | |
min_hashes = [[HASH_MAX] * 2 for _ in range(loops)] | |
randoms = [random.random() for _ in range(loops)] | |
for ic, counts in enumerate((one, two)): | |
for t in counts: | |
for ir, r in enumerate(randoms): | |
h = HASH((r, t)) | |
if h < min_hashes[ir][ic]: | |
min_hashes[ir][ic] = h | |
sim = sum(h_one == h_two for h_one, h_two in min_hashes) / loops | |
return sim | |
def minhash_weighted(one, two, n): | |
one = Counter(_ngrams(one, n)) | |
two = Counter(_ngrams(two, n)) | |
loops = LOOPS | |
min_hashes = [[HASH_MAX] * 2 for _ in range(loops)] | |
randoms = [random.random() for _ in range(loops)] | |
for ic, counts in enumerate((one, two)): | |
for t in counts: | |
for ir, r in enumerate(randoms): | |
for ix in range(counts[t]): | |
h = HASH((r, ix, t)) | |
if h < min_hashes[ir][ic]: | |
min_hashes[ir][ic] = h | |
sim = sum(h_one == h_two for h_one, h_two in min_hashes) / loops | |
return sim | |
def datasketch_minhash(one, two, n): | |
m_one = MinHash(num_perm=LOOPS) | |
for t in _ngrams(one, n): | |
m_one.update(' '.join(t).encode('utf-8')) | |
m_two = MinHash(num_perm=LOOPS) | |
for t in _ngrams(two, n): | |
m_two.update(' '.join(t).encode('utf-8')) | |
return m_one.jaccard(m_two) | |
def datasketch_minhash_weighted(one, two, n): | |
# this one only works for same-size sets | |
one = list(one) | |
two = list(two) | |
if len(one) != len(two): | |
return 0 | |
hashfunc = MinHash().hashfunc | |
one = list(hashfunc(' '.join(t).encode('utf-8')) for t in _ngrams(one, n)) | |
two = list(hashfunc(' '.join(t).encode('utf-8')) for t in _ngrams(two, n)) | |
gen = WeightedMinHashGenerator(len(one), LOOPS) | |
m_one = gen.minhash(one) | |
m_two = gen.minhash(two) | |
return m_one.jaccard(m_two) | |
def enumerated_ngrams(it, n): | |
for _, group in groupby(sorted(_ngrams(it, n))): | |
for i, t in enumerate(group): | |
yield t + (str(i),) | |
def datasketch_minhash_weighted_naive(one, two, n): | |
m_one = MinHash(num_perm=LOOPS) | |
for t in enumerated_ngrams(one, n): | |
m_one.update(' '.join(t).encode('utf-8')) | |
m_two = MinHash(num_perm=LOOPS) | |
for t in enumerated_ngrams(two, n): | |
m_two.update(' '.join(t).encode('utf-8')) | |
return m_one.jaccard(m_two) | |
impls = { | |
'js': jaccard, | |
'mh': minhash, | |
'dmh': datasketch_minhash, | |
'jsw': jaccard_weighted, | |
'mhw': minhash_weighted, | |
'dmhn': datasketch_minhash_weighted_naive, | |
'dmhw': datasketch_minhash_weighted, | |
} | |
print(''.join(f'{l:>4} ' for l in impls)) | |
times = {} | |
for one, two in DATA: | |
sims = [] | |
for n, fn in impls.items(): | |
start = time.perf_counter() | |
val = fn(one, two, 4) | |
end = time.perf_counter() | |
sims.append(val) | |
times[n] = times.get(n, 0) + end - start | |
print( | |
f"{''.join(f'{s:.2f} ' for s in sims)}" | |
f"{to_str(one):2} " | |
f"{to_str(two):2}" | |
) | |
print( | |
f"{''.join(f'{s:.2f} ' for s in times.values())}" | |
"times" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment