Last active
December 8, 2020 18:51
-
-
Save polytypic/fd85880ff1081acac308084b619a3178 to your computer and use it in GitHub Desktop.
Curious case of GADTs in Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Below is a simple attempt at using GADTs in Scala based on | |
// | |
// https://github.com/palladin/idris-snippets/blob/master/src/HOAS.idr | |
object Hoas { | |
// With case classes one can directly write down what looks like a GADT: | |
sealed trait Expr[A] | |
final case class Val[A](value: A) extends Expr[A] | |
final case class Bin[A, B, C](bin: (A, B) => C, lhs: Expr[A], rhs: Expr[B]) | |
extends Expr[C] | |
final case class If[A]( | |
condition: Expr[Boolean], | |
onTrue: Expr[A], | |
onFalse: Expr[A] | |
) extends Expr[A] | |
final case class App[A, B](function: Expr[A => B], argument: Expr[A]) | |
extends Expr[B] | |
final case class Lam[A, B](lambda: Expr[A] => Expr[B]) extends Expr[A => B] | |
final case class Fix[A, B](fix: Expr[(A => B) => A => B]) extends Expr[A => B] | |
// And here is the factorial function: | |
val fact: Expr[Int => Int] = Fix( | |
Lam((f: Expr[Int => Int]) => | |
Lam((x: Expr[Int]) => | |
If( | |
Bin((_: Int) == (_: Int), x, Val(0)), | |
Val(1), | |
Bin( | |
(_: Int) * (_: Int), | |
x, | |
App(f, Bin((_: Int) - (_: Int), x, Val(1))) | |
) | |
) | |
) | |
) | |
) | |
// This `eval` function also superficially seems fine: | |
def eval[T](expr: Expr[T]): T = expr match { | |
case Bin(f, x, y) => f(eval(x), eval(y)) | |
case If(b, c, a) => eval(if (eval(b)) c else a) | |
case App(f, x) => eval(f)(eval(x)) | |
case Lam(f) => x => eval(f(Val(x))) | |
case Fix(e) => { | |
val f = eval(e) | |
def rec(x: Any): Any = f(rec(_))(x) | |
rec(_) | |
} | |
case Val(x) => x | |
} | |
// And one even gets the expected result: | |
val one_hundred_twenty: Int = eval(App(fact, Val(5))) | |
// The interesting thing, however, is that the types inside the `eval` | |
// function are not what one might expect. For example, the type of `f` is | |
// `(Any, Any) => T` and the type of `x` and `y` is `Any`. This means that | |
// the expression `f(eval(x), eval(y))` is not typed precisely. If one would | |
// introduce a bug by flipping the arguments to `f(eval(y), eval(x))` the | |
// code would still pass the type checker. This doesn't happen in languages | |
// with proper support for GADTs. | |
// | |
// First of all, I think that this is bad. Writing code in a | |
// straightforward manner just doesn't give you the expected guarantees. | |
// | |
// Second of all, how should we implement this example in Scala 2 in a type | |
// safe manner? (Asking, because I don't yet know how.) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This is my second attempt to encode GADTs in Scala. The example is based on | |
// | |
// https://github.com/palladin/idris-snippets/blob/master/src/HOAS.idr | |
// | |
// and the encoding technique is inspired by the paper | |
// | |
// GADTs for the OCaml Masses | |
// http://homepage.cs.uiowa.edu/~astump/papers/icfp09.pdf | |
// | |
// and is basically a Scott encoding. | |
object HoasScott { | |
// We need the identity type constructor a bit later... | |
type Id[T] = T | |
// The `ExprDestructor` trait defines the cases of an expression. The | |
// parameter `Result` is the type constructor for the result of | |
// destructuring an expression. | |
trait ExprDestructor[Result[_]] { | |
def Val[A](value: A): Result[A] | |
def Bin[A, B, C](bin: (A, B) => C, lhs: Expr[A], rhs: Expr[B]): Result[C] | |
def If[A]( | |
condition: Expr[Boolean], | |
onTrue: Expr[A], | |
onFalse: Expr[A] | |
): Result[A] | |
def App[A, B](function: Expr[A => B], argument: Expr[A]): Result[B] | |
def Lam[A, B](lambda: Expr[A] => Expr[B]): Result[A => B] | |
def Fix[A, B](fix: Expr[(A => B) => A => B]): Result[A => B] | |
} | |
// The `Expr` trait is the actual type of expressions and can be called to | |
// destructure the expression. | |
trait Expr[T] { | |
def apply[Result[_]](handler: ExprDestructor[Result]): Result[T] | |
} | |
// The `Expr` object implements the `ExprDestructor` for the `Expr` trait | |
// itself to construct `Expr` values. | |
object Expr extends ExprDestructor[Expr] { | |
def Val[A](value: A) = new Expr[A] { | |
def apply[Result[_]](handler: ExprDestructor[Result]) = handler.Val(value) | |
} | |
def Bin[A, B, C](bin: (A, B) => C, lhs: Expr[A], rhs: Expr[B]) = | |
new Expr[C] { | |
def apply[Result[_]](handler: ExprDestructor[Result]) = | |
handler.Bin(bin, lhs, rhs) | |
} | |
def If[A]( | |
condition: Expr[Boolean], | |
onTrue: Expr[A], | |
onFalse: Expr[A] | |
) = new Expr[A] { | |
def apply[Result[_]](handler: ExprDestructor[Result]) = | |
handler.If(condition, onTrue, onFalse) | |
} | |
def App[A, B](function: Expr[A => B], argument: Expr[A]) = new Expr[B] { | |
def apply[Result[_]](handler: ExprDestructor[Result]) = | |
handler.App(function, argument) | |
} | |
def Lam[A, B](lambda: Expr[A] => Expr[B]) = new Expr[A => B] { | |
def apply[Result[_]](handler: ExprDestructor[Result]) = | |
handler.Lam(lambda) | |
} | |
def Fix[A, B](fix: Expr[(A => B) => A => B]) = new Expr[A => B] { | |
def apply[Result[_]](handler: ExprDestructor[Result]) = handler.Fix(fix) | |
} | |
} | |
import Expr._ | |
// The factorial function expression. | |
val fact: Expr[Int => Int] = Fix( | |
Lam((f: Expr[Int => Int]) => | |
Lam((x: Expr[Int]) => | |
If( | |
Bin((_: Int) == (_: Int), x, Val(0)), | |
Val(1), | |
Bin( | |
(_: Int) * (_: Int), | |
x, | |
App(f, Bin((_: Int) - (_: Int), x, Val(1))) | |
) | |
) | |
) | |
) | |
) | |
// Typed interpreter for typed expressions. | |
def eval[T](expr: Expr[T]): T = | |
expr(new ExprDestructor[Id] { | |
def Val[A](value: A) = value | |
def Bin[A, B, C](bin: (A, B) => C, lhs: Expr[A], rhs: Expr[B]) = | |
bin(eval(lhs), eval(rhs)) | |
def If[A]( | |
condition: Expr[Boolean], | |
onTrue: Expr[A], | |
onFalse: Expr[A] | |
) = eval(if (eval(condition)) onTrue else onFalse) | |
def App[A, B](function: Expr[A => B], argument: Expr[A]) = | |
eval(function)(eval(argument)) | |
def Lam[A, B](lambda: Expr[A] => Expr[B]) = | |
(argument: A) => eval(lambda(Expr.Val(argument))) | |
def Fix[A, B](fix: Expr[(A => B) => A => B]) = { | |
val fn = eval(fix) | |
def rec(x: A): B = fn(rec(_))(x) | |
rec(_) | |
} | |
}) | |
// This evaluates to 120. | |
val one_hundred_twenty: Int = eval(App(fact, Val(5))) | |
// While this encoding basically works, it has several downsides: | |
// - Some amount of boilerplate is required to define a GADT this way. | |
// - Destructuring does not directly support nesting (or other nice pattern | |
// matching features). | |
// - Impossible cases are not filtered out. | |
// - The encoding requires constructing an object per constructor. | |
// - Destructuring requires constructing an object per call. | |
// | |
// Are there better ways to encode GADTs in Scala? | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment