Created
April 28, 2017 21:34
-
-
Save ramv/1b9f507e8db742ce6897e1dc11dd0caa to your computer and use it in GitHub Desktop.
Similarity Analysis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.api.java.JavaPairRDD; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.mllib.linalg.Vector; | |
import org.apache.spark.mllib.linalg.Vectors; | |
import org.apache.spark.mllib.linalg.distributed.*; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import scala.Tuple2; | |
import java.io.Serializable; | |
import java.util.*; | |
public class SimilarityDimSum { | |
private static final Logger LOGGER = LoggerFactory.getLogger(SimilarityDimSum.class); | |
public static class Relation implements Comparable<Relation>{ | |
private Double score; | |
private String first; | |
private String second; | |
public void swap(){ | |
new Relation(second, first, score); | |
} | |
public Relation(String first, String second, Double similarity){ | |
this.first = first; | |
this.second = second; | |
this.score = similarity; | |
} | |
public Relation(Long first, Long second, Double similarity){ | |
this.first = first.toString(); | |
this.second = second.toString(); | |
this.score = similarity; | |
} | |
public int compareTo(Relation other){ | |
if(!this.first.equals(other.first)){ | |
//TODO ERROR .. what to do?? | |
} | |
return this.score.compareTo(other.score); | |
} | |
public Double getScore() { | |
return score; | |
} | |
public void setScore(Double score) { | |
this.score = score; | |
} | |
public String getFirst() { | |
return first; | |
} | |
public void setFirst(String first) { | |
this.first = first; | |
} | |
public String getSecond() { | |
return second; | |
} | |
public void setSecond(String second) { | |
this.second = second; | |
} | |
@Override | |
public String toString(){ | |
return String.format("[\"%s\",\"%s\",\"%s\"]", first, second, score.toString()); | |
} | |
} | |
public static JavaRDD<Relation> computeItemItemSimilarities(JavaRDD<DmRating> ratingJavaRDD, double threshold, int k){ | |
RowMatrix ratingsMatrix = getUserVectors(ratingJavaRDD); | |
return computeSimilarities(ratingsMatrix, threshold, k); | |
} | |
public static JavaRDD<Relation> computeUserUserSimilarities(JavaRDD<DmRating> ratingJavaRDD, double threshold, int k){ | |
RowMatrix ratingsMatrix = getItemVectors(ratingJavaRDD); | |
return computeSimilarities(ratingsMatrix, threshold, k); | |
} | |
public static JavaRDD<Relation> computeSimilarities(RowMatrix ratingsMatrix, double threshold, int k){ | |
CoordinateMatrix similaritiesMatrix = ratingsMatrix.columnSimilarities(threshold); | |
JavaRDD<MatrixEntry> entries = similaritiesMatrix.entries().toJavaRDD().cache(); | |
LOGGER.info("Number of entries: {}", entries.count()); | |
JavaPairRDD<String, Relation> similaritiesRdd = entries.mapPartitionsToPair((Iterator<MatrixEntry> entriesIterator) -> { | |
ArrayList<Tuple2<String,Relation>> list = new ArrayList<>(); | |
while(entriesIterator.hasNext()){ | |
MatrixEntry matrixEntry = entriesIterator.next(); | |
Relation sim = new Relation(matrixEntry.i(), matrixEntry.j(), matrixEntry.value()); | |
list.add(new Tuple2<>(sim.getFirst(), sim)); | |
} | |
return list; | |
}).cache(); | |
LOGGER.info("Number of similarities: {}", similaritiesRdd.count()); | |
JavaPairRDD<String,Iterable<Relation>> grouped = similaritiesRdd.groupByKey().cache(); | |
JavaRDD<Relation> topSimilarities = grouped.flatMap((Tuple2<String, Iterable<Relation>> item) -> { | |
Iterator<Relation> sim = item._2().iterator(); | |
ArrayList<Relation> simList = new ArrayList<>(); | |
while(sim.hasNext()){ | |
simList.add(sim.next()); | |
} | |
return getTopSimilarities(simList, k); | |
}); | |
LOGGER.info("Number of top similarities: {}", topSimilarities.count()); | |
return topSimilarities; | |
} | |
public static double cosineSimilarity(double dotProduct, double norm1, double norm2){ | |
return dotProduct/(norm1 * norm2); | |
} | |
public static List<Relation> getTopSimilarities(ArrayList<Relation> grouped, int k){ | |
Comparator<Relation> reverse = Collections.reverseOrder(); | |
grouped.sort(reverse); | |
if(grouped.size()>k) { | |
return grouped.subList(0, k); | |
}else{ | |
return grouped; | |
} | |
} | |
public static class DmRatingComparator implements Comparator<Tuple2<Integer, Iterable<DmRating>>>, Serializable { | |
@Override | |
public int compare(Tuple2<Integer, Iterable<DmRating>> o1, Tuple2<Integer, Iterable<DmRating>> o2) { | |
return o1._1().compareTo(o2._1()); | |
} | |
} | |
public static RowMatrix getUserVectors(JavaRDD<DmRating> ratingJavaRDD){ | |
JavaPairRDD<Integer, Iterable<DmRating>> groupedByItem = ratingJavaRDD.groupBy( rating -> rating.modVideoId); | |
int maxItemId = groupedByItem.max(new DmRatingComparator())._1(); | |
JavaPairRDD<Integer, Iterable<DmRating>> groupedByUser = ratingJavaRDD.groupBy(rating -> rating.modUserId); | |
JavaRDD<Vector> vectorJavaRDD = groupedByUser.mapPartitions(itemIter ->{ | |
ArrayList<Vector> vectors = new ArrayList<>(); | |
while(itemIter.hasNext()){ | |
Tuple2<Integer, Iterable<DmRating>> item = itemIter.next(); | |
HashMap<Integer, Tuple2<Integer, Double>> videoIdRating = new HashMap<>(); | |
Iterator<DmRating> videos = item._2().iterator(); | |
while(videos.hasNext()){ | |
DmRating rating = videos.next(); | |
// make sure we only use one rating per video | |
videoIdRating.put(rating.modVideoId, new Tuple2<>(rating.modVideoId, rating.rating)); | |
} | |
vectors.add(Vectors.sparse(maxItemId+1, videoIdRating.values())); | |
} | |
return vectors; | |
}); | |
return new RowMatrix(vectorJavaRDD.rdd()); | |
} | |
public static RowMatrix getItemVectors(JavaRDD<DmRating> ratingJavaRDD){ | |
JavaPairRDD<Integer, Iterable<DmRating>> groupedByUser = ratingJavaRDD.groupBy(rating -> rating.modUserId); | |
int maxUserId = groupedByUser.max(new DmRatingComparator())._1(); | |
JavaPairRDD<Integer, Iterable<DmRating>> groupedByItem = ratingJavaRDD.groupBy(rating -> rating.modVideoId); | |
/* | |
* videoID 1 : [ userID, userId, userId .. ] | |
* videoID 2 : [ userID, userId, userId .. ] | |
* | |
* videoId1 videoId2 | |
* userId 1 | |
* userId 1 | |
*/ | |
JavaRDD<Vector> vectorJavaRDD = groupedByItem.mapPartitions((Iterator<Tuple2<Integer, Iterable<DmRating>>> tuple2Iterator) -> { | |
ArrayList<Vector> vectors = new ArrayList<>(); | |
while(tuple2Iterator.hasNext()){ | |
Tuple2<Integer, Iterable<DmRating>> item = tuple2Iterator.next(); | |
HashMap<Integer, Tuple2<Integer, Double>> userIdRating = new HashMap<>(); | |
Iterator<DmRating> videos = item._2().iterator(); | |
while(videos.hasNext()){ | |
DmRating rating = videos.next(); | |
// make sure we only use one rating per user | |
userIdRating.put(rating.modUserId, new Tuple2<>(rating.modUserId, rating.rating)); | |
} | |
vectors.add(Vectors.sparse(maxUserId+1, userIdRating.values())); | |
} | |
return vectors; | |
}); | |
return new RowMatrix(vectorJavaRDD.rdd()); | |
} | |
public static JavaRDD<Relation> computeAlsFeatureSimilarity(JavaRDD<Tuple2<Object, double[]>> features, double threshold, int k){ | |
//TODO figure out how to do similarity computation. | |
return computeAlsFeatureSimilarityDimSum(features, threshold, k); | |
} | |
public static JavaRDD<Relation> computeAlsFeatureSimilarityDimSum(JavaRDD<Tuple2<Object, double[]>> features, double threshold, int k){ | |
JavaRDD<MatrixEntry> alsRows = features.mapPartitions(ti2 -> { | |
ArrayList<MatrixEntry> list = new ArrayList<>(); | |
while(ti2.hasNext()){ | |
Tuple2<Object, double[]> t2 = ti2.next(); | |
double[] vals = t2._2(); | |
Integer id = (Integer) t2._1(); | |
for(int i=0; i<vals.length; i++){ | |
list.add(new MatrixEntry(i, id, vals[i])); | |
} | |
} | |
return list; | |
}); | |
JavaRDD<MatrixEntry> cachedProductAlsRows = alsRows.cache(); | |
CoordinateMatrix alsCoordinateRowMatrix = new CoordinateMatrix(cachedProductAlsRows.rdd()); | |
RowMatrix alsRowMatrix = alsCoordinateRowMatrix.toRowMatrix(); | |
return computeSimilarities(alsRowMatrix, threshold, k); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment