Created
December 15, 2016 12:50
-
-
Save hekonsek/8132aed820fac43d1f7d499a1e5e6cbf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml.Pipeline | |
import org.apache.spark.ml.PipelineStage | |
import org.apache.spark.ml.Transformer | |
import org.apache.spark.ml.classification.LogisticRegression | |
import org.apache.spark.ml.feature.LabeledPoint | |
import org.apache.spark.ml.linalg.DenseVector | |
import org.apache.spark.ml.linalg.Vectors | |
import org.apache.spark.ml.param.ParamMap | |
import org.apache.spark.sql.Dataset | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.SparkSession | |
import org.apache.spark.sql.catalyst.encoders.RowEncoder | |
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema | |
import org.apache.spark.sql.types.StructType | |
class AbsTransformer extends Transformer { | |
@Override | |
Dataset<Row> transform(Dataset<?> dataset) { | |
(dataset as Dataset<GenericRowWithSchema>).map({ value -> | |
def values = (value.values().first() as DenseVector).values() | |
for (int i = 0; i < values.length; i++) { | |
values[i] = Math.abs(values[i]) | |
} | |
value as Row | |
}, RowEncoder.apply(dataset.schema())) | |
} | |
@Override | |
StructType transformSchema(StructType structType) { | |
structType | |
} | |
@Override | |
Transformer copy(ParamMap paramMap) { | |
this | |
} | |
@Override | |
String uid() { | |
UUID.randomUUID().toString() | |
} | |
public static void main(String[] args) { | |
def spark = SparkSession.builder().master('local[*]').getOrCreate() | |
def trainingData = spark.createDataFrame([ | |
new LabeledPoint(1.0d, Vectors.dense([18.0d, -25.0d] as double[])), | |
new LabeledPoint(1.0d, Vectors.dense([-15.0d, 20.0d] as double[])), | |
new LabeledPoint(1.0d, Vectors.dense([10.0d, 27.0d] as double[])), | |
new LabeledPoint(0.0d, Vectors.dense([0.0d, 5.0d] as double[])), | |
new LabeledPoint(0.0d, Vectors.dense([0.0d, -6.0d] as double[])), | |
new LabeledPoint(0.0d, Vectors.dense([0.0d, 3.0d] as double[])) | |
], LabeledPoint) | |
def stages = [new AbsTransformer(), new LogisticRegression()] as PipelineStage[] | |
def pipeLine = new Pipeline().setStages(stages) | |
def model = pipeLine.fit(trainingData) | |
def data = spark.createDataFrame([new LabeledPoint(1.0d, Vectors.dense([-20.0d, -20.0d] as double[]))], LabeledPoint) | |
def result = model.transform(data) | |
result.show() | |
def confidence = result.collectAsList().first().get(3) | |
println confidence | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment