Created
September 26, 2017 00:20
-
-
Save MishaelRosenthal/0eb599e4f8e0f3d8020f1d94085d07bd to your computer and use it in GitHub Desktop.
Almost optimal Disjoint-set data structure:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package com.twitter.timelines.data_processing.ad_hoc.feature_selection.utils | |
| import scala.annotation.tailrec | |
| import scala.collection.GenTraversableOnce | |
| /** | |
| * Almost optimal Disjoint-set data structure: | |
| * https://en.wikipedia.org/wiki/Disjoint-set_data_structure | |
| */ | |
| object DisjointSet { | |
| def empty[T]: DisjointSet[T] = DisjointSet(Map.empty[T, T]) | |
| } | |
| case class DisjointSet[T](parentMapping: Map[T, T]) { | |
| def union(x: T, y: T): DisjointSet[T] = compressPath(findPath(x) ++ findPath(y))._2 | |
| def unionTrav(trav: GenTraversableOnce[(T, T)]): DisjointSet[T] = | |
| trav.foldLeft(this) { case (acc, (x, y)) => acc.union(x, y) } | |
| def find(x: T): (T, DisjointSet[T]) = compressPath(findPath(x)) | |
| def findTrav(trav: Traversable[T]): (Map[T, T], DisjointSet[T]) = | |
| trav.foldLeft((Map.empty[T, T], this)) { | |
| case ((mapping, distjointSet), elem) => | |
| val (representative, newDistjointSet) = distjointSet.find(elem) | |
| (mapping.updated(elem, representative), newDistjointSet) | |
| } | |
| def sets(): (Set[Set[T]], DisjointSet[T]) = { | |
| val (mapping, updated) = findTrav(parentMapping.keys) | |
| val sets = mapping.groupBy(_._2).values.map(_.keySet).toSet | |
| (sets, updated) | |
| } | |
| private def updated(x: T, parent: T): DisjointSet[T] = DisjointSet(parentMapping.updated(x, parent)) | |
| protected def findPath(x: T): Iterator[T] = { | |
| @tailrec | |
| def findPath(x: T, acc: List[T]): List[T] = { | |
| parentMapping.get(x) match { | |
| case Some(parent) if parent != x => findPath(parent, parent :: acc) | |
| case _ => acc | |
| } | |
| } | |
| findPath(x, List(x)).iterator | |
| } | |
| protected def compressPath(path: Iterator[T]): (T, DisjointSet[T]) = { | |
| val representative = path.next() | |
| val compressed = path.foldLeft(this.updated(representative, representative)) { | |
| case (acc, value) => acc.updated(value, representative) | |
| } | |
| (representative, compressed) | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment