Last active
December 22, 2015 19:39
-
-
Save treper/e095dd208c6daffe6057 to your computer and use it in GitHub Desktop.
cluster tudou tags using kmeans,tag vectors is generated using word2vec and filtered by tudou tag database
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import spark.util.Vector | |
val word_vec_size=150 | |
def parseVector(line: String): Vector = { | |
return new Vector(line.split(' ').slice(1,word_vec_size+1).map(_.toDouble)) | |
} | |
def closestPoint(p: Vector, centers: Array[Vector]): Int = { | |
var index = 0 | |
var bestIndex = 0 | |
var closest = Double.PositiveInfinity | |
for (i <- 0 until centers.length) { | |
val tempDist = p.squaredDist(centers(i)) | |
if (tempDist < closest) { | |
closest = tempDist | |
bestIndex = i | |
} | |
} | |
return bestIndex | |
} | |
def average(ps: Seq[Vector]) : Vector = { | |
val numVectors = ps.size | |
var out = new Vector(ps(0).elements) | |
for (i <- 1 until numVectors) { | |
out += ps(i) | |
} | |
out / numVectors | |
} | |
val K = 500 | |
val convergeDist = 1e-6 | |
val tag_lib_file = sc.textFile("hdfs://finger-test2:54310/home/TagHierarchy/lis_label_library") | |
val tag_id_map = tag_lib_file.map(line => (line.split("\t")(1),line.split("\t")(0).toInt)).collectAsMap() | |
val tvf = sc.textFile("hdfs://finger-test2:54310/home/TagHierarchy/tag_vector") | |
val data = tvf.map(line => (line.split(' ')(0),parseVector(line))).cache() | |
val count = data.count() | |
println("Number of records " + count) | |
var centroids = data.takeSample(false, K, 42).map(x => x._2) | |
var tempDist = 1.0 | |
do { | |
var closest = data.map(p => (closestPoint(p._2, centroids), p._2)) | |
var pointsGroup = closest.groupByKey() | |
var newCentroids = pointsGroup.mapValues(ps => average(ps)).collectAsMap() | |
tempDist = 0.0 | |
for (i <- 0 until K) { | |
tempDist += centroids(i).squaredDist(newCentroids(i)) | |
} | |
for (newP <- newCentroids) { | |
centroids(newP._1) = newP._2 | |
} | |
println("Finished iteration (delta = " + tempDist + ")") | |
} while (tempDist > convergeDist) | |
val closest = data.map(p => (closestPoint(p._2, centroids), p._1)) | |
val pointsGroup = closest.groupByKey() | |
pointsGroup.saveAsTextFile("hdfs://finger-test2:54310/home/TagHierarchy/tag_cluster") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment