Last active
September 18, 2024 11:49
-
-
Save schaunwheeler/5ac6fb4cc393f921fc8b8b55bc2ced2e to your computer and use it in GitHub Desktop.
Use MinHash to get Jaccard Similarity in Pyspark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy.random import RandomState | |
import pyspark.sql.functions as f | |
from pyspark import StorageLevel | |
def hashmin_jaccard_spark( | |
sdf, node_col, edge_basis_col, suffixes=('A', 'B'), | |
n_draws=100, storage_level=None, seed=42, verbose=False): | |
""" | |
Calculate a sparse Jaccard similarity matrix using MinHash. | |
Parameters | |
sdf (pyspark.sql.DataFrame): A Dataframe containing at least two columns: | |
one defining the nodes (similarity between which is to be calculated) | |
and one defining the edges (the basis for node comparisons). | |
node_col (str): the name of the DataFrame column containing node labels | |
edge_basis_col: the name of the DataFrame columns containing the edge labels | |
suffixes (tuple): A tuple of length 2 contining the suffixes to be appeneded | |
to `node_col` in the output | |
n_draws (int): the number of permutations to do; this determines the precision | |
of the Jaccard similarity (n_draws == 100, the default, results in | |
similarity precision up to 0.01. | |
storage_level (pyspark.StorageLevel): PySpark object indicating how to persist | |
the hashing stage of the process | |
seed (int): seed for random state generation | |
verbose (bool): if True, print some information about how many records get hashed | |
""" | |
HASH_PRIME = 2038074743 | |
left_name = node_col + suffixes[0] | |
right_name = node_col + suffixes[1] | |
rs = RandomState(seed) | |
shifts = rs.randint(0, HASH_PRIME - 1, size=n_draws) | |
coefs = rs.randint(0, HASH_PRIME - 1, size=n_draws) + 1 | |
hash_sdf = ( | |
sdf | |
.selectExpr( | |
"*", | |
*[ | |
f"((1L + hash({edge_basis_col})) * {a} + {b}) % {HASH_PRIME} as hash{n}" | |
for n, (a, b) in enumerate(zip(coefs, shifts)) | |
] | |
) | |
.groupBy(node_col) | |
.agg( | |
f.array(*[f.min(f"hash{n}") for n in range(n_draws)]).alias("minHash") | |
) | |
.select( | |
node_col, | |
f.posexplode(f.col('minHash')).alias('hashIndex', 'minHash') | |
) | |
.groupby('hashIndex', 'minHash') | |
.agg( | |
f.collect_list(f.col(node_col)).alias('nodeList'), | |
f.collect_set(f.col(node_col)).alias('nodeSet') | |
) | |
) | |
if storage_level is not None: | |
hash_sdf = hash_sdf.persist(storage_level) | |
hash_count = hash_sdf.count() | |
if verbose: | |
print('Hash dataframe count:', hash_count) | |
adj_sdf = ( | |
hash_sdf.alias('a') | |
.join(hash_sdf.alias('b'), ['hashIndex', 'minHash'], 'inner') | |
.select( | |
f.col('minhash'), | |
f.explode(f.col('a.nodeList')).alias(left_name), | |
f.col('b.nodeSet') | |
) | |
.select( | |
f.col('minHash'), | |
f.col(left_name), | |
f.explode(f.col('nodeSet')).alias(right_name), | |
) | |
.groupby(left_name, right_name) | |
.agg((f.count('*') / n_draws).alias('jaccardSimilarity')) | |
) | |
return adj_sdf |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the
gist
. I was writing some unit tests and noticed that the error bounds are out-of-wack. I think you need to change line 45 fromf"((1L + hash({edge_basis_col})) * {a} + {b}) % {HASH_PRIME} as hash{n}"
to something likef"((1L + abs(hash({edge_basis_col}) % {HASH_PRIME})) * {a} + {b}) % {HASH_PRIME} as hash{n}"
?From the source where you got the hash function permutations, they cite this paper as proof that this family of hash functions work. But a condition for the proof is that, in the linear permutations$a \cdot x + b$ , we must satisfy $x \in [p] = {0, 1, ..., p-1}$ . In the current implementation, $x$ is effectively $[p]$ .
hash({edge_basis_col})
which can be any integer (even negative), so we need to force it to fall inI don't know how
spark
'shash()
works—so can't really check if this change actually makes the implementation theoretically sound ... but at least it passes my unit tests for theoretical error bounds.