Last active
February 1, 2016 07:33
-
-
Save ramv/a3335dd476af6cd9d830 to your computer and use it in GitHub Desktop.
Java Implementation of ALS Recommender for Movie Lens Data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaDoubleRDD; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.apache.spark.api.java.function.DoubleFlatMapFunction; | |
import org.apache.spark.api.java.function.Function; | |
import org.apache.spark.ml.recommendation.ALS; | |
import org.apache.spark.ml.recommendation.ALSModel; | |
import org.apache.spark.mllib.recommendation.Rating; | |
import org.apache.spark.sql.DataFrame; | |
import org.apache.spark.sql.Row; | |
import org.apache.spark.sql.SQLContext; | |
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; | |
import org.apache.spark.sql.types.DataTypes; | |
import org.apache.spark.sql.types.StructField; | |
import org.apache.spark.sql.types.StructType; | |
import java.util.ArrayList; | |
import java.util.List; | |
/** | |
* Created by r.viswanadha on 10/7/15. | |
*/ | |
public class MovieLensALS { | |
public static class RatingWithTimestamp extends Rating { | |
private long timestamp; | |
public RatingWithTimestamp(int userId, int movieId, float rating, long timestamp){ | |
super(userId, movieId, rating); | |
this.timestamp = timestamp; | |
} | |
public long getTimestamp(){ | |
return timestamp; | |
} | |
public Object productElement(int n) { | |
return null; | |
} | |
public int productArity() { | |
return 0; | |
} | |
public boolean canEqual(Object that) { | |
return false; | |
} | |
public boolean equals(Object that) { | |
return false; | |
} | |
} | |
public static class ParseRating implements Function<String, Rating> { | |
public Rating call(String str) throws Exception{ | |
String[] fields = str.split("::"); | |
if(fields.length != 4){ | |
throw new Exception("The text file line does not have the right format. "+str); | |
} | |
return new RatingWithTimestamp( Integer.parseInt(fields[0]), | |
Integer.parseInt(fields[1]), | |
Float.parseFloat(fields[2]), | |
Long.parseLong(fields[3])); | |
} | |
} | |
public static class Movie{ | |
private int movieId; | |
public int getMovieId() { | |
return movieId; | |
} | |
public void setMovieId(int movieId) { | |
this.movieId = movieId; | |
} | |
public String getTitle() { | |
return title; | |
} | |
public void setTitle(String title) { | |
this.title = title; | |
} | |
public ArrayList<String> getGenres() { | |
return genres; | |
} | |
public void setGenres(ArrayList<String> genres) { | |
this.genres = genres; | |
} | |
private String title; | |
private ArrayList<String> genres; | |
public Movie(int movieId, String title, ArrayList<String> genres){ | |
this.movieId = movieId; | |
this.title = title; | |
this.genres = genres; | |
} | |
} | |
public static class ParseMovie implements Function<String,Movie>{ | |
public Movie call(String str) throws Exception{ | |
String[] splits = str.split("::"); | |
if(splits.length!=3){ | |
throw new Exception("Movie string format is incorrect. "+str); | |
} | |
ArrayList<String> list = new ArrayList<String>(); | |
String[] genres = splits[2].split("|"); | |
for(String genre : genres){ | |
list.add(genre); | |
} | |
return new Movie(Integer.parseInt(splits[0]), splits[1], list); | |
} | |
} | |
public static class GenericRowFromRating implements Function<Rating, Row>{ | |
public Row call(Rating v1) throws Exception { | |
Object[] objs = new Object[3]; | |
objs[0] = v1.user(); | |
objs[1] = v1.product(); | |
objs[2] = v1.rating(); | |
GenericRowWithSchema row = new GenericRowWithSchema(objs, getStructTypeForRating()); | |
return row; | |
} | |
} | |
public static final StructType getStructTypeForRating(){ | |
// Generate the schema based on the string of schema | |
List<StructField> fields = new ArrayList<StructField>(); | |
fields.add(DataTypes.createStructField("userId", DataTypes.IntegerType, true)); | |
fields.add(DataTypes.createStructField("movieId", DataTypes.IntegerType, true)); | |
fields.add(DataTypes.createStructField("rating", DataTypes.DoubleType, true)); | |
return DataTypes.createStructType(fields); | |
} | |
public static void explicitRatingsAls(JavaSparkContext jsc, String ratingsFile, String moviesFile){ | |
JavaRDD<String> fileData= jsc.textFile(ratingsFile); | |
JavaRDD<Rating> ratingJavaRDD = fileData.map(new ParseRating()).cache(); | |
double[] splitPercent = new double[]{0.8, 0.2}; | |
JavaRDD<Rating>[] splits = ratingJavaRDD.randomSplit(splitPercent, 0L); | |
JavaRDD<Rating> trainingJavaRDD = splits[0].cache(); | |
JavaRDD<Rating> testJavaRDD = splits[1].cache(); | |
JavaRDD<Row> trainingRows = trainingJavaRDD.map(new GenericRowFromRating()); | |
JavaRDD<Row> testRows = testJavaRDD.map(new GenericRowFromRating()); | |
long numTraining = trainingRows.count(); | |
long numTest = testRows.count(); | |
System.out.println(String.format("Training: %d, Test; %d", numTraining, numTest)); | |
ratingJavaRDD.unpersist(false); | |
ALS als = new ALS() | |
.setUserCol("userId") | |
.setItemCol("movieId") | |
.setRank(10) | |
.setMaxIter(15) | |
.setRegParam(0.1) | |
.setNumBlocks(10) | |
.setImplicitPrefs(true) | |
.setAlpha(0.3); | |
SQLContext sqlCtx = new SQLContext(jsc.sc()); | |
DataFrame trainingDataFrame = sqlCtx.createDataFrame(trainingRows, getStructTypeForRating()); | |
System.out.println(String.join(" | ",trainingDataFrame.columns())); | |
DataFrame testDataFrame = sqlCtx.createDataFrame(testRows, getStructTypeForRating()); | |
ALSModel alsModel = als.fit(trainingDataFrame); | |
//alsModel.transform() | |
DataFrame testPredictionDF = alsModel.transform(testDataFrame).cache(); | |
DataFrame trainingPredictionDF = alsModel.transform(trainingDataFrame).cache(); | |
JavaDoubleRDD trainingPrediction = trainingPredictionDF.select("rating", "prediction").javaRDD().flatMapToDouble(new DoubleFlatMapFunction<Row>() { | |
public Iterable<Double> call(Row row) throws Exception { | |
System.out.println(String.format("rating: %s prediction:%s", row.get(0), row.get(1))); | |
Double err = Double.parseDouble(row.get(0).toString()) - Double.parseDouble(row.get(1).toString()); | |
Double err2 = err * err; | |
ArrayList<Double> list = new ArrayList<Double>(); | |
if (!err2.isNaN()) { | |
list.add(err); | |
} | |
return list; | |
} | |
}); | |
double trainingPredictionMean = trainingPrediction.mean(); | |
System.out.println("Training RMSE: "+Math.sqrt(trainingPredictionMean)); | |
// Evaluate the model. | |
JavaDoubleRDD testPrediction = testPredictionDF.select("rating", "prediction").javaRDD().flatMapToDouble(new DoubleFlatMapFunction<Row>() { | |
public Iterable<Double> call(Row row) throws Exception { | |
System.out.println(String.format("rating: %s prediction:%s", row.get(0), row.get(1))); | |
Double err = Double.parseDouble(row.get(0).toString()) - Double.parseDouble(row.get(1).toString()); | |
Double err2 = err * err; | |
ArrayList<Double> list = new ArrayList<Double>(); | |
if(!err2.isNaN()){ | |
list.add(err); | |
} | |
return list; | |
} | |
}); | |
double mean = testPrediction.mean(); | |
System.out.println("RMSE: "+Math.sqrt(mean)); | |
// Inspect false positives. | |
// Note: We reference columns in 2 ways: | |
// (1) predictions("movieId") lets us specify the movieId column in the predictions | |
// DataFrame, rather than the movieId column in the movies DataFrame. | |
// (2) $"userId" specifies the userId column in the predictions DataFrame. | |
// We could also write predictions("userId") but do not have to since | |
// the movies DataFrame does not have a column "userId." | |
JavaRDD<String> movieStringsRDD = jsc.textFile(moviesFile); | |
JavaRDD<Movie> movieJavaRDD = movieStringsRDD.map(new ParseMovie()).cache(); | |
DataFrame movieDF = sqlCtx.createDataFrame(movieJavaRDD, Movie.class); | |
DataFrame falsePositives = testPredictionDF | |
.join(movieDF) | |
.where(testPredictionDF.col("movieId") | |
.equalTo(movieDF.col("movieId")) | |
.and(testPredictionDF.col("rating").$less$eq(1)) | |
.and(testPredictionDF.col("prediction").$greater$eq(4)) | |
) | |
.select(testPredictionDF.col("userId"), | |
testPredictionDF.col("movieId"), | |
movieDF.col("title"), | |
testPredictionDF.col("rating"), | |
testPredictionDF.col("prediction")); | |
long numFalsePositives = falsePositives.count(); | |
if(numFalsePositives>0){ | |
Row[] rows = falsePositives.collect(); | |
for(Row row : rows){ | |
System.out.println("\t"+row.toString()); | |
} | |
} | |
} | |
/** | |
* Run the program with sample_movielens_ratings.txt as the first argument and sample_movielens_movies.txt as the second argument | |
* You can find these files at https://github.com/apache/spark/blob/master/data/mllib/sample_movielens_data.txt and | |
* https://github.com/apache/spark/blob/master/data/mllib/sample_movielens_movies.txt | |
*/ | |
public static void main(String[] args){ | |
//Flags.setFromCommandLineArgs(THE_OPTIONS, args); | |
// Startup the Spark Conf. | |
SparkConf conf = new SparkConf() | |
.setAppName("Recommendation Engine POC").setMaster("local"); | |
conf.set("spark.serializer", org.apache.spark.serializer.KryoSerializer.class.getName()); | |
JavaSparkContext jsc = new JavaSparkContext(conf); | |
explicitRatingsAls(jsc, | |
args[0], | |
args[1]); | |
jsc.stop(); | |
//createStreaming(conf); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment