/**
* Java stub for Python mllib HierarchicalClustering.run()
*/
def trainHierarchicalClusteringModel(
data: JavaRDD[Vector],
k: Int,
maxIterations: Int,
maxRetries: Int,
seed: java.lang.Long): HierarchicalClusteringModel = {
val algo = new HierarchicalClustering()
.setNumClusters(k)
.setMaxIterations(maxIterations)
.setMaxRetries(maxRetries)
if (seed != null) algo.setSeed(seed)
try {
algo.run(data)
} finally {
data.rdd.unpersist(blocking = false)
}
}
@inherit_doc
class HierarchicalClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
"""A clustering model derived from the hierarchical clustering method.
>>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2)
>>> rdd = sc.parallelize(data)
>>> model = HierarchicalClustering.train(rdd, 2)
>>> len(model.clusterCenters)
2
>>> model.predict(array([0.0, 0.0])) == model.predict(array([1.0, 1.0]))
True
>>> model.predict(array([8.0, 9.0])) == model.predict(array([9.0, 8.0]))
True
>>> abs(model.WSSSE(rdd) - 2.82842712) < 10e-8
True
>>> len(model.toLinkageMatrix())
1
>>> len(model.toAdjacencyList())
2
>>> sparse_data = [
... SparseVector(3, {1: 1.0}),
... SparseVector(3, {1: 1.1}),
... SparseVector(3, {2: 1.0}),
... SparseVector(3, {2: 1.1})
... ]
>>> sparse_rdd = sc.parallelize(sparse_data)
>>> model = HierarchicalClustering.train(sparse_rdd, 2)
>>> model.predict(array([0., 1., 0.])) == model.predict(array([0, 1.1, 0.]))
True
>>> model.predict(array([0., 0., 1.])) == model.predict(array([0, 0, 1.1]))
True
>>> model.predict(sparse_data[0]) == model.predict(sparse_data[1])
True
>>> model.predict(sparse_data[2]) == model.predict(sparse_data[3])
True
>>> len(model.clusterCenters)
2
>>> abs(model.WSSSE(sparse_rdd) - 0.2) < 10e-2
True
>>> len(model.toLinkageMatrix())
1
>>> len(model.toAdjacencyList())
2
"""
def predict(self, x):
"""Find the cluster to which x belongs in this model."""
if isinstance(x, RDD):
return self.call("predict", x.map(_convert_to_vector))
else:
return self.call("predict", _convert_to_vector(x))
def toAdjacencyList(self):
"""Convert a cluster dendrogram to a adjacency list with distances as their weights."""
return self.call("toJavaAdjacencyList")
def toLinkageMatrix(self):
return self.call("toJavaLinkageMatrix")
@property
def clusterCenters(self):
"""Get the cluster centers, represented as a list of NumPy arrays."""
centers = _java2py(self._sc, self.call("getCenters"))
return [c.toArray() for c in centers]
def WSSSE(self, rdd):
"""Get Within Set Sum of Squared Error (WSSSE)."""
return self.call("WSSSE", rdd.map(_convert_to_vector))
class HierarchicalClustering(object):
@classmethod
def train(cls, rdd, k, maxIterations=100, maxRetries=10, seed=None):
"""Train a hierarchical clustering model."""
model = callMLlibFunc("trainHierarchicalClusteringModel", rdd.map(_convert_to_vector),
k, maxIterations, maxRetries, seed)
return HierarchicalClusteringModel(model)