Last active
January 26, 2023 07:57
-
-
Save ncoop57/c2149e8413a0f0c531051154348a9ed3 to your computer and use it in GitHub Desktop.
Pyspark Minhash
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import os | |
from pyspark.ml import Pipeline | |
from pyspark.ml.feature import RegexTokenizer, NGram, HashingTF, MinHashLSH | |
from pyspark.sql.functions import col | |
from spark_session_builder import build_spark_session | |
spark = build_spark_session("spark://cpu64-dy-c6i-16xlarge-1:7077", 32, 128) | |
db = spark.read.parquet("/fsx/shared/pilev2_parquet/StackExchange_ver4_non_local_dedupped/dataset.parquet").limit(1_000_000) # Stage 0 & 1 | |
# db.show() | |
start = time.time() | |
# spark.sparkContext.defaultParallelism = os.cpu_count() | |
rdd = spark.sparkContext.parallelize(db.collect(), numSlices=5_000) | |
# Fit the pipeline to the parallelized data pipelineModel = pipeline.fit(rdd) | |
df = spark.createDataFrame(rdd, db.schema) | |
#, db.schema) | |
model = Pipeline(stages=[ | |
RegexTokenizer( # Stage 2 | |
pattern="[^A-Za-z_0-9]", inputCol="text", outputCol="tokens", minTokenLength=1 | |
), | |
NGram(n=5, inputCol="tokens", outputCol="ngrams"), # Stage 3 | |
HashingTF(inputCol="ngrams", outputCol="vectors"), # Stage 4 | |
MinHashLSH(inputCol="vectors", outputCol="lsh", numHashTables=13) # Stage 5 | |
]).fit(df) | |
db_hashed = model.transform(df) | |
duplicates = model.stages[-1].approxSimilarityJoin( | |
db_hashed, | |
db_hashed, | |
0.15, | |
distCol="JaccardDistance" | |
).filter("datasetA.id < datasetB.id") # Stage 6 | |
# duplicates.show() | |
duplicates.write.parquet("./duplicates", mode="overwrite") # Stage 7 | |
end = time.time() | |
print(f"Time taken: {end - start} for {db.count()} rows") | |
# duplicates.show() | |
# .select( | |
# col("datasetA.id").alias("idA"), | |
# col("datasetB.id").alias("idB"), | |
# col("JaccardDistance") | |
# ) | |
# # duplicates.show() | |
# duplicates = duplicates.filter("idA != idB") | |
# duplicates = duplicates.filter("idA < idB") | |
# duplicates.show() | |
# duplicates_ids = duplicates.select("idA").distinct().collect() | |
# duplicates_ids = [row.idA for row in duplicates_ids] | |
# # db.filter(db.id.isin(duplicates.ids)).show() | |
# print(db.count()) | |
# db = db.filter(~db.id.isin(duplicates_ids)) | |
# # write to parquet | |
# db.write.parquet("./dataset_dedupped", mode="overwrite") | |
# print(db.count()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.ml import Pipeline | |
from pyspark.ml.feature import RegexTokenizer, NGram, HashingTF, MinHashLSH | |
from pyspark.sql.functions import col, monotonically_increasing_id | |
from pyspark.sql.functions import desc, row_number, monotonically_increasing_id | |
from pyspark.sql.window import Window | |
from spark_session_builder import build_spark_session | |
spark = build_spark_session("spark://cpu64-dy-c6i-16xlarge-1:7077", 32, 256) | |
db = spark.createDataFrame([ | |
"Hello there 😊! I really like Spark ❤️!", | |
"Can anyone suggest an efficient algorithm", | |
"anyone suggest an efficient algorithm", | |
"Hello there 7l | real|y like Spark!", | |
"Hola, como estas? Me gusta mucho Spark!", | |
], "string").toDF("text") | |
db = db.withColumn('id', row_number().over(Window.orderBy(monotonically_increasing_id())) - 1) | |
db.show() | |
model = Pipeline(stages=[ | |
RegexTokenizer( | |
pattern="", inputCol="text", outputCol="tokens", minTokenLength=1 | |
), | |
NGram(n=3, inputCol="tokens", outputCol="ngrams"), | |
HashingTF(inputCol="ngrams", outputCol="vectors"), | |
MinHashLSH(inputCol="vectors", outputCol="lsh", numHashTables=13) | |
]).fit(db) | |
db_hashed = model.transform(db) | |
# add id column | |
# db_hashed = db_hashed.withColumn("id", monotonically_increasing_id()) | |
# db_hashed["id"] = [i for i in range(db_hashed.count())] | |
duplicates = model.stages[-1].approxSimilarityJoin( | |
db_hashed, | |
db_hashed, | |
0.85, | |
distCol="JaccardDistance" | |
).select( | |
col("datasetA.id").alias("idA"), | |
col("datasetB.id").alias("idB"), | |
col("JaccardDistance") | |
) | |
duplicates.show() | |
duplicates = duplicates.filter("idA != idB") | |
duplicates = duplicates.filter("idA < idB") | |
duplicates.show() | |
# filter out duplicate ids | |
duplicates_ids = duplicates.select("idA").distinct().collect() | |
print(duplicates_ids) | |
duplicates_ids = [row.idA for row in duplicates_ids] | |
# db.filter(db.id.isin(duplicates.ids)).show() | |
db.filter(~db.id.isin(duplicates_ids)).show() | |
# duplicates = duplicates.filter("datasetA.id < datasetB.id") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment