Skip to content

Instantly share code, notes, and snippets.

@dmmiller612
Last active March 16, 2023 10:12
Show Gist options
  • Save dmmiller612/883b30db1a16c7c24a383921c405066c to your computer and use it in GitHub Desktop.
Save dmmiller612/883b30db1a16c7c24a383921c405066c to your computer and use it in GitHub Desktop.
Spark DL4J Dataframe/Dataset usage
import org.apache.spark.ml.feature.*;
class ExampleUsage {
public void example(){
List<Row> dubs = Lists.newArrayList(
RowFactory.create(1000.0, 1000.0, 1.0),
RowFactory.create(90.0, 90.0, 0.0)
);
DataFrame df = sqlContext.createDataFrame(dubs, createStruct());
Pipeline p = new Pipeline().setStages(new PipelineStage[]{getAssembler(new String[]{"x", "y"}, "features")});
DataFrame part2 = p.fit(df).transform(df).select("features", "label");
SparkDl4jNetwork sparkDl4jNetwork = new SparkDl4jNetwork()
.setFeaturesCol("features")
.setLabelCol("label")
.setTrainingMaster(() -> new ParameterAveragingTrainingMaster.Builder(3)
.averagingFrequency(2)
.workerPrefetchNumBatches(2)
.batchSizePerWorker(2)
.build())
.setMultiLayerConfiguration(getNNConfiguration());
SparkDl4jModel sm = sparkDl4jNetwork.fit(part2);
MultiLayerNetwork mln = sm.getMultiLayerNetwork();
Assert.assertNotNull(mln);
System.out.println(sm.output(Vectors.dense(0.0, 0.0)));
sm.write().save("somewhere");
SparkDl4jModel spdm = SparkDl4jModel.load("somewhere");
System.out.println(spdm.predict(Vectors.dense(0.0, 0.0)));
}
public static VectorAssembler getAssembler(String[] input, String output){
return new VectorAssembler()
.setInputCols(input)
.setOutputCol(output);
}
private static StructType createStruct() {
return new StructType(new StructField[]{
new StructField("x", DataTypes.DoubleType, true, Metadata.empty()),
new StructField("y", DataTypes.DoubleType, true, Metadata.empty()),
new StructField("label", DataTypes.DoubleType, true, Metadata.empty())
});
}
private static MultiLayerConfiguration getNNConfiguration(){
return new NeuralNetConfiguration.Builder()
.seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1000)
.weightInit(WeightInit.UNIFORM)
.learningRate(0.1)
.updater(Updater.NESTEROVS)
.list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(100).weightInit(WeightInit.XAVIER).activation("relu").build())
.layer(1, new DenseLayer.Builder().nIn(100).nOut(120).weightInit(WeightInit.XAVIER).activation("relu").build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("softmax").nIn(120).nOut(2).build())
.pretrain(false).backprop(true)
.build();
}
}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.DataFrame
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster
import org.deeplearning4j.util.ModelSerializer
import org.nd4j.linalg.dataset.DataSet
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.util.FeatureUtil
final class SparkDl4jNetwork(
override val uid: String
) extends Predictor[Vector, SparkDl4jNetwork, SparkDl4jModel] {
private var _multiLayerConfiguration : MultiLayerConfiguration = _
private var _numLabels : Int = 2
private var _freq : Int = 10
private var _trainingMaster : Serializer = _
def this() = this(Identifiable.randomUID("dl4j"))
override def train(dataset: DataFrame) : SparkDl4jModel = {
val sparkNet = new SparkDl4jMultiLayer(dataset.sqlContext.sparkContext, _multiLayerConfiguration, _trainingMaster())
val lps = dataset.select(getFeaturesCol, getLabelCol).rdd
.map(row => new LabeledPoint(row.getAs[Double](getLabelCol), row.getAs[Vector](getFeaturesCol)))
.map(item => {
val features = item.features
val label = item.label
if (_numLabels > 1) {
new DataSet(Nd4j.create(features.toArray), FeatureUtil.toOutcomeVector(label.asInstanceOf[Int], _numLabels))
} else {
new DataSet(Nd4j.create(features.toArray), Nd4j.create(Array(label)))
}
})
sparkNet.fit(lps)
new SparkDl4jModel(uid, sparkNet)
}
override def copy(extra: ParamMap) : SparkDl4jNetwork = defaultCopy(extra)
def setNumLabels(value: Int) : SparkDl4jNetwork = {
this._numLabels = value
this
}
def setMultiLayerConfiguration(multiLayerConfiguration: MultiLayerConfiguration) : SparkDl4jNetwork = {
this._multiLayerConfiguration = multiLayerConfiguration
this
}
def setTrainingMaster(tm: Serializer) : SparkDl4jNetwork = {
this._trainingMaster = tm
this
}
}
class SparkDl4jModel(override val uid: String, network: SparkDl4jMultiLayer)
extends PredictionModel[Vector, SparkDl4jModel] with Serializable with MLWritable {
override def copy(extra: ParamMap) : SparkDl4jModel = {
copyValues(new SparkDl4jModel(uid, network)).setParent(parent)
}
override def predict(features: Vector) : Double = {
val v = output(features)
if (v.size > 1) {
v.argmax
} else if (v.size == 1) {
v.toArray(0)
} else throw new RuntimeException("Vector size must be greater than 1")
}
def getMultiLayerNetwork : MultiLayerNetwork = network.getNetwork
def output(vector: Vector) : Vector = network.predict(vector)
protected[SparkDl4jModel] class SparkDl4jModelWriter(instance: SparkDl4jModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
ModelSerializer.writeModel(network.getNetwork, path, true)
}
}
override def write : MLWriter = new SparkDl4jModelWriter(this)
}
object SparkDl4jModel extends MLReadable[SparkDl4jModel] {
override def read: MLReader[SparkDl4jModel] = new SparkDl4jReader
override def load(path: String): SparkDl4jModel = super.load(path)
private class SparkDl4jReader extends MLReader[SparkDl4jModel] {
override def load(path: String) : SparkDl4jModel = {
val mln = ModelSerializer.restoreMultiLayerNetwork(path)
new SparkDl4jModel(Identifiable.randomUID("dl4j"), new SparkDl4jMultiLayer(sc, mln, null))
}
}
}
trait Serializer extends Serializable {
def apply() : ParameterAveragingTrainingMaster
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment