Skip to content

Instantly share code, notes, and snippets.

@lemon24
Last active March 19, 2023 21:46
Show Gist options
  • Save lemon24/b9af5ade919713406bda9603847d32e5 to your computer and use it in GitHub Desktop.
Save lemon24/b9af5ade919713406bda9603847d32e5 to your computer and use it in GitHub Desktop.
"""
(Broken?) MinHash implementation attempting to optimize Jaccard similarity
for https://github.com/lemon24/reader/issues/202 (reader.entry_dedupe plugin).
---
Current state:
It kinda works, but no matter how much I increase LOOPS,
the result doesn't seem to converge to the real similarity
(e.g. for 1_000_000, the weighted version has differences up to 0.10,
not much different from the 1000 ones; ???).
This is likely due to a bug, or me misunderstanding the algorithm.
2023-03 UPDATE:
With hashlib.md5() as the hash function,
it's within 0.02 of the real similarity,
and it does converge with more loops!
(I probably just looked at the hash() version for some reason.)
---
Even if we fix it:
The L5-Minhash.pdf doc below says that to reduce the error rate
to < 0.05 99% of the times we need ~1000 loops.
I don't think this would be an optimiztion for "online" use.
(Haven't really thought how document size affects things, though.)
*However*, if we use a known array of randoms and a stable hash function,
we can precompute min_hashes and store it for each entry "offline",
and then just compare the hashes to get the actual similarity.
We might also use this to bucket/group entries by similarity.
---
Obviously, I don't need to roll my own;
datasketch below seems to do a great job.
---
https://en.wikipedia.org/wiki/Jaccard_index
https://www.cs.utah.edu/~jeffp/teaching/cs5140-S15/cs5140/L4-Jaccard+nGram.pdf
https://en.wikipedia.org/wiki/MinHash
https://en.wikipedia.org/wiki/MinHash#Incorporating_weights
https://www.cs.utah.edu/~jeffp/teaching/cs5140-S15/cs5140/L5-Minhash.pdf
---
$ time python minhash.py
js mh dmh jsw mhw dmhn dmhw
0.50 0.41 0.48 0.95 0.94 0.94 0.99 one=40 one=39 two=1
1.00 1.00 1.00 0.95 0.94 0.93 0.00 one=40 one=38
0.14 0.17 0.14 0.72 0.72 0.72 0.92 one=40 one=37 two=3
0.17 0.13 0.16 0.79 0.73 0.76 0.00 one=50 one=47 two=1 three=1
0.12 0.18 0.12 0.82 0.80 0.83 0.00 one=50 one=55 two=5
0.11 0.13 0.12 0.75 0.71 0.73 0.00 one=70 one=63 two=5 three=1
0.12 0.15 0.12 0.68 0.59 0.68 0.88 one=70 one=60 two=10
0.00 0.02 0.13 0.00 0.23 0.14 0.25 times
python minhash.py 1.45s user 0.12s system 113% cpu 1.393 total
"""
from collections import Counter
import sys
import random
import hashlib
import time
from itertools import groupby
from reader.plugins.entry_dedupe import _ngrams
sys.path.append('tests')
import test_plugins_entry_dedupe
from datasketch import MinHash, WeightedMinHashGenerator
DATA = []
for one, two, _ in test_plugins_entry_dedupe.IS_DUPLICATE_DATA:
if not one.summary or 'one' not in one.summary:
continue
DATA.append((one.summary.split(), two.summary.split()))
def to_str(value):
parts = []
parts.extend(f'{k}={v}' for k, v in Counter(value).items())
return ' '.join(parts)
def jaccard(one, two, n):
one = set(_ngrams(one, n))
two = set(_ngrams(two, n))
return len(one & two) / len(one | two)
def jaccard_weighted(one, two, n):
one = Counter(_ngrams(one, n))
two = Counter(_ngrams(two, n))
return sum((one & two).values()) / sum((one | two).values())
LOOPS = 1000
# sys.maxsize because https://stackoverflow.com/a/19133757
HASH_MAX = sys.maxsize
HASH = hash
# at least 5x slower
#HASH_MAX = b'\xff' * hashlib.md5().digest_size
#def HASH(thing): return hashlib.md5(repr(thing).encode('utf-8')).digest()
def minhash(one, two, n):
one = set(_ngrams(one, n))
two = set(_ngrams(two, n))
loops = LOOPS
min_hashes = [[HASH_MAX] * 2 for _ in range(loops)]
randoms = [random.random() for _ in range(loops)]
for ic, counts in enumerate((one, two)):
for t in counts:
for ir, r in enumerate(randoms):
h = HASH((r, t))
if h < min_hashes[ir][ic]:
min_hashes[ir][ic] = h
sim = sum(h_one == h_two for h_one, h_two in min_hashes) / loops
return sim
def minhash_weighted(one, two, n):
one = Counter(_ngrams(one, n))
two = Counter(_ngrams(two, n))
loops = LOOPS
min_hashes = [[HASH_MAX] * 2 for _ in range(loops)]
randoms = [random.random() for _ in range(loops)]
for ic, counts in enumerate((one, two)):
for t in counts:
for ir, r in enumerate(randoms):
for ix in range(counts[t]):
h = HASH((r, ix, t))
if h < min_hashes[ir][ic]:
min_hashes[ir][ic] = h
sim = sum(h_one == h_two for h_one, h_two in min_hashes) / loops
return sim
def datasketch_minhash(one, two, n):
m_one = MinHash(num_perm=LOOPS)
for t in _ngrams(one, n):
m_one.update(' '.join(t).encode('utf-8'))
m_two = MinHash(num_perm=LOOPS)
for t in _ngrams(two, n):
m_two.update(' '.join(t).encode('utf-8'))
return m_one.jaccard(m_two)
def datasketch_minhash_weighted(one, two, n):
# this one only works for same-size sets
one = list(one)
two = list(two)
if len(one) != len(two):
return 0
hashfunc = MinHash().hashfunc
one = list(hashfunc(' '.join(t).encode('utf-8')) for t in _ngrams(one, n))
two = list(hashfunc(' '.join(t).encode('utf-8')) for t in _ngrams(two, n))
gen = WeightedMinHashGenerator(len(one), LOOPS)
m_one = gen.minhash(one)
m_two = gen.minhash(two)
return m_one.jaccard(m_two)
def enumerated_ngrams(it, n):
for _, group in groupby(sorted(_ngrams(it, n))):
for i, t in enumerate(group):
yield t + (str(i),)
def datasketch_minhash_weighted_naive(one, two, n):
m_one = MinHash(num_perm=LOOPS)
for t in enumerated_ngrams(one, n):
m_one.update(' '.join(t).encode('utf-8'))
m_two = MinHash(num_perm=LOOPS)
for t in enumerated_ngrams(two, n):
m_two.update(' '.join(t).encode('utf-8'))
return m_one.jaccard(m_two)
impls = {
'js': jaccard,
'mh': minhash,
'dmh': datasketch_minhash,
'jsw': jaccard_weighted,
'mhw': minhash_weighted,
'dmhn': datasketch_minhash_weighted_naive,
'dmhw': datasketch_minhash_weighted,
}
print(''.join(f'{l:>4} ' for l in impls))
times = {}
for one, two in DATA:
sims = []
for n, fn in impls.items():
start = time.perf_counter()
val = fn(one, two, 4)
end = time.perf_counter()
sims.append(val)
times[n] = times.get(n, 0) + end - start
print(
f"{''.join(f'{s:.2f} ' for s in sims)}"
f"{to_str(one):2} "
f"{to_str(two):2}"
)
print(
f"{''.join(f'{s:.2f} ' for s in times.values())}"
"times"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment