Skip to content

Instantly share code, notes, and snippets.

@paul-english
Created July 27, 2017 18:08
Show Gist options
  • Save paul-english/b72b5a0d3b27444bf5300e27a40f090c to your computer and use it in GitHub Desktop.
Save paul-english/b72b5a0d3b27444bf5300e27a40f090c to your computer and use it in GitHub Desktop.
EditDistance.scala, not the best for prod, spark sql has levenshtein built in, but useful as a template for spark ml transformers
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ ParamMap, Param }
import org.apache.spark.ml.util.{ Identifiable, DefaultParamsWritable, DefaultParamsReadable }
import org.apache.spark.sql.functions.{ udf, col }
import org.apache.spark.sql.types.{ StructType, StructField, DoubleType }
import org.apache.spark.sql.{ Dataset, DataFrame }
object EditDistance extends DefaultParamsReadable[EditDistance] {
override def load(path: String): EditDistance = super.load(path)
def minimum(i1: Int, i2: Int, i3: Int) = scala.math.min(scala.math.min(i1, i2), i3)
def editDistance(s1: String, s2: String): Double = {
(Option(s1), Option(s2)) match {
case (None, None) => 0.0
case (Some(s), None) => s.length.toDouble
case (None, Some(s)) => s.length.toDouble
case (Some(s1), Some(s3)) => {
val dist = Array.tabulate(s2.length + 1, s1.length + 1) { (j, i) => if (j == 0) i else if (i == 0) j else 0 }
for (j <- 1 to s2.length; i <- 1 to s1.length) {
dist(j)(i) = if (s2(j - 1) == s1(i - 1)) dist(j - 1)(i - 1)
else minimum(dist(j - 1)(i) + 1, dist(j)(i - 1) + 1, dist(j - 1)(i - 1) + 1)
}
dist(s2.length)(s1.length).toDouble
}
}
}
val edit_distance = udf[Double, String, String](editDistance)
}
class EditDistance(override val uid: String) extends Transformer with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("editdistance"))
def setLeftCol(value: String): this.type = set(leftCol, value)
def setRightCol(value: String): this.type = set(rightCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
val leftCol = new Param[String](this, "leftCol", "left col")
val rightCol = new Param[String](this, "rightCol", "right col")
val outputCol = new Param[String](this, "outputCol", "output col")
override def transform(df: Dataset[_]): DataFrame = {
val left_col = extractParamMap.getOrElse(leftCol, "left")
val right_col = extractParamMap.getOrElse(rightCol, "right")
val output_col = extractParamMap.getOrElse(outputCol, "output")
df.withColumn(output_col, EditDistance.edit_distance(col(left_col), col(right_col)))
}
override def copy(paramMap: ParamMap): Transformer = this
override def transformSchema(schema: StructType): StructType = {
val output_col = extractParamMap.getOrElse(outputCol, "output")
schema.add(StructField(output_col, DoubleType, false))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment