-
-
Save calvinlfer/05869a0b2a57e04a59e2c68a1a7c3082 to your computer and use it in GitHub Desktop.
A generic partitioning of a list of enums
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//> using scala 3.6.3 | |
import scala.deriving.Mirror | |
import scala.compiletime.{erasedValue} | |
/** Type-level function that transforms a type-level tuple of types | |
* into a corresponding tuple of List[...] at the value level. | |
* | |
* E.g., if Ts is (T1, T2, T3), Partitioned[Ts] is (List[T1], List[T2], List[T3]). | |
*/ | |
type Partitioned[Ts <: Tuple] <: Tuple = Ts match | |
case EmptyTuple => EmptyTuple | |
case t *: ts => List[t] *: Partitioned[ts] | |
/** Top-level inline function that partitions a List[A] into a typed tuple of lists | |
* corresponding to each subtype of A (as given by Mirror.SumOf[A]). | |
*/ | |
inline def partition[A](xs: List[A])(using m: Mirror.SumOf[A]): Partitioned[m.MirroredElemTypes] = | |
partitionAll[A, m.MirroredElemTypes](xs) | |
/** Recursively build the nested tuple of lists for each type in Ts. | |
* We do so by splitting out the first type `t` from Ts, finding all elements | |
* in `xs` that are instances of `t`, and then recursing for the rest. | |
*/ | |
inline def partitionAll[A, Ts <: Tuple](xs: List[A]): Partitioned[Ts] = | |
inline erasedValue[Ts] match | |
case _: EmptyTuple => | |
EmptyTuple | |
case _: (t *: ts) => | |
// Partition out all elements that match subtype `t` | |
val (these, others) = xs.partition(_.isInstanceOf[t]) | |
// Map them to the precise subtype | |
these.map(_.asInstanceOf[t]) *: partitionAll[A, ts](others) | |
enum E: | |
case A, B, C | |
@main def runPartition(): Unit = | |
val xs: List[E] = List(E.A, E.B, E.A, E.C, E.B) | |
val partitioned = partition(xs) | |
// partitioned has type: List[E.A] *: List[E.B] *: List[E.C] *: EmptyTuple | |
// You can decompose it like this: | |
val (as, bs, cs) = partitioned | |
println(s"As: $as") // e.g., List(E.A, E.A) | |
println(s"Bs: $bs") // e.g., List(E.B, E.B) | |
println(s"Cs: $cs") // e.g., List(E.C) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment