Created
February 12, 2019 00:01
-
-
Save fabrizioc1/0b1647894420daec3ab7b8e648c7f2b1 to your computer and use it in GitHub Desktop.
Example of scala Spark transformer with python wrapper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark import since, keyword_only | |
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param | |
from pyspark.ml.util import JavaMLReadable, JavaMLWritable | |
from pyspark.ml.wrapper import JavaTransformer | |
class Stemmer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): | |
@keyword_only | |
def __init__(self, inputCol=None, outputCol=None): | |
super(Stemmer, self).__init__() | |
self._java_obj = self._new_java_obj("org.fct.spark.transformer.Stemmer", self.uid) | |
kwargs = self._input_kwargs | |
self.setParams(**kwargs) | |
@keyword_only | |
def setParams(self, inputCol=None, outputCol=None): | |
kwargs = self._input_kwargs | |
return self._set(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.fct.spark.transformer | |
import edu.stanford.nlp.process.Morphology | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.sql.types.{DataType, StringType, ArrayType} | |
import org.apache.spark.ml.util.Identifiable | |
import org.apache.spark.ml.param.{Param, ParamMap} | |
import org.apache.spark.ml.UnaryTransformer | |
class Stemmer(override val uid: String) extends org.apache.spark.ml.UnaryTransformer[Seq[String], Seq[String], Stemmer] { | |
def this() = this(org.apache.spark.ml.util.Identifiable.randomUID("stemmer")) | |
override protected def createTransformFunc: Seq[String] => Seq[String] = { strArray => | |
val stemmer = new edu.stanford.nlp.process.Morphology() | |
strArray.map(originStr => stemmer.stem(originStr)) | |
} | |
override protected def validateInputType(inputType: org.apache.spark.sql.types.DataType): Unit = { | |
require(inputType == org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.StringType), s"Input type must be ArrayType(StringType) but got $inputType.") | |
} | |
override protected def outputDataType: org.apache.spark.sql.types.DataType = new org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.StringType, false) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment