This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| strong_edges = output_df.filter(f.col('prob') >= 0.5)\ | |
| .select('edge.src', 'edge.dst') | |
| strong_graph = GraphFrame(node, strong_edges) | |
| spark.sparkContext.setCheckpointDir("/tmp/match_checkpoints") | |
| comps = strong_graph.connectedComponents()\ | |
| .select('component', 'source', f.col('id').alias('source_id')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pyspark.sql.functions import pandas_udf | |
| import pandas as pd | |
| @pandas_udf(returnType=t.DoubleType()) | |
| def pd_predict(feature): | |
| temp = feature.values.tolist() | |
| return pd.Series(gs_rf.best_estimator_.predict_proba(temp)[:,1]) | |
| output_df = feature_df.withColumn('prob', pd_predict('features')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| human_label = spark.read.csv("YOUR_STORAGE_PATH/candidate_pair_sample_LABELED.csv")\ | |
| .filter(f.col('human_label').isNotNull())\ | |
| .distinct() | |
| feature_df = distance_df.filter(f.col('overall_sim') > 0.06)\ | |
| .withColumn('rules_label', | |
| f.when((f.col('name_tfidf_sim') >= 0.999) | (f.col('overall_sim') >= 0.999), 1) | |
| .when(f.col('overall_sim') < 0.12, 0) | |
| .otherwise(None) | |
| )\ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| distance_df = spark.read.parquet("YOUR_STORAGE_PATH/amazon_google_distance.parquet") | |
| display_cols = ['name', 'description', 'manufacturer', 'price'] | |
| sample_df = distance_df.filter((f.col('overall_sim') > 0) & (f.col('overall_sim') < 1)) | |
| .select('edge.src', 'edge.dst', *[f.concat_ws('\nVS\n', 'src.' + c, 'dst.' + c).alias(c) for c in display_cols], 'overall_sim') | |
| .sample(withReplacement=False, fraction=0.02, seed=42) | |
| sample_df.write.mode('overwrite').csv("YOUR_STORAGE_PATH/candidate_pair_sample.csv") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| @udf("double") | |
| def dot(x, y): | |
| if x is not None and y is not None: | |
| return float(x.dot(y)) | |
| else: | |
| return 0 | |
| def null_safe_levenshtein_sim(c1, c2): | |
| output = f.when(f.col(c1).isNull() | f.col(c2).isNull(), 0)\ | |
| .otherwise(1 - f.levenshtein(c1, c2) / f.greatest(f.length(c1), f.length(c2))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pyspark.sql import functions as f | |
| from pyspark.sql import types as t | |
| from pyspark.sql import Window as w | |
| import numpy as np | |
| from graphframes import GraphFrame | |
| keep_cols = ['source', 'name', 'description', 'manufacturer', 'price', | |
| 'name_swRemoved', 'description_swRemoved', 'manufacturer_swRemoved', | |
| 'name_swRemoved_tfidf', 'description_swRemoved_tfidf', 'manufacturer_swRemoved_tfidf', | |
| 'name_encoding', 'description_encoding'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pyspark.sql import functions as f | |
| from pyspark.sql import types as t | |
| from pyspark.sql import Window as w | |
| from pyspark.ml.linalg import DenseVector, SparseVector | |
| from pyspark.ml.feature import HashingTF, IDF, Tokenizer, RegexTokenizer, CountVectorizer, StopWordsRemover, NGram, Normalizer, VectorAssembler, Word2Vec, Word2VecModel, PCA | |
| from pyspark.ml import Pipeline, Transformer | |
| from pyspark.ml.linalg import VectorUDT, Vectors | |
| import tensorflow_hub as hub |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| df = google.select( | |
| f.lit('google').alias('source'), | |
| f.col('id').alias('source_id'), | |
| f.col('name'), f.col('description'), | |
| f.col('manufacturer'), | |
| f.col('price') | |
| )\ | |
| .union( | |
| amazon.select( | |
| f.lit('amazon').alias('source'), |
NewerOlder