Skip to content

Instantly share code, notes, and snippets.

@toff63
Last active October 26, 2015 01:34
Show Gist options
  • Save toff63/43c5f7bf73702f040397 to your computer and use it in GitHub Desktop.
Save toff63/43c5f7bf73702f040397 to your computer and use it in GitHub Desktop.
import java.io.File
import scala.io.Source
import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd._
import org.apache.spark.mllib.recommendation.{ALS, Rating, MatrixFactorizationModel}
object MovieLensALS {
def main(args: Array[String]) {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
if (args.length != 2) {
println("Usage: /path/to/spark/bin/spark-submit --driver-memory 2g --class MovieLensALS " +
"target/scala-*/movielens-als-ssembly-*.jar movieLensHomeDir personalRatingsFile")
sys.exit(1)
}
// set up environment
val conf = new SparkConf()
.setAppName("MovieLensALS")
.set("spark.executor.memory", "2g")
val sc = new SparkContext(conf)
// load personal ratings
val myRatings = loadRatings(args(1))
val myRatingsRDD = sc.parallelize(myRatings, 1)
// load ratings and movie titles
val movieLensHomeDir = args(0)
val ratings = sc.textFile(new File(movieLensHomeDir, "ratings.csv").toString).map { line =>
val fields = line.split(",")
// format: (timestamp % 10, Rating(userId, movieId, rating))
(fields(3).toLong % 10, Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble))
}
val movies = sc.textFile(new File(movieLensHomeDir, "movies.csv").toString).map { line =>
val fields = line.split(",")
// format: (movieId, movieName)
(fields(0).toInt, fields(1))
}.collect().toMap
// your code here
val numRatings = ratings.count
val numMovies = ratings.map(_._2.product).distinct.count
val numUsers = ratings.map(_._2.user).distinct.count
println("Got " + numRatings + " ratings from " + numUsers + " users on " + numMovies + " movies.")
val numPartition = 4
val training = ratings.filter(x => x._1 < 6)
.values
.union(myRatingsRDD)
.repartition(numPartition)
.cache()
val validation = ratings.filter(x => x._1 >= 6 && x._1 < 8)
.values
.repartition(numPartition)
.cache()
val test = ratings.filter(x => x._1 >= 8)
.values
.cache()
val numTraining = training.count()
val numValidation = validation.count()
val numTest = test.count()
println("Training: " + numTraining + ", validation: " + numValidation + ", test: " + numTest)
def models(ratings:RDD[Rating]) =
for{
rank <- Seq(8,9,10,11,12)
lambda <- Seq(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0, 10.0)
numIteration <- 10.to(20)
} yield ALSModel(rank, lambda, numIteration, ALS.train(ratings , rank, numIteration, lambda))
val bestModel = models(training).map{ alsModel => ALSModel.addrmse(alsModel,validation,numValidation)}
.sortWith{ALSModel.sort }
.head
val rmseOnTest = computermse(bestModel.model, test, numTest)
println(s"""The best model was trained using rank ${bestModel.rank}
and lambda ${bestModel.lambda}. and its RMSE on test is $rmseOnTest""")
// clean up
sc.stop()
}
case class ALSModel(rank:Int,
lambda:Double ,
numIteration:Long,
model:MatrixFactorizationModel,
rmse:Option[Double] = None)
object ALSModel {
def addrmse(alsModel:ALSModel, validation:RDD[Rating], numValidation:Long) =
alsModel.copy(rmse = Some(computermse(alsModel.model,validation, numValidation)))
def sort(a1:ALSModel, a2:ALSModel):Boolean = {
val rmse1:Double = a1.rmse.getOrElse(Double.MaxValue)
val rmse2:Double = a2.rmse.getOrElse(Double.MaxValue)
rmse1 < rmse2
}
}
/** compute rmse (root mean squared error). */
def computermse(model: MatrixFactorizationModel, data: RDD[Rating], n: Long): Double = {
val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product)))
val predictionsAndRatings = predictions.map(x => ((x.user, x.product), x.rating))
.join(data.map(x => ((x.user, x.product), x.rating)))
.values
math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).reduce(_ + _) / n)
}
/** Load ratings from file. */
def loadRatings(path: String): Seq[Rating] = {
val lines = Source.fromFile(path).getLines()
val ratings = lines.map { line =>
val fields = line.split("::")
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
}.filter(_.rating > 0.0)
if (ratings.isEmpty) {
sys.error("No ratings provided.")
} else {
ratings.toSeq
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment