Skip to content

Instantly share code, notes, and snippets.

@MishaelRosenthal
Created September 26, 2017 00:20
Show Gist options
  • Select an option

  • Save MishaelRosenthal/0eb599e4f8e0f3d8020f1d94085d07bd to your computer and use it in GitHub Desktop.

Select an option

Save MishaelRosenthal/0eb599e4f8e0f3d8020f1d94085d07bd to your computer and use it in GitHub Desktop.
Almost optimal Disjoint-set data structure: https://en.wikipedia.org/wiki/Disjoint-set_data_structure
package com.twitter.timelines.data_processing.ad_hoc.feature_selection.utils
import scala.annotation.tailrec
import scala.collection.GenTraversableOnce
/**
* Almost optimal Disjoint-set data structure:
* https://en.wikipedia.org/wiki/Disjoint-set_data_structure
*/
object DisjointSet {
def empty[T]: DisjointSet[T] = DisjointSet(Map.empty[T, T])
}
case class DisjointSet[T](parentMapping: Map[T, T]) {
def union(x: T, y: T): DisjointSet[T] = compressPath(findPath(x) ++ findPath(y))._2
def unionTrav(trav: GenTraversableOnce[(T, T)]): DisjointSet[T] =
trav.foldLeft(this) { case (acc, (x, y)) => acc.union(x, y) }
def find(x: T): (T, DisjointSet[T]) = compressPath(findPath(x))
def findTrav(trav: Traversable[T]): (Map[T, T], DisjointSet[T]) =
trav.foldLeft((Map.empty[T, T], this)) {
case ((mapping, distjointSet), elem) =>
val (representative, newDistjointSet) = distjointSet.find(elem)
(mapping.updated(elem, representative), newDistjointSet)
}
def sets(): (Set[Set[T]], DisjointSet[T]) = {
val (mapping, updated) = findTrav(parentMapping.keys)
val sets = mapping.groupBy(_._2).values.map(_.keySet).toSet
(sets, updated)
}
private def updated(x: T, parent: T): DisjointSet[T] = DisjointSet(parentMapping.updated(x, parent))
protected def findPath(x: T): Iterator[T] = {
@tailrec
def findPath(x: T, acc: List[T]): List[T] = {
parentMapping.get(x) match {
case Some(parent) if parent != x => findPath(parent, parent :: acc)
case _ => acc
}
}
findPath(x, List(x)).iterator
}
protected def compressPath(path: Iterator[T]): (T, DisjointSet[T]) = {
val representative = path.next()
val compressed = path.foldLeft(this.updated(representative, representative)) {
case (acc, value) => acc.updated(value, representative)
}
(representative, compressed)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment