Last active
December 23, 2019 17:01
-
-
Save kalexmills/4e2724ffa5ba7d1f90fdf7f0de9242e0 to your computer and use it in GitHub Desktop.
MultiSet datastructure in Scala w/ Cats (Foldable and Monad)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.niftysoft.gennit.util | |
import cats._ | |
import cats.implicits._ | |
import scala.annotation.tailrec | |
case class MultiSet[V] private (data: Map[V,Int]) { | |
def filter(f: V => Boolean): MultiSet[V] = MultiSet(data.filter{case(v, mul) => f(v)}) | |
def multiplicity(elem: V): Int = data.getOrElse(elem, 0) | |
def contains(elem: V): Boolean = data.contains(elem) | |
def mult(factor: Int): MultiSet[V] = | |
MultiSet( | |
data.map{case (x -> count) => (x -> count * factor)} | |
) | |
def addMany(elem: V, num: Int): MultiSet[V] = | |
MultiSet( | |
data + (elem -> (multiplicity(elem) + num)) | |
) | |
def excl(elem: V): MultiSet[V] = | |
MultiSet( | |
if(multiplicity(elem) == 0) { | |
data | |
} else if (multiplicity(elem) == 1) { | |
data - elem | |
} else { | |
data + (elem -> (multiplicity(elem) - 1)) | |
} | |
) | |
def exclAll(elem: V): MultiSet[V] = | |
MultiSet( | |
data - elem | |
) | |
def incl(elem: V): MultiSet[V] = addMany(elem, 1) | |
def diff(other: Set[V]): MultiSet[V] = | |
diff(MultiSet(other)) | |
def diff(other: Seq[V]): MultiSet[V] = | |
diff(MultiSet(other:_*)) | |
def diff(other: MultiSet[V]): MultiSet[V] = | |
MultiSet( | |
data.map{case (v,mul) => (v, mul - other.multiplicity(v))} | |
.filter{case (v, mul) => mul > 0} | |
) | |
def sum(other: Set[V]): MultiSet[V] = | |
sum(MultiSet(other)) | |
def sum(other: Seq[V]): MultiSet[V] = | |
sum(MultiSet(other:_*)) | |
def sum(other: MultiSet[V]): MultiSet[V] = | |
MultiSet( | |
data.map{case (v, mul) => (v, other.multiplicity(v) + mul)} ++ | |
(other.data -- data.keySet) | |
) | |
def union(other: Seq[V]): MultiSet[V] = | |
union(MultiSet(other:_*)) | |
def union(other: Set[V]): MultiSet[V] = | |
union(MultiSet(other)) | |
def union(other: MultiSet[V]): MultiSet[V] = | |
MultiSet( | |
data.map{case (v, mul) => (v, Math.max(other.multiplicity(v), mul))} ++ | |
(other.data -- data.keySet)) | |
def intersect(other: Seq[V]): MultiSet[V] = | |
intersect(MultiSet(other:_*)) | |
def intersect(other: MultiSet[V]): MultiSet[V] = | |
MultiSet( | |
data.map{case (v, mul) => (v, Math.min(other.multiplicity(v), mul))} | |
.filter{case (v, mul) => mul > 0} | |
) | |
def toList: List[V] = iterator.toList | |
def iterator: Iterator[V] = new Iterator[V] { | |
private[this] val keys = data.keysIterator | |
private[this] var curr: Option[V] = if (keys.hasNext) Some(keys.next()) else None | |
private[this] var valLeft: Int = currMult() | |
def hasNext: Boolean = keys.hasNext || valLeft > 0 | |
def next(): V = | |
if (valLeft > 0) { | |
valLeft -= 1 | |
curr.get | |
} else { | |
curr = Some(keys.next()) // throws NoSuchElementException as needed | |
valLeft = currMult() - 1 | |
curr.get | |
} | |
private[this] def currMult(): Int = { | |
curr.map(data(_)).getOrElse(0) | |
} | |
} | |
override def equals(o: Any): Boolean = { | |
o match { | |
case ms @ MultiSet(data) => this.data.equals(data) | |
case _ => false | |
} | |
} | |
override def toString(): String = { | |
data.toList.map{case (x, count) => | |
List(x.toString) | |
.replicateA(count) | |
.flatten | |
.intercalate(", ")} | |
.intercalate(", ") | |
} | |
} | |
object MultiSet { | |
def apply[A](): MultiSet[A] = new MultiSet(Map()) | |
def apply[A](x: A*): MultiSet[A] = new MultiSet(x.groupBy(identity).map{case (v, s) => (v, s.length)}) | |
def apply[A](x: Set[A]): MultiSet[A] = new MultiSet(x.map{x => (x -> 1)}.toMap) | |
implicit val functorForMultiset = new Functor[MultiSet] { | |
def map[A, B](fa: MultiSet[A])(f: A => B): MultiSet[B] = | |
MultiSet(fa.data.map{case(a, mul) => (f(a), mul)}) | |
} | |
implicit val foldableForMultiset = new Foldable[MultiSet] { | |
import cats._ | |
import cats.implicits._ | |
def foldLeft[A, B](fa: MultiSet[A], b: B)(f: (B, A) => B): B = | |
fa.iterator.foldLeft(b)(f) | |
def foldRight[A, B](fa: MultiSet[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = | |
fa.iterator.toList.foldRight(lb)(f) | |
} | |
implicit val monadForMultiset = new Monad[MultiSet] { | |
def pure[A](x: A): MultiSet[A] = MultiSet(x) | |
def flatMap[A, B](fa: MultiSet[A])(f: A => MultiSet[B]): MultiSet[B] = | |
fa.foldLeft(MultiSet[B]())((set, a) => set.union(f(a).mult(fa.multiplicity(a)))) | |
def tailRecM[A, B](a: A)(f: A => MultiSet[Either[A,B]]): MultiSet[B] = { | |
var buf = MultiSet[B]() | |
@tailrec | |
def go(sets: List[MultiSet[Either[A,B]]]): Unit = sets match { | |
case set :: tail => set.data.toList match { | |
case (x -> count) :: rest => x match { | |
case Right(b) => buf.addMany(b, count); go(MultiSet(rest.toMap) :: tail) | |
case Left(a) => go(f(a) :: MultiSet(rest.toMap) :: tail) | |
} | |
case Nil => go(tail) | |
} | |
case Nil => () | |
} | |
go(f(a) :: Nil) | |
buf | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.niftysoft.gennit.util | |
import org.scalatest._ | |
import cats.implicits._ | |
class MultiSetSpec extends FlatSpec with Matchers { | |
"MultiSet" should "work on empty set" in { | |
val x: MultiSet[Int] = MultiSet() | |
x.contains(1) shouldEqual (false) | |
x.iterator.toList shouldEqual (List.empty) | |
} | |
it should "sum arguments to apply correctly" in { | |
val x: MultiSet[Int] = MultiSet(1,1) | |
x.multiplicity(1) shouldEqual (2) | |
x.iterator.toList shouldEqual (List(1,1)) | |
} | |
it should "remove things when excl is called" in { | |
val x: MultiSet[Int] = MultiSet().addMany(1,2) | |
x.excl(1).iterator.toList shouldEqual (List(1)) | |
} | |
it should "admit multiple elements" in { | |
val x: MultiSet[Int] = MultiSet().addMany(2, 6) | |
x.multiplicity(2) shouldEqual (6) | |
x.contains(2) shouldEqual (true) | |
x.iterator.toList shouldEqual (List(2,2,2,2,2,2)) | |
} | |
it should "implement difference" in { | |
val x = MultiSet().addMany(0, 4).addMany(1,3) | |
val y = MultiSet().addMany(0, 2).addMany(1,5) | |
val xdiffy = x.diff(y) | |
val ydiffx = y.diff(x) | |
xdiffy.iterator.toList.sorted shouldEqual (List(0,0)) | |
ydiffx.iterator.toList.sorted shouldEqual (List(1,1)) | |
} | |
it should "implement sums" in { | |
val x = MultiSet().addMany(0, 4).addMany(1,3) | |
val y = MultiSet().addMany(0, 2).addMany(1,5) | |
val xplusy = x.sum(y) | |
xplusy.iterator.toList.sorted shouldEqual (List(0,0,0,0,0,0,1,1,1,1,1,1,1,1)) | |
} | |
it should "admit unions with sets" in { | |
val x = MultiSet().addMany(0,4) | |
val y = Set(1,2,3) | |
x.union(y).iterator.toList.sorted shouldEqual (List(0,0,0,0,1,2,3)) | |
} | |
it should "implement unions" in { | |
val x = MultiSet().addMany(0, 4).addMany(1,3) | |
val y = MultiSet().addMany(0, 2).addMany(1,5) | |
val xuy = x.union(y) | |
xuy.iterator.toList.sorted shouldEqual (List(0,0,0,0,1,1,1,1,1)) | |
} | |
it should "implement intersections" in { | |
val x = MultiSet().addMany(0, 4).addMany(1,3) | |
val y = MultiSet().addMany(0, 2).addMany(1,5) | |
val xny = x.intersect(y) | |
xny.iterator.toList.sorted shouldEqual (List(0,0,1,1,1)) | |
} | |
it should "be creatable from sets" in { | |
val x = MultiSet(Set(1,2,3)) | |
x.iterator.toList.sorted shouldEqual (List(1,2,3)) | |
} | |
it should "admit non-deterministic monadic computations" in { | |
var x = MultiSet(Set(1,2,3)) | |
x.flatMap(x => MultiSet().addMany(x, 2)).iterator.toList.sorted shouldEqual (List(1,1,2,2,3,3)) | |
x = MultiSet(Set(1,3,5)) | |
x.flatMap(x => MultiSet(x, x+1)).iterator.toList.sorted shouldEqual (List(1,2,3,4,5,6)) | |
} | |
it should "multiply existing elements when asked" in { | |
val x = MultiSet(Set('a','b','c')).mult(3) | |
x.iterator.toList.sorted shouldEqual (List('a','a','a','b','b','b','c','c','c')) | |
} | |
it should "allow unary multplication" in { | |
val x = MultiSet().addMany(1, 3) | |
x.flatMap(x => MultiSet(x,x)).iterator.toList.sorted shouldEqual(List(1,1,1,1,1,1)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note: this is not a high-performance data structure, but it should be useful on slow paths where MultiSet semantics are needed.
It's possible that the explicit
Functor
implementation is faster than using the auto-generated implementation based onmap
andpure
fromMonad
, so I'm leaving it. Have yet to see problems with implicit resolution, though I suppose that could happen in theory.