Skip to content

Instantly share code, notes, and snippets.

@ramv
Created April 28, 2017 21:34
Show Gist options
  • Save ramv/1b9f507e8db742ce6897e1dc11dd0caa to your computer and use it in GitHub Desktop.
Save ramv/1b9f507e8db742ce6897e1dc11dd0caa to your computer and use it in GitHub Desktop.
Similarity Analysis
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.Serializable;
import java.util.*;
public class SimilarityDimSum {
private static final Logger LOGGER = LoggerFactory.getLogger(SimilarityDimSum.class);
public static class Relation implements Comparable<Relation>{
private Double score;
private String first;
private String second;
public void swap(){
new Relation(second, first, score);
}
public Relation(String first, String second, Double similarity){
this.first = first;
this.second = second;
this.score = similarity;
}
public Relation(Long first, Long second, Double similarity){
this.first = first.toString();
this.second = second.toString();
this.score = similarity;
}
public int compareTo(Relation other){
if(!this.first.equals(other.first)){
//TODO ERROR .. what to do??
}
return this.score.compareTo(other.score);
}
public Double getScore() {
return score;
}
public void setScore(Double score) {
this.score = score;
}
public String getFirst() {
return first;
}
public void setFirst(String first) {
this.first = first;
}
public String getSecond() {
return second;
}
public void setSecond(String second) {
this.second = second;
}
@Override
public String toString(){
return String.format("[\"%s\",\"%s\",\"%s\"]", first, second, score.toString());
}
}
public static JavaRDD<Relation> computeItemItemSimilarities(JavaRDD<DmRating> ratingJavaRDD, double threshold, int k){
RowMatrix ratingsMatrix = getUserVectors(ratingJavaRDD);
return computeSimilarities(ratingsMatrix, threshold, k);
}
public static JavaRDD<Relation> computeUserUserSimilarities(JavaRDD<DmRating> ratingJavaRDD, double threshold, int k){
RowMatrix ratingsMatrix = getItemVectors(ratingJavaRDD);
return computeSimilarities(ratingsMatrix, threshold, k);
}
public static JavaRDD<Relation> computeSimilarities(RowMatrix ratingsMatrix, double threshold, int k){
CoordinateMatrix similaritiesMatrix = ratingsMatrix.columnSimilarities(threshold);
JavaRDD<MatrixEntry> entries = similaritiesMatrix.entries().toJavaRDD().cache();
LOGGER.info("Number of entries: {}", entries.count());
JavaPairRDD<String, Relation> similaritiesRdd = entries.mapPartitionsToPair((Iterator<MatrixEntry> entriesIterator) -> {
ArrayList<Tuple2<String,Relation>> list = new ArrayList<>();
while(entriesIterator.hasNext()){
MatrixEntry matrixEntry = entriesIterator.next();
Relation sim = new Relation(matrixEntry.i(), matrixEntry.j(), matrixEntry.value());
list.add(new Tuple2<>(sim.getFirst(), sim));
}
return list;
}).cache();
LOGGER.info("Number of similarities: {}", similaritiesRdd.count());
JavaPairRDD<String,Iterable<Relation>> grouped = similaritiesRdd.groupByKey().cache();
JavaRDD<Relation> topSimilarities = grouped.flatMap((Tuple2<String, Iterable<Relation>> item) -> {
Iterator<Relation> sim = item._2().iterator();
ArrayList<Relation> simList = new ArrayList<>();
while(sim.hasNext()){
simList.add(sim.next());
}
return getTopSimilarities(simList, k);
});
LOGGER.info("Number of top similarities: {}", topSimilarities.count());
return topSimilarities;
}
public static double cosineSimilarity(double dotProduct, double norm1, double norm2){
return dotProduct/(norm1 * norm2);
}
public static List<Relation> getTopSimilarities(ArrayList<Relation> grouped, int k){
Comparator<Relation> reverse = Collections.reverseOrder();
grouped.sort(reverse);
if(grouped.size()>k) {
return grouped.subList(0, k);
}else{
return grouped;
}
}
public static class DmRatingComparator implements Comparator<Tuple2<Integer, Iterable<DmRating>>>, Serializable {
@Override
public int compare(Tuple2<Integer, Iterable<DmRating>> o1, Tuple2<Integer, Iterable<DmRating>> o2) {
return o1._1().compareTo(o2._1());
}
}
public static RowMatrix getUserVectors(JavaRDD<DmRating> ratingJavaRDD){
JavaPairRDD<Integer, Iterable<DmRating>> groupedByItem = ratingJavaRDD.groupBy( rating -> rating.modVideoId);
int maxItemId = groupedByItem.max(new DmRatingComparator())._1();
JavaPairRDD<Integer, Iterable<DmRating>> groupedByUser = ratingJavaRDD.groupBy(rating -> rating.modUserId);
JavaRDD<Vector> vectorJavaRDD = groupedByUser.mapPartitions(itemIter ->{
ArrayList<Vector> vectors = new ArrayList<>();
while(itemIter.hasNext()){
Tuple2<Integer, Iterable<DmRating>> item = itemIter.next();
HashMap<Integer, Tuple2<Integer, Double>> videoIdRating = new HashMap<>();
Iterator<DmRating> videos = item._2().iterator();
while(videos.hasNext()){
DmRating rating = videos.next();
// make sure we only use one rating per video
videoIdRating.put(rating.modVideoId, new Tuple2<>(rating.modVideoId, rating.rating));
}
vectors.add(Vectors.sparse(maxItemId+1, videoIdRating.values()));
}
return vectors;
});
return new RowMatrix(vectorJavaRDD.rdd());
}
public static RowMatrix getItemVectors(JavaRDD<DmRating> ratingJavaRDD){
JavaPairRDD<Integer, Iterable<DmRating>> groupedByUser = ratingJavaRDD.groupBy(rating -> rating.modUserId);
int maxUserId = groupedByUser.max(new DmRatingComparator())._1();
JavaPairRDD<Integer, Iterable<DmRating>> groupedByItem = ratingJavaRDD.groupBy(rating -> rating.modVideoId);
/*
* videoID 1 : [ userID, userId, userId .. ]
* videoID 2 : [ userID, userId, userId .. ]
*
* videoId1 videoId2
* userId 1
* userId 1
*/
JavaRDD<Vector> vectorJavaRDD = groupedByItem.mapPartitions((Iterator<Tuple2<Integer, Iterable<DmRating>>> tuple2Iterator) -> {
ArrayList<Vector> vectors = new ArrayList<>();
while(tuple2Iterator.hasNext()){
Tuple2<Integer, Iterable<DmRating>> item = tuple2Iterator.next();
HashMap<Integer, Tuple2<Integer, Double>> userIdRating = new HashMap<>();
Iterator<DmRating> videos = item._2().iterator();
while(videos.hasNext()){
DmRating rating = videos.next();
// make sure we only use one rating per user
userIdRating.put(rating.modUserId, new Tuple2<>(rating.modUserId, rating.rating));
}
vectors.add(Vectors.sparse(maxUserId+1, userIdRating.values()));
}
return vectors;
});
return new RowMatrix(vectorJavaRDD.rdd());
}
public static JavaRDD<Relation> computeAlsFeatureSimilarity(JavaRDD<Tuple2<Object, double[]>> features, double threshold, int k){
//TODO figure out how to do similarity computation.
return computeAlsFeatureSimilarityDimSum(features, threshold, k);
}
public static JavaRDD<Relation> computeAlsFeatureSimilarityDimSum(JavaRDD<Tuple2<Object, double[]>> features, double threshold, int k){
JavaRDD<MatrixEntry> alsRows = features.mapPartitions(ti2 -> {
ArrayList<MatrixEntry> list = new ArrayList<>();
while(ti2.hasNext()){
Tuple2<Object, double[]> t2 = ti2.next();
double[] vals = t2._2();
Integer id = (Integer) t2._1();
for(int i=0; i<vals.length; i++){
list.add(new MatrixEntry(i, id, vals[i]));
}
}
return list;
});
JavaRDD<MatrixEntry> cachedProductAlsRows = alsRows.cache();
CoordinateMatrix alsCoordinateRowMatrix = new CoordinateMatrix(cachedProductAlsRows.rdd());
RowMatrix alsRowMatrix = alsCoordinateRowMatrix.toRowMatrix();
return computeSimilarities(alsRowMatrix, threshold, k);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment