Skip to content

Instantly share code, notes, and snippets.

@yu-iskw
Last active August 29, 2015 14:24
Show Gist options
  • Save yu-iskw/4d7ede75475ba9dc6f9a to your computer and use it in GitHub Desktop.
Save yu-iskw/4d7ede75475ba9dc6f9a to your computer and use it in GitHub Desktop.
Removed HierarchicalClustering in Python

Scala Code

  /**
   * Java stub for Python mllib HierarchicalClustering.run()
   */
  def trainHierarchicalClusteringModel(
    data: JavaRDD[Vector],
    k: Int,
    maxIterations: Int,
    maxRetries: Int,
    seed: java.lang.Long): HierarchicalClusteringModel = {
    val algo = new HierarchicalClustering()
        .setNumClusters(k)
        .setMaxIterations(maxIterations)
        .setMaxRetries(maxRetries)

    if (seed != null)  algo.setSeed(seed)

    try {
      algo.run(data)
    } finally {
      data.rdd.unpersist(blocking = false)
    }
  }

Python Code

@inherit_doc
class HierarchicalClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):

    """A clustering model derived from the hierarchical clustering method.

    >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2)
    >>> rdd = sc.parallelize(data)
    >>> model = HierarchicalClustering.train(rdd, 2)
    >>> len(model.clusterCenters)
    2
    >>> model.predict(array([0.0, 0.0])) == model.predict(array([1.0, 1.0]))
    True
    >>> model.predict(array([8.0, 9.0])) == model.predict(array([9.0, 8.0]))
    True
    >>> abs(model.WSSSE(rdd) - 2.82842712) < 10e-8
    True
    >>> len(model.toLinkageMatrix())
    1
    >>> len(model.toAdjacencyList())
    2

    >>> sparse_data = [
    ...     SparseVector(3, {1: 1.0}),
    ...     SparseVector(3, {1: 1.1}),
    ...     SparseVector(3, {2: 1.0}),
    ...     SparseVector(3, {2: 1.1})
    ... ]
    >>> sparse_rdd = sc.parallelize(sparse_data)
    >>> model = HierarchicalClustering.train(sparse_rdd, 2)
    >>> model.predict(array([0., 1., 0.])) == model.predict(array([0, 1.1, 0.]))
    True
    >>> model.predict(array([0., 0., 1.])) == model.predict(array([0, 0, 1.1]))
    True
    >>> model.predict(sparse_data[0]) == model.predict(sparse_data[1])
    True
    >>> model.predict(sparse_data[2]) == model.predict(sparse_data[3])
    True
    >>> len(model.clusterCenters)
    2
    >>> abs(model.WSSSE(sparse_rdd) - 0.2) < 10e-2
    True
    >>> len(model.toLinkageMatrix())
    1
    >>> len(model.toAdjacencyList())
    2
    """

    def predict(self, x):
        """Find the cluster to which x belongs in this model."""
        if isinstance(x, RDD):
            return self.call("predict", x.map(_convert_to_vector))
        else:
            return self.call("predict", _convert_to_vector(x))

    def toAdjacencyList(self):
        """Convert a cluster dendrogram to a adjacency list with distances as their weights."""
        return self.call("toJavaAdjacencyList")

    def toLinkageMatrix(self):
        return self.call("toJavaLinkageMatrix")

    @property
    def clusterCenters(self):
        """Get the cluster centers, represented as a list of NumPy arrays."""
        centers = _java2py(self._sc, self.call("getCenters"))
        return [c.toArray() for c in centers]

    def WSSSE(self, rdd):
        """Get Within Set Sum of Squared Error (WSSSE)."""
        return self.call("WSSSE", rdd.map(_convert_to_vector))


class HierarchicalClustering(object):

    @classmethod
    def train(cls, rdd, k, maxIterations=100, maxRetries=10, seed=None):
        """Train a hierarchical clustering model."""
        model = callMLlibFunc("trainHierarchicalClusteringModel", rdd.map(_convert_to_vector),
                              k, maxIterations, maxRetries, seed)
        return HierarchicalClusteringModel(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment