Created
January 27, 2016 21:52
-
-
Save asimjalis/965bd44657b90aeab887 to your computer and use it in GitHub Desktop.
Spark MLlib Twitter Quickstart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Load tweets. | |
import scala.util.control.Breaks._ | |
import scala.collection.JavaConversions._ | |
import twitter4j.{Twitter,Query,TwitterFactory} | |
val twitter = TwitterFactory.getSingleton | |
val query = new twitter4j.Query("lang:en") | |
query.setCount(100) | |
query.setSince("2016-01-13") | |
query.setUntil("2016-01-24") | |
def getMaxId(tweets:java.util.List[twitter4j.Status]) = { | |
tweets.map(_.getId).max | |
} | |
def getMinId(tweets:java.util.List[twitter4j.Status]) = { | |
tweets.map(_.getId).min | |
} | |
def getTweetsAfterId(lastId:Long) = { | |
query.setSinceId(lastId) | |
twitter.search(query).getTweets() | |
} | |
def getTweetsBeforeId(firstId:Long) = { | |
query.setMaxId(firstId - 1) | |
twitter.search(query).getTweets() | |
} | |
var maxId = 0L | |
var minId = Long.MinValue | |
val tweetMaxCount = 10000 | |
val tweetList = new java.util.ArrayList[twitter4j.Status] | |
while (tweetList.size < tweetMaxCount ) { | |
try { | |
val tweets = getTweetsBeforeId(minId) | |
println("tweetList.size=" + tweetList.size) | |
tweetList.addAll(tweets) | |
if (tweets.size != 0) { | |
minId = getMinId(tweets) | |
} else { | |
break | |
} | |
} | |
catch { | |
case e:Throwable => { println(e.getMessage); break } | |
} | |
} | |
// Build model. | |
import org.apache.spark.mllib.regression.LabeledPoint | |
import org.apache.spark.mllib.linalg._ | |
import org.apache.spark.mllib.stat.Statistics | |
import org.apache.spark.mllib.linalg.Vectors | |
import org.apache.spark.ml.Pipeline | |
import org.apache.spark.ml.evaluation.RegressionEvaluator | |
import org.apache.spark.ml.feature.VectorIndexer | |
import org.apache.spark.ml.regression.{RandomForestRegressionModel, | |
RandomForestRegressor} | |
def safeLog(x:Double) = { | |
math.log(x + 1) | |
} | |
val data = sc.parallelize(tweetList). | |
filter(_.getRetweetCount > 10). | |
map(t => | |
LabeledPoint( | |
safeLog(t.getRetweetCount), | |
Vectors.dense( | |
safeLog(t.getUser.getFollowersCount), | |
safeLog(t.getMediaEntities.size), | |
safeLog(t.getUserMentionEntities.size), | |
safeLog(t.getHashtagEntities.size), | |
safeLog(t.getText.length)))).toDF | |
val featureIndexer = new VectorIndexer(). | |
setInputCol("features"). | |
setOutputCol("indexedFeatures"). | |
setMaxCategories(4). | |
fit(data) | |
// Split data into training and test sets. | |
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) | |
// Train RandomForest model. | |
val rf = new RandomForestRegressor(). | |
setLabelCol("label"). | |
setFeaturesCol("indexedFeatures") | |
// Chain indexer and forest in a Pipeline | |
val pipeline = new Pipeline().setStages(Array(featureIndexer, rf)) | |
// Train model and run indexer. | |
val model = pipeline.fit(trainingData) | |
// Make predictions. | |
val predictions = model.transform(testData) | |
// Display some example rows. | |
predictions.select("prediction", "label", "features").show(5) | |
// Select (prediction, true label) and compute test error | |
val evaluator = new RegressionEvaluator(). | |
setLabelCol("label"). | |
setPredictionCol("prediction"). | |
setMetricName("rmse") | |
val rmse = evaluator.evaluate(predictions) | |
println("Root Mean Squared Error (RMSE) on test data = " + rmse) | |
// Print out feature importances. | |
val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] | |
rfModel.numFeatures | |
rfModel.featureImportances |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment