Skip to content

Instantly share code, notes, and snippets.

@agibsonccc
Created February 17, 2017 04:14
Show Gist options
  • Save agibsonccc/b402910a4d6be80b273d7efdc8de084b to your computer and use it in GitHub Desktop.
Save agibsonccc/b402910a4d6be80b273d7efdc8de084b to your computer and use it in GitHub Desktop.
val driverProps = propsCon.value
val sourceFile = new File(driverProps.filePath)
logger.info(s"the exported file is ${sourceFile.getPath}")
val numLinesToSkip: Int = 0
val delimiter: String = ", "
val schema: Schema = new Schema.Builder() //.addColumnString("year").addColumnString("month")
.addColumnString("accountNum").addColumnString("name").addColumnString("Date")
.addColumnString("status").addColumnString("accountType").addColumnString("consumerTier")
.addColumnString("oem").addColumnString("phone").addColumnString("postalCode")
.addColumnString("province").addColumnString("city").addColumnString("street")
.addColumnString("crmVin").addColumnString("model")//.addColumnString("crmModel1").addColumnString("crmModel2")
.addColumnString("crmAssetType").build
val labelMap = new util.HashMap[String, String]()
labelMap.put("SHARAN", "0")
labelMap.put("Tiguan", "1")
labelMap.put("Scirocco", "2")
labelMap.put("PASSAT", "3")
labelMap.put("NBS", "4")
labelMap.put("Golf", "5")
// Map("BEETLE" -> "0", "GOLF" -> "1", "PASSAT" -> "2", "SCIROCCO" -> "3", "SHARAN" -> "4", "TIGUAN" -> "5")
val tp = new TransformProcess.Builder(schema).transform(new StringToIntegerSetTransform("province", List("prov0", "prov1", "prov2", "prov3", "prov4",
"prov5", "prov6", "prov7", "prov8", "prov9", "prov10", "prov11", "prov12", "prov13", "prov14", "prov15", "prov16", "prov17", "prov18", "prov19", "prov20",
"prov21", "prov22", "prov23", "prov24", "prov25", "prov26", "prov27", "prov28", "prov29", "prov30", "prov31", "prov32", "prov33").asJava,
List("北京市","广东省","山东省","江苏省","河南省","上海市","河北省","浙江省","香港特别行政区","陕西省","湖南省","重庆市","福建省",
"天津市","云南省","四川省","广西壮","安徽省","海南省","江西省","湖北省","山西省","辽宁省","台湾省","黑龙江","内蒙古",
"澳门特别行政区","贵州省","甘肃省","青海省","新疆维吾尔自治区","西藏区","吉林省","宁夏回").asJava))
.removeAllColumnsExceptFor("prov0", "prov1", "prov2", "prov3", "prov4",
"prov5", "prov6", "prov7", "prov8", "prov9", "prov10", "prov11", "prov12", "prov13", "prov14", "prov15", "prov16", "prov17", "prov18", "prov19", "prov20",
"prov21", "prov22", "prov23", "prov24", "prov25", "prov26", "prov27", "prov28", "prov29", "prov30", "prov31", "prov32", "prov33", "model")
.transform(new StringMapForVWTransform("model", labelMap))//.stringMapTransform("model", labelMap)
// .categoricalToInteger("prov0", "prov1", "prov2", "prov3", "prov4",
// "prov5", "prov6", "prov7", "prov8", "prov9", "prov10", "prov11", "prov12", "prov13", "prov14", "prov15", "prov16", "prov17", "prov18", "prov19", "prov20",
// "prov21", "prov22", "prov23", "prov24", "prov25", "prov26", "prov27", "prov28", "prov29", "prov30", "prov31", "prov32", "prov33")
.build
val stringData= sc.textFile(sourceFile.getPath)
.filter {line => !line.startsWith("0CRM Account Number")}
.map {line =>
if(line.contains(", "))
line.replace(", ", "-")
else
line
}.map {
line =>
if(line.contains("HONGMEIJIEDAO,JUJINGYUAN,11JIA101"))
line.replace("HONGMEIJIEDAO,JUJINGYUAN,11JIA101", "HONGMEIJIEDAO_JUJINGYUAN_11JIA101")
else
line
}
.persist
val types = stringData.map {line =>
val splicts = line.split(",")
(splicts(13).split(" ")(0), 1)
// if(splicts.length < 13)
// println(s"-----------------------------> $line")
}.groupByKey.map {pair => (pair._1, pair._2.size)}
.filter {pair => pair._2 > 15}
.foreach {pair =>
println(s"===lable : ${pair._1} -> count : ${pair._2}")
}
// types.map {typeItem =>
// if("夏朗".equals(typeItem))
// "sh"
// else
// typeItem
// }
println(s"all the provinces is $types")
val rr: RecordReader = new CSVRecordReader
val toFun = new StringToWritablesFunction(rr)
val parsed = stringData.map {value => toFun.call(value)}
val parsedData = SparkTransformExecutor.execute(parsed, tp)
// parsedData.filter
val filterNoisyData = parsedData.filter {(writableList : util.List[Writable]) =>
!"noisy".equals(writableList.get(writableList.size - 1).toString)
}
//save the data to the disk fitstly
// val exportPath = "/tmp/annData"
// filterNoisyData.saveAsTextFile(exportPath)
// val reloadData2 = sc.textFile(exportPath)
//
// val exportPath2 = "/tmp/annData2"
// reloadData2.map {
// line =>
// val str = line.substring(1, line.length - 1)
// str.split(", ").mkString(",")
// }.saveAsTextFile(exportPath2)
val recordReader = new CSVRecordReader()
recordReader.initialize(new FileSplit(new File("/tmp/Data/"), Array("csv")))
// val forTestFile = new File("/tmp/annData2")
// sc.textFile(forTestFile.getPath).foreach {line => println(s"data is ===> ${line}")}
// val tuguanNum = sc.textFile(forTestFile.getPath).filter {line => line.endsWith("1")}.count
// println(s"Tiguan number is ==========> ${tuguanNum}")
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
val labelIndex: Int = 34 //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
val numClasses: Int = 6 //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
val batchSize: Int = 150 //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
//val labelIndex: Int = 4 //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
// val numClasses: Int = 3 //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
// val batchSize: Int = 150 //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
val iterator: DataSetIterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses)
val allData: DataSet = iterator.next
// val verifiedMax = Nd4j.getExecutioner.exec(new IMax(allData.getLabels), 1)
// var verifyIndex = 0
// while(verifyIndex < 10) {
// println(s"${allData.getFeatures.getRow(verifyIndex)}----${verifiedMax.getDouble(verifyIndex)}")
// verifyIndex += 1
// }
// println(s"get the first column of first record ${allData.get(0).get(0)}")
allData.shuffle()
val testAndTrain = allData.splitTestAndTrain(0.85) //Use 65% of data for training
val trainingData: DataSet = testAndTrain.getTrain
val testData: DataSet = testAndTrain.getTest
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
val normalizer = new NormalizerStandardize
normalizer.fit(trainingData) //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData) //Apply normalization to the training data
normalizer.transform(testData) //Apply normalization to the test data. This is using statistics calculated from the *training* set
val numInputs = 34
val outputNum = 6
val iterations = 1000
val seed = 6
val conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true).l2(1e-4)
.list
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(10).build)
.layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build)
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(10).nOut(outputNum).build)
.backprop(true).pretrain(false).build
//run the model
val model = new MultiLayerNetwork(conf)
model.init()
model.setListeners(new ScoreIterationListener(100))
var n = 0
val nEpochs = 1
iterator.reset()
while (n < nEpochs) {
model.fit(iterator)
n += 1
}
// model.fit(trainingData)
//evaluate the model on the test set
val labelNames = List("SHARAN", "Tiguan", "Scirocco", "PASSAT", "NBS", "Golf")
val eval = new Evaluation(labelNames.length)
val output = model.output(testData.getFeatureMatrix)
eval.eval(testData.getLabels, output)
println(eval.stats)
// val testIter = testData.iterator
// while (testIter.hasNext) {
// model.predict()
// val next: DataSet = testData.next
// val output: INDArray = model.output(next.getFeatureMatrix) //get the networks prediction
// eval.eval(next.getLabels, output) //check the prediction against the true class
// }
testData.setLabelNames(labelNames.asJava)
val predicted = model.predict(testData)
// val actuallyLabel = testData.getLabels
val actualLabels = testData.getLabelNamesList
var index = 0
val testIter = testData.iterator
val argMax = Nd4j.getExecutioner.exec(new IMax(testData.getLabels), 1)
while(testIter.hasNext) {
testIter.next.outcome()
// val actualLabel = testData.getLabelName(index)
val outcome = labelNames(argMax.getDouble(index).toInt)
val predictedLabel = predicted.get(index)
index = index + 1
println(s" the predicated label is ${outcome} -- ${predictedLabel}")//the actualLabel is ${actualLabel} and
}
println(s"the actualLabel length is ${actualLabels.size} and the predicated label length is ${predicted.size}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment