Created
August 6, 2016 14:30
-
-
Save schnee/0d7e2000c16c7a66568c38be7e2a983f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Databricks notebook source exported at Sat, 6 Aug 2016 14:28:53 UTC | |
// MAGIC %md | |
// MAGIC # Breckenridge Property Description Topic Modeling | |
// MAGIC This notebook turns the text contained in property descriptions in the Breckenridge CO US destination into topic probability distributions for subsequent math. The chief output is the LDA-determined topic distributions. These distributions are analyzed for similarity scores in an R document elsewhere. | |
// MAGIC | |
// MAGIC For getting all the topic distributions, you'll want the 'clusteredDF' object | |
// COMMAND ---------- | |
import org.apache.spark.sql.SaveMode; | |
import org.apache.spark.sql.types.{StructType, StructField, StringType, FloatType}; | |
// just get the Breck descriptions | |
val customSchema = StructType(Array( | |
StructField("country", StringType, true), | |
StructField("idvalue", StringType, true), | |
StructField("locality", StringType, true), | |
StructField("lat", FloatType, true), | |
StructField("lon", FloatType, true), | |
StructField("propertyType", StringType, true), | |
StructField("numBathroom", StringType, true), | |
StructField("numBedrooms", StringType, true), | |
StructField("description", StringType, true), | |
StructField("region", StringType, true), | |
StructField("datasource", StringType, true), | |
StructField("countrycode", StringType, true) | |
)) | |
// this reads in the CSV. The chief column is the 'descriptions' column | |
val geoCSV = sqlContext.read.format("csv") | |
.option("header", "true") | |
.schema(customSchema) | |
.load("/FileStore/tables/0lxpa4cl1456846124647/breckenridge.csv") | |
val fileName = "/tmp/geo.parquet" | |
geoCSV.filter("description is not null").write.mode(SaveMode.Overwrite).parquet(fileName) | |
val geo = sqlContext.read.parquet(fileName) | |
geo.printSchema | |
// COMMAND ---------- | |
// this is all just to simulate "near" matches. In real life, there's no way we'd do this - the | |
// cleaning and topic modeling would be done on the geo.description column directly | |
// get just the "Internal" listing | |
val internals = geo.filter($"datasource" === "Internal") | |
// select 15% of them to duplicate | |
val toDupe = internals.sample(false, 0.15, 55667) | |
//renamed the idvalue column to indicate we have a dupe | |
val renamer = udf { (idvalue: String) => | |
"dupe-of-" + idvalue | |
} | |
val renamed = toDupe.withColumn("idvalue", renamer(toDupe("idvalue"))) | |
//change the description of the dupes by removing some of the words | |
//this simulates slight changes | |
val fuzzer = udf { (description: String) => | |
val r = new scala.util.Random(55667) | |
val pctIntact = 100 | |
val theList = description.split(" ") | |
//delete 5% of the words | |
val newList = theList.filter(elm => r.nextInt(100) < pctIntact ) | |
newList.mkString(" ") | |
} | |
val fuzzed = renamed.withColumn("description", fuzzer(renamed("description"))) | |
// this is the labeled dataset | |
val labeled = internals.unionAll(fuzzed) | |
display(labeled) | |
// COMMAND ---------- | |
val cleaner = udf { (description: String) => | |
description.toLowerCase() | |
.replaceAll(" "," ") | |
.replaceAll("\\.", "\\. ") | |
.replaceAll("nbsp", " ") | |
.replaceAll(" "," ") | |
} | |
val geoClean = labeled.withColumn("description", cleaner(labeled("description"))) | |
display(geoClean) | |
// COMMAND ---------- | |
import org.apache.spark.ml.feature.RegexTokenizer | |
import org.apache.spark.ml.feature.StopWordsRemover | |
import org.apache.spark.ml.feature.CountVectorizer | |
// Split each document into words | |
val tokenizer = new RegexTokenizer() | |
.setInputCol("description") | |
.setOutputCol("words") | |
.setGaps(false) | |
.setPattern("\\p{L}+") | |
// Remove semantically uninteresting words like "the", "and", ... | |
val stopWordsFilter = new StopWordsRemover() | |
.setInputCol(tokenizer.getOutputCol) | |
.setOutputCol("filteredWords") | |
.setCaseSensitive(false) | |
// Simple Counts | |
// Limit to top `vocabSize` most common words and convert to word count vector features | |
val vocabSize: Int = 10000 | |
val countVectorizer = new CountVectorizer() | |
.setInputCol(stopWordsFilter.getOutputCol) | |
.setOutputCol("countFeatures") | |
.setVocabSize(vocabSize) | |
.setMinDF(2) | |
.setMinTF(1) | |
// COMMAND ---------- | |
import org.apache.spark.ml.Pipeline | |
val fePipeline = new Pipeline() | |
.setStages(Array(tokenizer, stopWordsFilter, countVectorizer)) | |
val fePipelineModel = fePipeline.fit(geoClean) | |
val featuresDF = fePipelineModel.transform(geoClean) | |
fePipelineModel.write.overwrite.save(s"/mnt/$MountName/fePipelineModel") | |
display(featuresDF) | |
// COMMAND ---------- | |
// MAGIC %md | |
// MAGIC # Feature Engineering is Done | |
// MAGIC We've transformed words into numbers, so now we can build a model | |
// COMMAND ---------- | |
/*** | |
* TAKES TOO LONG FOR DEMO - 2 whole minutes | |
****/ | |
import org.apache.spark.ml.clustering.LDA | |
import org.apache.spark.mllib.linalg.Vectors | |
import org.apache.spark.ml.Pipeline | |
val numTopics = 600 | |
val numIterations = 300 | |
// Perform Latent Dirichlet Allocation over the simple counts | |
val countLDA = new LDA() | |
.setK(numTopics) | |
.setMaxIter(numIterations) | |
.setSeed(55667) | |
.setFeaturesCol(countVectorizer.getOutputCol) | |
.setTopicDistributionCol("countTopicDistribution") | |
val clusterPipeline = new Pipeline() | |
.setStages(Array(countLDA)) | |
// teach a model how to transform text into | |
val clusterPipelineModel = clusterPipeline.fit(featuresDF) | |
clusterPipelineModel.write.overwrite().save(s"/mnt/$MountName/clusterPipelineModel") | |
// COMMAND ---------- | |
// MAGIC %md | |
// MAGIC # Transform some data | |
// MAGIC Almost every line of code above was to get to this point. Here is where we are generating the topic probability distribution vectors. And we can do math with those | |
// COMMAND ---------- | |
/**** | |
* DEPENDS ON THE THING THAT TAKES TOO LONG TO DEMO | |
****/ | |
val clusteredDF = clusterPipelineModel.transform(featuresDF) | |
clusteredDF.write.mode(SaveMode.Overwrite).parquet(s"/mnt/$MountName/clustered.parquet") | |
// COMMAND ---------- | |
// MAGIC %md | |
// MAGIC We've extracted features, built a model and transformed data: Text -> Word-Count Vectors -> Topic Probability Distributions | |
// COMMAND ---------- | |
val cDF = sqlContext.read.parquet(s"/mnt/$MountName/clustered.parquet") | |
display(cDF.filter("idvalue = 1940477").select("description","countFeatures","countTopicDistribution")) | |
var fileName = "clustered-labeled-o"+pctIntact+"-t"+numTopics+".json" | |
cDF.repartition(1).write.mode(SaveMode.Overwrite).json(s"/mnt/$MountName/$fileName") | |
// COMMAND ---------- | |
import org.apache.spark.ml.feature.CountVectorizerModel | |
import org.apache.spark.ml.clustering.LDAModel | |
import org.apache.spark.ml.PipelineModel | |
val persistedClusterPipelineModel = PipelineModel.load(s"/mnt/$MountName/clusterPipelineModel") | |
val persistedfePipelineModel = PipelineModel.load(s"/mnt/$MountName/fePipelineModel") | |
val ldaModel = persistedClusterPipelineModel.stages(0).asInstanceOf[LDAModel] | |
val topicsDF = ldaModel.describeTopics(maxTermsPerTopic = 10) | |
val vocabArray = persistedfePipelineModel.stages(2).asInstanceOf[CountVectorizerModel].vocabulary | |
val termWeightsPerTopicRDD = topicsDF.select($"termIndices", $"termWeights").map(row => { | |
val terms = row.getSeq[Int](0) | |
val termWeights = row.getSeq[Double](1) | |
terms.map(idx => vocabArray(idx)).zip(termWeights) | |
}) | |
println("\nTopics:\n") | |
// Call collect for display purposes only - otherwise keep things on the cluster | |
termWeightsPerTopicRDD.collect().zipWithIndex.take(151).foreach{ case (topic, i) => | |
println(s"Topic $i") | |
topic.foreach { case (term, weight) => println(s"$term\t\t\t$weight") } | |
println(s"==========") | |
} | |
//point out topic 6, 84, 150 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment