Created
October 4, 2016 20:17
-
-
Save dyangrev/ed9b6f05169ee3a392004d402f536693 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.DataFrame | |
import com.rockymadden.stringmetric.similarity.NGramMetric | |
import org.apache.spark.sql.SaveMode | |
import org.apache.spark.sql.functions._ | |
import sqlContext.implicits._ | |
import java.util.UUID | |
import com.datastax.spark.connector._ | |
object BlockingUtils { | |
object MatchingConfig { | |
// Matching configs | |
case class NGramMethod(gram: Int) | |
case class ExactMethod() | |
val config = Map( | |
"normalized_first_name" -> NGramMethod(2), | |
"normalized_middle_name" -> NGramMethod(2), | |
"normalized_last_name" -> NGramMethod(2), | |
"normalized_postal_address" -> NGramMethod(2), | |
"normalized_email_address" -> NGramMethod(2), | |
"normalized_phone_number" -> ExactMethod, | |
"normalized_date_of_birth" -> ExactMethod, | |
"normalized_primary_language" -> ExactMethod, | |
"normalized_gender" -> ExactMethod) | |
val fieldsToCompare = Seq( | |
"normalized_first_name", | |
"normalized_middle_name", | |
"normalized_last_name", | |
"normalized_postal_address", | |
"normalized_email_address", | |
"normalized_phone_number", | |
"normalized_date_of_birth", | |
"normalized_primary_language", | |
"normalized_gender") | |
// End of matching configs | |
} | |
object UDFS { | |
val uuidUdf = udf(() => UUID.randomUUID().toString) | |
def applyComparison[T]: (String, String, (String, String) => T) => Option[T] = | |
(str1: String, str2: String, comparisonMethod: (String, String) => T) => { | |
def isValid = (str: String) => Option(str).exists(s => s.nonEmpty && s != "null") | |
if (isValid(str1) && isValid(str2)) Some(comparisonMethod(str1, str2)) | |
else None | |
} | |
val nGramUdf = (gram: Int) => udf[Double, String, String] { | |
(str1: String, str2: String) => { | |
val matcher = (s1: String, s2: String) => NGramMetric(gram).compare(s1, s2).getOrElse(0.0) | |
applyComparison(str1, str2, matcher).getOrElse(-1.0) | |
} | |
} | |
val exactUdf = udf[Double, String, String] { | |
(str1: String, str2: String) => { | |
val matcher = (s1: String, s2: String) => if (s1 == s2) 1 else 0 | |
applyComparison(str1, str2, matcher).map(_.toDouble).getOrElse(-1.0) | |
} | |
} | |
// If both field x columns are populated, size++, sum+=x_sim | |
val meanUdf = udf[Double, Double, Double, Double, Double, Double, Double, Double, Double, Double] { | |
(d1, d2, d3, d4, d5, d6, d7, d8, d9) => { | |
Seq(d1, d2, d3, d4, d5, d6, d7, d8, d9).filter(_ != -1.0) match { | |
case filteredList if filteredList.nonEmpty => filteredList.sum / filteredList.size | |
case _ => 0.0 | |
} | |
} | |
} | |
} | |
/** | |
* Load a cassandra table into a dataframe | |
* | |
* @param tableName Name of the table | |
* @param keyspace Name of the keyspace | |
* @param cluster Name of the cluster | |
*/ | |
def loadTableIntoDF(tableName: String, keyspace: String = "doppler", cluster: String = "cpark") = | |
sqlContext | |
.read | |
.format("org.apache.spark.sql.cassandra") | |
.options(Map( | |
"table" -> tableName, | |
"keyspace" -> keyspace, | |
"cluster" -> cluster)).load() | |
/** | |
* Create a blocked and joined data frame | |
* | |
* @param df Dataframe to operate on, usually the blocking normalized profile data frame | |
* @param blockingKey Blocking key, if there are more than one, all of them will be treated as a composite blocking key | |
* @return | |
*/ | |
def blockedAndJoinedDF(df: DataFrame, blockingKey: Seq[String]) = { | |
def addPostfixForDFColumns(cols: Iterable[String], df: DataFrame, postfix: String) = | |
cols.foldLeft(df)((df, colName) => df.withColumnRenamed(colName, s"$colName$postfix")) | |
val joinExpr = blockingKey.foldLeft($"profile_id_1" < $"profile_id_2")((expr, blockingKey) => | |
expr && $"${blockingKey}_1" === $"${blockingKey}_2" | |
&& !isnull($"${blockingKey}_1") | |
&& !isnull($"${blockingKey}_2") | |
&& ($"${blockingKey}_1" !== lit("null")) | |
&& ($"${blockingKey}_2" !== lit("null")) | |
) | |
val cols = df.columns | |
val df1 = addPostfixForDFColumns(cols, df, "_1") | |
val df2 = addPostfixForDFColumns(cols, df, "_2") | |
df1.join(df2, joinExpr) | |
} | |
/** | |
* Calculate the similarity for blocked and joined data frame | |
* | |
* @param blockedAndJoinedDF The blocked and joined data frame | |
*/ | |
def calculateSimilarityForBlockedAndJoinedDF(blockedAndJoinedDF: DataFrame) = { | |
val matchedDF = MatchingConfig.fieldsToCompare.foldLeft(blockedAndJoinedDF)((df, fieldToCompare) => { | |
val config = MatchingConfig.config.get(fieldToCompare) | |
config match { | |
case Some(MatchingConfig.NGramMethod(gram)) => | |
df.withColumn(s"${fieldToCompare}_sim", UDFS.nGramUdf(gram)(df(s"${fieldToCompare}_1"), df(s"${fieldToCompare}_2"))) | |
case Some(MatchingConfig.ExactMethod) => | |
df.withColumn(s"${fieldToCompare}_sim", UDFS.exactUdf(df(s"${fieldToCompare}_1"), df(s"${fieldToCompare}_2"))) | |
case _ => | |
df | |
} | |
}) | |
// Calculate the mean sim score | |
matchedDF.withColumn("mean_sim", UDFS.meanUdf( | |
matchedDF("normalized_first_name_sim"), | |
matchedDF("normalized_middle_name_sim"), | |
matchedDF("normalized_last_name_sim"), | |
matchedDF("normalized_postal_address_sim"), | |
matchedDF("normalized_email_address_sim"), | |
matchedDF("normalized_phone_number_sim"), | |
matchedDF("normalized_date_of_birth_sim"), | |
matchedDF("normalized_primary_language_sim"), | |
matchedDF("normalized_gender_sim"))) | |
} | |
/** | |
* Persist a data frame into cassandar | |
* | |
* @param df The dataframe to be persisted | |
* @param keyspace Keyspace to persist to | |
* @param table Table to persist to | |
* @param saveMode Save mode | |
*/ | |
def persistDataFrame(df: DataFrame, keyspace: String, table: String, saveMode: SaveMode = SaveMode.Append) { | |
df.write | |
.mode(SaveMode.Append) | |
.format("org.apache.spark.sql.cassandra") | |
.options(Map("keyspace" -> keyspace, "table" -> table)) | |
.save() | |
} | |
/** | |
* Load normalized profiles into a dataframe and cache it | |
*/ | |
def loadNormalizedProfiles = { | |
loadTableIntoDF("normalized_profiles") | |
.drop("profile_id") | |
.distinct // Drop the profile id and do a distinct for deduping | |
.withColumn("profile_id", BlockingUtils.UDFS.uuidUdf()) // Then we add a new profile id | |
.select( | |
"property", | |
"profile_id", | |
"normalized_first_name", | |
"normalized_middle_name", | |
"normalized_last_name", | |
"normalized_email_address", | |
"normalized_phone_number", | |
"normalized_first_name_first_3", | |
"normalized_last_name_first_3", | |
"normalized_postal_address", | |
"normalized_date_of_birth", | |
"normalized_primary_language", | |
"normalized_gender") | |
.cache() | |
} | |
def getBlockingKeysByProperty(property: String): Seq[Seq[String]] = property match { | |
case "30100" => Seq( | |
Seq("normalized_first_name"), | |
Seq("normalized_last_name", "normalized_first_name_first_3"), | |
Seq("normalized_email_address"), | |
Seq("normalized_phone_number"), | |
Seq("normalized_last_name_first_3", "normalized_postal_address") | |
) | |
case x if Seq("GOVERN", "LUCIA").contains(x) => Seq( | |
Seq("normalized_first_name", "normalized_last_name_first_3"), | |
Seq("normalized_last_name"), | |
Seq("normalized_email_address"), | |
Seq("normalized_phone_number"), | |
Seq("normalized_last_name_first_3", "normalized_postal_address") | |
) | |
case _ => Seq( | |
Seq("normalized_first_name"), | |
Seq("normalized_last_name"), | |
Seq("normalized_email_address"), | |
Seq("normalized_phone_number"), | |
Seq("normalized_last_name_first_3", "normalized_postal_address") | |
) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment