Last active
August 29, 2015 14:06
-
-
Save Blaisorblade/1622b0809effb9a56061 to your computer and use it in GitHub Desktop.
Trying to implement a typed tree transform ([T]Exp[T] => Exp[∆[T]]) in Scala, where ∆[_] is a type function, with some successes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Also at https://gist.github.com/Blaisorblade/1622b0809effb9a56061 | |
package ilc | |
/** | |
* Incremental lambda calculus - an attempt at a *typed* Scala implementation. | |
* Should you want more details on the particular transform, see http://inc-lc.github.io/ and our | |
* PLDI paper at http://www.informatik.uni-marburg.de/~pgiarrusso/papers/pldi14-ilc-author-final.pdf | |
* — the transform is in Fig. 4(g). | |
* | |
* However, I **quickly** wrote a minimal introduction here so that you can make sense of the code. | |
* I am not really explaining why this is useful — there's the paper for it. | |
* | |
* The goal is to turn a program `t : σ → τ` into their incremental version/derivative `Derive(t)`, | |
* which takes an input value and its change to an output change. | |
* Hence `Derive(t)` has type `∆ (σ → τ) = σ → ∆σ → ∆τ`. More in general, | |
* you can say `Derive`'s type is `[T] T => ∆T`. | |
* | |
* (I won't explain here what happens to the context, because the HOAS representation of functions | |
* hides that away anyway, so contexts don't show up in the code. The type I gave is true for closed terms). | |
* | |
* At the same time, the derivative has the type of a change (because it is a nil change, but | |
* I won't try explaining that here). | |
* | |
* This transform takes a term t : T to its nil change/derivative (it's the same) Derive(t) : ∆T. | |
* ∆T is the type of changes to T, and it's a type function. On functions it's: | |
* | |
* ∆ (σ → τ) = σ → ∆σ → ∆τ | |
* | |
* Here's the transform on terms: | |
* | |
* Derive(λx. t) = λx dx. Derive(t) | |
* Derive(s t) = Derive(s) t Derive(t) | |
* Derive(x) = dx | |
* | |
* You also need extra cases for base types and constants. Below are examples for integers and arithmetic. | |
* | |
* Tested with Scala 2.11.0-M8 and 2.11.2 — 2.10.x would unsoundly allow things which fail here, see below. | |
* You can even run it to observe the transform in action! | |
*/ | |
/* | |
* We represent the universe of simple types in Scala at the the type-level. | |
* Implementation inspired by http://apocalisp.wordpress.com/2010/06/13/type-level-programming-in-scala-part-3-boolean/. | |
*/ | |
sealed trait Type { | |
type Eval | |
type DT <: Type | |
} | |
sealed trait BaseType[BaseT] extends Type { | |
type Eval = BaseT | |
} | |
sealed trait BaseInt extends BaseType[Int] { | |
type DT = BaseInt | |
} | |
/* | |
* I tried adding subtyping here using Scala's subtyping, but it does not work. | |
* Eval uses BaseS and BaseT invariantly. | |
* Note that this would have been (unsoundly) accepted from 2.8.x to 2.10.x. | |
* To correctly encode subtyping, the next step might be to also embed | |
* subtyping judgements for lambda_<:, but this is not really practical | |
* *for integration with LMS*. However, it would still allow showing that derive | |
* works for lambda_sub, especially with a meta-interpreter of subtyping judgements. | |
*/ | |
sealed trait =>:[SPar <: Type, TPar <: Type] extends Type { | |
type S = SPar | |
type T = TPar | |
type BaseS = S#Eval | |
type BaseT = T#Eval | |
type Eval = BaseS => BaseT | |
type DT = S =>: S#DT =>: T#DT | |
} | |
//A first attempt at writing the type. | |
//def derive[T <: Type](v: T#Eval): T#DT#Eval = ??? | |
//Cool! It worked! Let's move on to the real thing. | |
// A typeclass for (erased) change structures. | |
trait ΔBase[ReprT <: Type] { | |
type T = ReprT#Eval | |
type DT = ReprT#DT#Eval | |
def ⊕(t: T, dt: DT): T | |
def ⊖(tNew: T, tOld: T): DT | |
def ∘(dt1: DT, dt2: DT): DT | |
} | |
object Ops { | |
implicit class InfixChangeValueOps[ReprT <: Type](dt: ReprT#DT#Eval)(implicit Δt: ΔBase[ReprT]) { | |
type DT = ReprT#DT#Eval | |
def ∘(dt2: DT): DT = Δt.∘(dt, dt2) | |
} | |
implicit class InfixBaseValueOps[ReprT <: Type](t: ReprT#Eval)(implicit Δt: ΔBase[ReprT]) { | |
type T = ReprT#Eval | |
type DT = ReprT#DT#Eval | |
def ⊕(dt: DT): T = Δt.⊕(t, dt) | |
def ⊖(tOld: T): DT = Δt.⊖(t, tOld) | |
} | |
} | |
import Ops._ | |
class ΔInt extends ΔBase[BaseInt] { | |
def ⊕(t: T, dt: DT): T = t + dt | |
def ⊖(tNew: T, tOld: T): DT = tNew - tOld | |
def ∘(dt1: DT, dt2: DT): DT = dt1 + dt2 | |
} | |
class ΔFun[ReprS <: Type, ReprT <: Type](implicit Δs: ΔBase[ReprS], Δt: ΔBase[ReprT]) extends ΔBase[ReprS =>: ReprT] { | |
def ⊕(t: T, dt: DT): T = | |
x => t(x) ⊕ dt(x)(x ⊖ x) | |
def ⊖(tNew: T, tOld: T): DT = | |
x => dx => tNew(x ⊕ dx) ⊖ tOld(x) | |
def ∘(dt1: DT, dt2: DT): DT = | |
// ??? Is this definition correct? | |
x => dx => dt1(x)(dx) ∘ dt2(x)(dx) | |
} | |
//Theoretically useful, but I never needed this. | |
//type Δ[T <: Type, DTP] = ΔBase[T] { type DT = DTP } | |
//Can we synthesize an instance by looking at the type? That would require macros. Luckily, we might not need that. | |
//def f[T <: Type]: ΔBase[T] = ??? | |
trait Exp[TP <: Type] { | |
type T = TP | |
def derive: Exp[T#DT] | |
} | |
final case class Num(t: Int) extends Exp[BaseInt] { | |
def derive: Exp[T#DT] = Num(0) | |
} | |
case class Plus(a: Exp[BaseInt], b: Exp[BaseInt]) extends Exp[BaseInt] { | |
def derive: Exp[T#DT] = Plus(a.derive, b.derive) | |
} | |
case class App[S <: Type, T <: Type](fun: Exp[S =>: T], arg: Exp[S]) extends Exp[T] { | |
def derive: Exp[T#DT] = | |
App(App(fun.derive, arg), arg.derive) | |
} | |
trait Name { | |
def name: String | |
def derive: Name = DerivedName(this) | |
//XXX hack, we'd need a proper pretty-printer, but that's so much boilerplate that | |
//I won't bother for now. | |
override def toString = name | |
} | |
case class BaseName(name: String) extends Name | |
case class DerivedName(base: Name) extends Name { | |
def name = "d" + base.name | |
} | |
case class Var[T <: Type](name: Name) extends Exp[T] { | |
def derive: Var[T#DT] = Var(name.derive) | |
} | |
case class Fun[SPar <: Type, UPar <: Type](v: Var[SPar], body: Exp[UPar]) extends Exp[SPar =>: UPar] { | |
type S = SPar | |
type U = UPar | |
def derive: Exp[T#DT] = | |
Fun(v, Fun(v.derive, body.derive)) | |
} | |
//Don't write S, T as params, since T is shadowed inside. | |
// We need to record also the variable we want... | |
case class HOASFun[SPar <: Type, TPar <: Type](v: Var[SPar], fun: Exp[SPar] => Exp[TPar]) extends Exp[SPar =>: TPar] { | |
def derive: Exp[T#DT] = | |
// ... so that we can specify it here. Without that, the code would | |
// typecheck, but pick the "wrong" variable when converting to a first-order | |
// representation, because derivation on variables is too "nominal" to be | |
// expressed in HOAS otherwise. | |
HOASFun(v, x => HOASFun(v.derive, dx => fun(x).derive)) | |
override def toString = toFun.toString | |
def toFun: Exp[T] = Fun(v, fun(v)) | |
} | |
trait HoasWrappers { | |
private var counter = -1 | |
def fresh[S <: Type](): Var[S] = { | |
counter += 1 | |
Var(BaseName("x" + counter)) | |
} | |
def funBase[S <: Type, T <: Type](fun: Exp[S] => Exp[T]): Exp[S =>: T] = | |
HOASFun(fresh[S](), fun) | |
//Curried type application. | |
//Complete "signature": | |
// fun[S <: Type][T <: Type](fun: Exp[S] => Exp[T]): Exp[S =>: T] = funBase[S, T](fun) | |
// | |
//but typically, you only write: | |
// fun[S](x => body) | |
//and Exp[S] is used as type annotation for x. | |
def fun[S <: Type] = new CurriedFun[S] | |
class CurriedFun[S <: Type] { | |
def apply[T <: Type](fun: Exp[S] => Exp[T]) = funBase(fun) | |
} | |
} | |
case class Fix[T <: Type](body: Exp[T =>: T]) extends Exp[T] { | |
def derive: Exp[T#DT] = | |
Fix(App(body.derive, this)) | |
//equivalent to: | |
//Fix(App(body.derive, Fix(body))) | |
} | |
//XXX: not a real test. | |
object DeriveTest extends scala.App with HoasWrappers { | |
//The order of terms is vals first, defs after, to ensure fresh vars are | |
//generated in the same order as they appear. | |
// | |
//This can be useful for inspecting that fresh variable generation happens in | |
//the operationally expected way (since no reduction is done, no fresh | |
//variable should be generated and discarded). | |
val id = fun[BaseInt](x => x) | |
val power = fix[BaseInt =>: BaseInt =>: BaseInt](power => fun(n => fun(exp => n /*more interesting body needed*/))) | |
def ap[S <: Type, T <: Type] = fun[S =>: T](f => fun[S](arg => App(f, arg))) | |
def fix[T <: Type](body: Exp[T] => Exp[T]): Exp[T] = Fix(fun(body)) | |
println(id) | |
println(id.derive) | |
println(power) | |
println(power.derive) | |
val apInt = ap[BaseInt, BaseInt] | |
println(apInt) | |
println(apInt.derive) | |
} | |
//Writing derive as a pattern-matching method exposes a bug in type-refinement. See BugReport.scala | |
trait TryExternalDerive { | |
//So we need unsafeCoerce here: | |
def unsafeCoerce[T <: Type, U <: Type](a: Exp[T]): Exp[U] = a.asInstanceOf[Exp[U]] | |
//That's what GADTs are translated to anyway, when Scalac manages. Luckily, we know better. | |
def deriveVar[T <: Type](v: Var[T]): Var[T#DT] = Var(v.name.derive) | |
def derive[T <: Type](term: Exp[T]): Exp[T#DT] = { | |
term match { | |
case Num(n) => | |
unsafeCoerce(Num(n)) | |
case Fun(v, body) => | |
unsafeCoerce(Fun(v, Fun(deriveVar(v), derive(body)))) | |
case App(fun, arg) => | |
App(App(derive(fun), arg), derive(arg)) | |
case v: Var[_] => | |
deriveVar(v) | |
} | |
} | |
} | |
/* | |
// A related bug report | |
object BugReportReduced { | |
trait Type { | |
type DT <: Type | |
} | |
trait BaseInt extends Type { | |
type DT = BaseInt | |
} | |
trait Exp[TP <: Type] | |
class Num extends Exp[BaseInt] | |
def derive[T <: Type](term: Exp[T]): Any = { | |
val res0: Exp[T] = term match { case _: Num => (??? : Exp[T#DT]) } | |
val res1: Exp[T#DT] = term match { case _: Num => (??? : Exp[T#DT]) } | |
res1: Exp[T#DT] // fails, the prefix of the type projection in the case body underwent wanted GADT refinement, but this is not reflected here. | |
val res1b = term match { case _ : Num => (new Num : Exp[T#DT]) } //works | |
res1b: Exp[T#DT] //fails | |
type X = T#DT | |
val res2 = term match { case _: Num => (??? : Exp[X]) } | |
res2: Exp[T#DT] // works | |
val res2b = term match { case _: Num => (new Num : Exp[X]) } //fails | |
res2b: Exp[T#DT] // this part works | |
} | |
def derive2[T <: Type](term: Exp[T]): Exp[T] = { | |
term match { | |
case _ : Num => (??? : Exp[BaseInt]) | |
case _ : Num => (??? : Exp[T]) | |
} | |
} | |
def derive2Harder[T <: Type](term: Exp[T]): Exp[T] = { | |
val res = term match { | |
//case _ : Num => (??? : Exp[BaseInt]) | |
case _ : Num => (??? : Exp[T]) //also broken! | |
} | |
res | |
} | |
def derive3[T <: Type](term: Exp[T]): Exp[BaseInt] = { | |
term match { | |
case _ : Num => (??? : Exp[BaseInt]) | |
case _ : Num => (??? : Exp[T]) | |
} | |
} | |
def derive3Harder[T <: Type](term: Exp[T]): Exp[BaseInt] = { | |
val res = term match { | |
//case _ : Num => (??? : Exp[BaseInt]) | |
case _ : Num => (??? : Exp[T]) | |
} | |
res | |
} | |
} | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment