Skip to content

Instantly share code, notes, and snippets.

@ramv
Last active February 1, 2016 07:33
Show Gist options
  • Save ramv/a3335dd476af6cd9d830 to your computer and use it in GitHub Desktop.
Save ramv/a3335dd476af6cd9d830 to your computer and use it in GitHub Desktop.
Java Implementation of ALS Recommender for Movie Lens Data
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.DoubleFlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
/**
* Created by r.viswanadha on 10/7/15.
*/
public class MovieLensALS {
public static class RatingWithTimestamp extends Rating {
private long timestamp;
public RatingWithTimestamp(int userId, int movieId, float rating, long timestamp){
super(userId, movieId, rating);
this.timestamp = timestamp;
}
public long getTimestamp(){
return timestamp;
}
public Object productElement(int n) {
return null;
}
public int productArity() {
return 0;
}
public boolean canEqual(Object that) {
return false;
}
public boolean equals(Object that) {
return false;
}
}
public static class ParseRating implements Function<String, Rating> {
public Rating call(String str) throws Exception{
String[] fields = str.split("::");
if(fields.length != 4){
throw new Exception("The text file line does not have the right format. "+str);
}
return new RatingWithTimestamp( Integer.parseInt(fields[0]),
Integer.parseInt(fields[1]),
Float.parseFloat(fields[2]),
Long.parseLong(fields[3]));
}
}
public static class Movie{
private int movieId;
public int getMovieId() {
return movieId;
}
public void setMovieId(int movieId) {
this.movieId = movieId;
}
public String getTitle() {
return title;
}
public void setTitle(String title) {
this.title = title;
}
public ArrayList<String> getGenres() {
return genres;
}
public void setGenres(ArrayList<String> genres) {
this.genres = genres;
}
private String title;
private ArrayList<String> genres;
public Movie(int movieId, String title, ArrayList<String> genres){
this.movieId = movieId;
this.title = title;
this.genres = genres;
}
}
public static class ParseMovie implements Function<String,Movie>{
public Movie call(String str) throws Exception{
String[] splits = str.split("::");
if(splits.length!=3){
throw new Exception("Movie string format is incorrect. "+str);
}
ArrayList<String> list = new ArrayList<String>();
String[] genres = splits[2].split("|");
for(String genre : genres){
list.add(genre);
}
return new Movie(Integer.parseInt(splits[0]), splits[1], list);
}
}
public static class GenericRowFromRating implements Function<Rating, Row>{
public Row call(Rating v1) throws Exception {
Object[] objs = new Object[3];
objs[0] = v1.user();
objs[1] = v1.product();
objs[2] = v1.rating();
GenericRowWithSchema row = new GenericRowWithSchema(objs, getStructTypeForRating());
return row;
}
}
public static final StructType getStructTypeForRating(){
// Generate the schema based on the string of schema
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("userId", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("movieId", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("rating", DataTypes.DoubleType, true));
return DataTypes.createStructType(fields);
}
public static void explicitRatingsAls(JavaSparkContext jsc, String ratingsFile, String moviesFile){
JavaRDD<String> fileData= jsc.textFile(ratingsFile);
JavaRDD<Rating> ratingJavaRDD = fileData.map(new ParseRating()).cache();
double[] splitPercent = new double[]{0.8, 0.2};
JavaRDD<Rating>[] splits = ratingJavaRDD.randomSplit(splitPercent, 0L);
JavaRDD<Rating> trainingJavaRDD = splits[0].cache();
JavaRDD<Rating> testJavaRDD = splits[1].cache();
JavaRDD<Row> trainingRows = trainingJavaRDD.map(new GenericRowFromRating());
JavaRDD<Row> testRows = testJavaRDD.map(new GenericRowFromRating());
long numTraining = trainingRows.count();
long numTest = testRows.count();
System.out.println(String.format("Training: %d, Test; %d", numTraining, numTest));
ratingJavaRDD.unpersist(false);
ALS als = new ALS()
.setUserCol("userId")
.setItemCol("movieId")
.setRank(10)
.setMaxIter(15)
.setRegParam(0.1)
.setNumBlocks(10)
.setImplicitPrefs(true)
.setAlpha(0.3);
SQLContext sqlCtx = new SQLContext(jsc.sc());
DataFrame trainingDataFrame = sqlCtx.createDataFrame(trainingRows, getStructTypeForRating());
System.out.println(String.join(" | ",trainingDataFrame.columns()));
DataFrame testDataFrame = sqlCtx.createDataFrame(testRows, getStructTypeForRating());
ALSModel alsModel = als.fit(trainingDataFrame);
//alsModel.transform()
DataFrame testPredictionDF = alsModel.transform(testDataFrame).cache();
DataFrame trainingPredictionDF = alsModel.transform(trainingDataFrame).cache();
JavaDoubleRDD trainingPrediction = trainingPredictionDF.select("rating", "prediction").javaRDD().flatMapToDouble(new DoubleFlatMapFunction<Row>() {
public Iterable<Double> call(Row row) throws Exception {
System.out.println(String.format("rating: %s prediction:%s", row.get(0), row.get(1)));
Double err = Double.parseDouble(row.get(0).toString()) - Double.parseDouble(row.get(1).toString());
Double err2 = err * err;
ArrayList<Double> list = new ArrayList<Double>();
if (!err2.isNaN()) {
list.add(err);
}
return list;
}
});
double trainingPredictionMean = trainingPrediction.mean();
System.out.println("Training RMSE: "+Math.sqrt(trainingPredictionMean));
// Evaluate the model.
JavaDoubleRDD testPrediction = testPredictionDF.select("rating", "prediction").javaRDD().flatMapToDouble(new DoubleFlatMapFunction<Row>() {
public Iterable<Double> call(Row row) throws Exception {
System.out.println(String.format("rating: %s prediction:%s", row.get(0), row.get(1)));
Double err = Double.parseDouble(row.get(0).toString()) - Double.parseDouble(row.get(1).toString());
Double err2 = err * err;
ArrayList<Double> list = new ArrayList<Double>();
if(!err2.isNaN()){
list.add(err);
}
return list;
}
});
double mean = testPrediction.mean();
System.out.println("RMSE: "+Math.sqrt(mean));
// Inspect false positives.
// Note: We reference columns in 2 ways:
// (1) predictions("movieId") lets us specify the movieId column in the predictions
// DataFrame, rather than the movieId column in the movies DataFrame.
// (2) $"userId" specifies the userId column in the predictions DataFrame.
// We could also write predictions("userId") but do not have to since
// the movies DataFrame does not have a column "userId."
JavaRDD<String> movieStringsRDD = jsc.textFile(moviesFile);
JavaRDD<Movie> movieJavaRDD = movieStringsRDD.map(new ParseMovie()).cache();
DataFrame movieDF = sqlCtx.createDataFrame(movieJavaRDD, Movie.class);
DataFrame falsePositives = testPredictionDF
.join(movieDF)
.where(testPredictionDF.col("movieId")
.equalTo(movieDF.col("movieId"))
.and(testPredictionDF.col("rating").$less$eq(1))
.and(testPredictionDF.col("prediction").$greater$eq(4))
)
.select(testPredictionDF.col("userId"),
testPredictionDF.col("movieId"),
movieDF.col("title"),
testPredictionDF.col("rating"),
testPredictionDF.col("prediction"));
long numFalsePositives = falsePositives.count();
if(numFalsePositives>0){
Row[] rows = falsePositives.collect();
for(Row row : rows){
System.out.println("\t"+row.toString());
}
}
}
/**
* Run the program with sample_movielens_ratings.txt as the first argument and sample_movielens_movies.txt as the second argument
* You can find these files at https://github.com/apache/spark/blob/master/data/mllib/sample_movielens_data.txt and
* https://github.com/apache/spark/blob/master/data/mllib/sample_movielens_movies.txt
*/
public static void main(String[] args){
//Flags.setFromCommandLineArgs(THE_OPTIONS, args);
// Startup the Spark Conf.
SparkConf conf = new SparkConf()
.setAppName("Recommendation Engine POC").setMaster("local");
conf.set("spark.serializer", org.apache.spark.serializer.KryoSerializer.class.getName());
JavaSparkContext jsc = new JavaSparkContext(conf);
explicitRatingsAls(jsc,
args[0],
args[1]);
jsc.stop();
//createStreaming(conf);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment