Last active
October 26, 2015 01:34
-
-
Save toff63/43c5f7bf73702f040397 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.File | |
import scala.io.Source | |
import org.apache.log4j.Logger | |
import org.apache.log4j.Level | |
import org.apache.spark.SparkConf | |
import org.apache.spark.SparkContext | |
import org.apache.spark.SparkContext._ | |
import org.apache.spark.rdd._ | |
import org.apache.spark.mllib.recommendation.{ALS, Rating, MatrixFactorizationModel} | |
object MovieLensALS { | |
def main(args: Array[String]) { | |
Logger.getLogger("org.apache.spark").setLevel(Level.WARN) | |
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF) | |
if (args.length != 2) { | |
println("Usage: /path/to/spark/bin/spark-submit --driver-memory 2g --class MovieLensALS " + | |
"target/scala-*/movielens-als-ssembly-*.jar movieLensHomeDir personalRatingsFile") | |
sys.exit(1) | |
} | |
// set up environment | |
val conf = new SparkConf() | |
.setAppName("MovieLensALS") | |
.set("spark.executor.memory", "2g") | |
val sc = new SparkContext(conf) | |
// load personal ratings | |
val myRatings = loadRatings(args(1)) | |
val myRatingsRDD = sc.parallelize(myRatings, 1) | |
// load ratings and movie titles | |
val movieLensHomeDir = args(0) | |
val ratings = sc.textFile(new File(movieLensHomeDir, "ratings.csv").toString).map { line => | |
val fields = line.split(",") | |
// format: (timestamp % 10, Rating(userId, movieId, rating)) | |
(fields(3).toLong % 10, Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)) | |
} | |
val movies = sc.textFile(new File(movieLensHomeDir, "movies.csv").toString).map { line => | |
val fields = line.split(",") | |
// format: (movieId, movieName) | |
(fields(0).toInt, fields(1)) | |
}.collect().toMap | |
// your code here | |
val numRatings = ratings.count | |
val numMovies = ratings.map(_._2.product).distinct.count | |
val numUsers = ratings.map(_._2.user).distinct.count | |
println("Got " + numRatings + " ratings from " + numUsers + " users on " + numMovies + " movies.") | |
val numPartition = 4 | |
val training = ratings.filter(x => x._1 < 6) | |
.values | |
.union(myRatingsRDD) | |
.repartition(numPartition) | |
.cache() | |
val validation = ratings.filter(x => x._1 >= 6 && x._1 < 8) | |
.values | |
.repartition(numPartition) | |
.cache() | |
val test = ratings.filter(x => x._1 >= 8) | |
.values | |
.cache() | |
val numTraining = training.count() | |
val numValidation = validation.count() | |
val numTest = test.count() | |
println("Training: " + numTraining + ", validation: " + numValidation + ", test: " + numTest) | |
def models(ratings:RDD[Rating]) = | |
for{ | |
rank <- Seq(8,9,10,11,12) | |
lambda <- Seq(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0, 10.0) | |
numIteration <- 10.to(20) | |
} yield ALSModel(rank, lambda, numIteration, ALS.train(ratings , rank, numIteration, lambda)) | |
val bestModel = models(training).map{ alsModel => ALSModel.addrmse(alsModel,validation,numValidation)} | |
.sortWith{ALSModel.sort } | |
.head | |
val rmseOnTest = computermse(bestModel.model, test, numTest) | |
println(s"""The best model was trained using rank ${bestModel.rank} | |
and lambda ${bestModel.lambda}. and its RMSE on test is $rmseOnTest""") | |
// clean up | |
sc.stop() | |
} | |
case class ALSModel(rank:Int, | |
lambda:Double , | |
numIteration:Long, | |
model:MatrixFactorizationModel, | |
rmse:Option[Double] = None) | |
object ALSModel { | |
def addrmse(alsModel:ALSModel, validation:RDD[Rating], numValidation:Long) = | |
alsModel.copy(rmse = Some(computermse(alsModel.model,validation, numValidation))) | |
def sort(a1:ALSModel, a2:ALSModel):Boolean = { | |
val rmse1:Double = a1.rmse.getOrElse(Double.MaxValue) | |
val rmse2:Double = a2.rmse.getOrElse(Double.MaxValue) | |
rmse1 < rmse2 | |
} | |
} | |
/** compute rmse (root mean squared error). */ | |
def computermse(model: MatrixFactorizationModel, data: RDD[Rating], n: Long): Double = { | |
val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) | |
val predictionsAndRatings = predictions.map(x => ((x.user, x.product), x.rating)) | |
.join(data.map(x => ((x.user, x.product), x.rating))) | |
.values | |
math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).reduce(_ + _) / n) | |
} | |
/** Load ratings from file. */ | |
def loadRatings(path: String): Seq[Rating] = { | |
val lines = Source.fromFile(path).getLines() | |
val ratings = lines.map { line => | |
val fields = line.split("::") | |
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) | |
}.filter(_.rating > 0.0) | |
if (ratings.isEmpty) { | |
sys.error("No ratings provided.") | |
} else { | |
ratings.toSeq | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment