Skip to content

Instantly share code, notes, and snippets.

@d-plaindoux
Last active December 19, 2024 11:02
Show Gist options
  • Save d-plaindoux/5c1b26a7eede27601d7c0341999b15f9 to your computer and use it in GitHub Desktop.
Save d-plaindoux/5c1b26a7eede27601d7c0341999b15f9 to your computer and use it in GitHub Desktop.
Monomorphic function registry in Scala3 (thanks to @Ptival https://gist.github.com/Ptival/fbc92e38000f0771453bd2af571e1c1c example 🫶)
import :=:.Refl
import Type.*
import TypeRepr.*
import scala.annotation.tailrec
enum :=:[A, B]:
case Refl[E]() extends :=:[E, E]
enum Type:
case BoolT()
case ->:[A <: Type, B <: Type]()
case IntT()
case StringRT()
type Embed[T <: Type] =
T match
case BoolT => Boolean
case a ->: b => Embed[a] => Embed[b]
case IntT => Int
case StringRT => String
enum TypeRepr[T <: Type]:
case BoolR extends TypeRepr[BoolT]
case FunctionR[A <: Type, B <: Type](a: TypeRepr[A], b: TypeRepr[B]) extends TypeRepr[A ->: B]
case IntR extends TypeRepr[IntT]
case StringR extends TypeRepr[StringRT]
extension [A <: Type](a: TypeRepr[A])
def ->>:[B <: Type](b: TypeRepr[B]): TypeRepr[A ->: B] = FunctionR(a, b)
def eqTypeRepr[A <: Type, B <: Type](t1: TypeRepr[A], t2: TypeRepr[B]): Option[A :=: B] =
(t1, t2) match
case (BoolR, BoolR) => Some(Refl())
case (IntR, IntR) => Some(Refl())
case (StringR, StringR) => Some(Refl())
case (FunctionR(a, b), FunctionR(c, d)) =>
eqTypeRepr(a, c) match
case Some(Refl()) =>
eqTypeRepr(b, d) match
case Some(Refl()) => Some(Refl())
case None => None
case None => None
case _ => None
case class RegistryEntry[T <: Type](name: String, entry: TypeRepr[T], value: Embed[T])
type Registry = List[RegistryEntry[_]]
@tailrec
def find[T <: Type](registry: Registry, name: String, entry: TypeRepr[T]): Option[Embed[T]] =
registry match
case Nil => None
case r :: registry if name == r.name =>
eqTypeRepr(r.entry, entry) match
case Some(Refl()) => Some(r.value)
case None => find(registry, name, entry)
case _ :: registry => find(registry, name, entry)
val addIntR: Int => Int => Int = a => b => a + b
val addStringR: String => String => String = a => b => a + b
val myRegistry: Registry =
List(
RegistryEntry("add", IntR ->>: IntR ->>: IntR, addIntR),
RegistryEntry("add", StringR ->>: StringR ->>: StringR, addStringR)
)
private def main(): Option[(Int, String)] =
for (add <- find(myRegistry, "add", IntR ->>: IntR ->>: IntR);
concat <- find(myRegistry, "add", StringR ->>: StringR ->>: StringR))
yield (add(1)(2), concat("a")("b"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment