Skip to content

Instantly share code, notes, and snippets.

@treper
Last active December 22, 2015 19:39
Show Gist options
  • Save treper/e095dd208c6daffe6057 to your computer and use it in GitHub Desktop.
Save treper/e095dd208c6daffe6057 to your computer and use it in GitHub Desktop.
cluster tudou tags using kmeans,tag vectors is generated using word2vec and filtered by tudou tag database
import spark.util.Vector
val word_vec_size=150
def parseVector(line: String): Vector = {
return new Vector(line.split(' ').slice(1,word_vec_size+1).map(_.toDouble))
}
def closestPoint(p: Vector, centers: Array[Vector]): Int = {
var index = 0
var bestIndex = 0
var closest = Double.PositiveInfinity
for (i <- 0 until centers.length) {
val tempDist = p.squaredDist(centers(i))
if (tempDist < closest) {
closest = tempDist
bestIndex = i
}
}
return bestIndex
}
def average(ps: Seq[Vector]) : Vector = {
val numVectors = ps.size
var out = new Vector(ps(0).elements)
for (i <- 1 until numVectors) {
out += ps(i)
}
out / numVectors
}
val K = 500
val convergeDist = 1e-6
val tag_lib_file = sc.textFile("hdfs://finger-test2:54310/home/TagHierarchy/lis_label_library")
val tag_id_map = tag_lib_file.map(line => (line.split("\t")(1),line.split("\t")(0).toInt)).collectAsMap()
val tvf = sc.textFile("hdfs://finger-test2:54310/home/TagHierarchy/tag_vector")
val data = tvf.map(line => (line.split(' ')(0),parseVector(line))).cache()
val count = data.count()
println("Number of records " + count)
var centroids = data.takeSample(false, K, 42).map(x => x._2)
var tempDist = 1.0
do {
var closest = data.map(p => (closestPoint(p._2, centroids), p._2))
var pointsGroup = closest.groupByKey()
var newCentroids = pointsGroup.mapValues(ps => average(ps)).collectAsMap()
tempDist = 0.0
for (i <- 0 until K) {
tempDist += centroids(i).squaredDist(newCentroids(i))
}
for (newP <- newCentroids) {
centroids(newP._1) = newP._2
}
println("Finished iteration (delta = " + tempDist + ")")
} while (tempDist > convergeDist)
val closest = data.map(p => (closestPoint(p._2, centroids), p._1))
val pointsGroup = closest.groupByKey()
pointsGroup.saveAsTextFile("hdfs://finger-test2:54310/home/TagHierarchy/tag_cluster")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment