Last active
March 16, 2023 10:12
-
-
Save dmmiller612/883b30db1a16c7c24a383921c405066c to your computer and use it in GitHub Desktop.
Spark DL4J Dataframe/Dataset usage
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml.feature.*; | |
class ExampleUsage { | |
public void example(){ | |
List<Row> dubs = Lists.newArrayList( | |
RowFactory.create(1000.0, 1000.0, 1.0), | |
RowFactory.create(90.0, 90.0, 0.0) | |
); | |
DataFrame df = sqlContext.createDataFrame(dubs, createStruct()); | |
Pipeline p = new Pipeline().setStages(new PipelineStage[]{getAssembler(new String[]{"x", "y"}, "features")}); | |
DataFrame part2 = p.fit(df).transform(df).select("features", "label"); | |
SparkDl4jNetwork sparkDl4jNetwork = new SparkDl4jNetwork() | |
.setFeaturesCol("features") | |
.setLabelCol("label") | |
.setTrainingMaster(() -> new ParameterAveragingTrainingMaster.Builder(3) | |
.averagingFrequency(2) | |
.workerPrefetchNumBatches(2) | |
.batchSizePerWorker(2) | |
.build()) | |
.setMultiLayerConfiguration(getNNConfiguration()); | |
SparkDl4jModel sm = sparkDl4jNetwork.fit(part2); | |
MultiLayerNetwork mln = sm.getMultiLayerNetwork(); | |
Assert.assertNotNull(mln); | |
System.out.println(sm.output(Vectors.dense(0.0, 0.0))); | |
sm.write().save("somewhere"); | |
SparkDl4jModel spdm = SparkDl4jModel.load("somewhere"); | |
System.out.println(spdm.predict(Vectors.dense(0.0, 0.0))); | |
} | |
public static VectorAssembler getAssembler(String[] input, String output){ | |
return new VectorAssembler() | |
.setInputCols(input) | |
.setOutputCol(output); | |
} | |
private static StructType createStruct() { | |
return new StructType(new StructField[]{ | |
new StructField("x", DataTypes.DoubleType, true, Metadata.empty()), | |
new StructField("y", DataTypes.DoubleType, true, Metadata.empty()), | |
new StructField("label", DataTypes.DoubleType, true, Metadata.empty()) | |
}); | |
} | |
private static MultiLayerConfiguration getNNConfiguration(){ | |
return new NeuralNetConfiguration.Builder() | |
.seed(12345) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.iterations(1000) | |
.weightInit(WeightInit.UNIFORM) | |
.learningRate(0.1) | |
.updater(Updater.NESTEROVS) | |
.list() | |
.layer(0, new DenseLayer.Builder().nIn(2).nOut(100).weightInit(WeightInit.XAVIER).activation("relu").build()) | |
.layer(1, new DenseLayer.Builder().nIn(100).nOut(120).weightInit(WeightInit.XAVIER).activation("relu").build()) | |
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("softmax").nIn(120).nOut(2).build()) | |
.pretrain(false).backprop(true) | |
.build(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml.{PredictionModel, Predictor} | |
import org.apache.spark.ml.param.ParamMap | |
import org.apache.spark.ml.util._ | |
import org.apache.spark.mllib.regression.LabeledPoint | |
import org.apache.spark.mllib.linalg.Vector | |
import org.apache.spark.sql.DataFrame | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork | |
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer | |
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster | |
import org.deeplearning4j.util.ModelSerializer | |
import org.nd4j.linalg.dataset.DataSet | |
import org.nd4j.linalg.factory.Nd4j | |
import org.nd4j.linalg.util.FeatureUtil | |
final class SparkDl4jNetwork( | |
override val uid: String | |
) extends Predictor[Vector, SparkDl4jNetwork, SparkDl4jModel] { | |
private var _multiLayerConfiguration : MultiLayerConfiguration = _ | |
private var _numLabels : Int = 2 | |
private var _freq : Int = 10 | |
private var _trainingMaster : Serializer = _ | |
def this() = this(Identifiable.randomUID("dl4j")) | |
override def train(dataset: DataFrame) : SparkDl4jModel = { | |
val sparkNet = new SparkDl4jMultiLayer(dataset.sqlContext.sparkContext, _multiLayerConfiguration, _trainingMaster()) | |
val lps = dataset.select(getFeaturesCol, getLabelCol).rdd | |
.map(row => new LabeledPoint(row.getAs[Double](getLabelCol), row.getAs[Vector](getFeaturesCol))) | |
.map(item => { | |
val features = item.features | |
val label = item.label | |
if (_numLabels > 1) { | |
new DataSet(Nd4j.create(features.toArray), FeatureUtil.toOutcomeVector(label.asInstanceOf[Int], _numLabels)) | |
} else { | |
new DataSet(Nd4j.create(features.toArray), Nd4j.create(Array(label))) | |
} | |
}) | |
sparkNet.fit(lps) | |
new SparkDl4jModel(uid, sparkNet) | |
} | |
override def copy(extra: ParamMap) : SparkDl4jNetwork = defaultCopy(extra) | |
def setNumLabels(value: Int) : SparkDl4jNetwork = { | |
this._numLabels = value | |
this | |
} | |
def setMultiLayerConfiguration(multiLayerConfiguration: MultiLayerConfiguration) : SparkDl4jNetwork = { | |
this._multiLayerConfiguration = multiLayerConfiguration | |
this | |
} | |
def setTrainingMaster(tm: Serializer) : SparkDl4jNetwork = { | |
this._trainingMaster = tm | |
this | |
} | |
} | |
class SparkDl4jModel(override val uid: String, network: SparkDl4jMultiLayer) | |
extends PredictionModel[Vector, SparkDl4jModel] with Serializable with MLWritable { | |
override def copy(extra: ParamMap) : SparkDl4jModel = { | |
copyValues(new SparkDl4jModel(uid, network)).setParent(parent) | |
} | |
override def predict(features: Vector) : Double = { | |
val v = output(features) | |
if (v.size > 1) { | |
v.argmax | |
} else if (v.size == 1) { | |
v.toArray(0) | |
} else throw new RuntimeException("Vector size must be greater than 1") | |
} | |
def getMultiLayerNetwork : MultiLayerNetwork = network.getNetwork | |
def output(vector: Vector) : Vector = network.predict(vector) | |
protected[SparkDl4jModel] class SparkDl4jModelWriter(instance: SparkDl4jModel) extends MLWriter { | |
override protected def saveImpl(path: String): Unit = { | |
ModelSerializer.writeModel(network.getNetwork, path, true) | |
} | |
} | |
override def write : MLWriter = new SparkDl4jModelWriter(this) | |
} | |
object SparkDl4jModel extends MLReadable[SparkDl4jModel] { | |
override def read: MLReader[SparkDl4jModel] = new SparkDl4jReader | |
override def load(path: String): SparkDl4jModel = super.load(path) | |
private class SparkDl4jReader extends MLReader[SparkDl4jModel] { | |
override def load(path: String) : SparkDl4jModel = { | |
val mln = ModelSerializer.restoreMultiLayerNetwork(path) | |
new SparkDl4jModel(Identifiable.randomUID("dl4j"), new SparkDl4jMultiLayer(sc, mln, null)) | |
} | |
} | |
} | |
trait Serializer extends Serializable { | |
def apply() : ParameterAveragingTrainingMaster | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment