Last active
October 18, 2017 16:27
-
-
Save pshirshov/a58581e5910e33471b6eac38e52b69d8 to your computer and use it in GitHub Desktop.
HLIst and implicit based specializations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// hlist definition | |
sealed trait HList | |
case class HPair[A, B <: HList](head: A, tail: B) extends HList { | |
override def toString = Seq(head, tail).mkString(" ::: ") | |
} | |
case object HNil extends HList | |
// instantiation | |
object HListFactory { | |
implicit class ToHListEnd[H](v: H) { | |
def :::[T](b: T): HPair[T, HPair[H, HNil.type]] = HPair(b, HPair(v, HNil)) | |
} | |
implicit class ToHList[H <: HList](v: H) { | |
def :::[T](b: T): HPair[T, H] = HPair(b, v) | |
} | |
} | |
// Loop over the list based on specialization simulated through implicit resolution mechanism | |
trait IsEnumerable[L <: HList] { | |
def length(l: L): Int | |
} | |
object IsEnumerable { | |
implicit def hlistIsEnumerable0: IsEnumerable[HNil.type] = new IsEnumerable[HNil.type] { | |
override def length(l: HNil.type) = 0 | |
} | |
implicit def hlistIsEnumerableN[H0, T0 <: HList : IsEnumerable]: IsEnumerable[H0 HPair T0] = new IsEnumerable[H0 HPair T0] { | |
override def length(l: HPair[H0, T0]) = 1 + implicitly[IsEnumerable[T0]].length(l.tail) | |
} | |
def length[H <: HList : IsEnumerable](hl: H): Int = { | |
implicitly[IsEnumerable[H]].length(hl) | |
} | |
} | |
// basic list operations, one-way recursion | |
trait IsHPair[Lst <: HList] { | |
type H | |
type T <: HList | |
type L | |
type I <: HList | |
def head(l : Lst): H | |
def tail(l : Lst): T | |
def last(l : Lst): L | |
def init(l : Lst): I | |
def ==(l: Lst, other: Lst): Boolean | |
} | |
object IsHPair { | |
type Aux[Lst <: HList, H0, T0 <: HList] = IsHPair[Lst] {type H = H0; type T = T0} | |
implicit def hlistIsPair0[H0]: Aux[H0 HPair HNil.type, H0, HNil.type] = new IsHPair[H0 HPair HNil.type] { | |
override type H = H0 | |
override type T = HNil.type | |
override type L = H0 | |
override type I = HNil.type | |
override def init(l: HPair[H0, HNil.type]): I = HNil | |
override def head(l: HPair[H0, T]): H = l.head | |
override def tail(l: HPair[H0, T]): T = HNil | |
override def last(l: HPair[H0, T]): H = l.head | |
override def ==(l: HPair[H0, T], other: HPair[H0, HNil.type]): Boolean = other.head == l.head | |
} | |
implicit def hlistIsPairN[H0, T0 <: HList : IsHPair]: Aux[H0 HPair T0, H0, T0] = new IsHPair[H0 HPair T0] { | |
override type H = H0 | |
override type T = T0 | |
override type L = IsHPair[T0]#L | |
override type I = H0 HPair IsHPair[T0]#I | |
override def head(l: HPair[H0, T0]) = l.head | |
override def tail(l: HPair[H0, T0]) = l.tail | |
override def last(l: HPair[H0, T0]): L = implicitly[IsHPair[T0]].last(l.tail) | |
override def init(l: HPair[H0, T0]) = HPair(l.head, implicitly[IsHPair[T0]].init(l.tail)) | |
override def ==(l: HPair[H0, T0], other: HPair[H0, T0]): Boolean = l.head == other.head && implicitly[IsHPair[T0]].==(l.tail, other.tail) | |
} | |
implicit class HListPairOps[H <: HList : IsHPair](hl: H) { | |
type IP = IsHPair[H] | |
def hhead: IP#H = { | |
implicitly[IP].head(hl) | |
} | |
def htail: IP#T = { | |
implicitly[IP].tail(hl) | |
} | |
def hlast: IP#L = { | |
implicitly[IP].last(hl) | |
} | |
def hinit: IP#I = { | |
implicitly[IP].init(hl) | |
} | |
def ==(o: H): Boolean = { | |
implicitly[IP].==(hl, o) | |
} | |
} | |
} | |
// joining two lists by replacing HNil in Lst1 with Lst2 | |
trait IsJoinable[Lst1 <: HList, Lst2 <: HList] { | |
type J <: HList | |
def join(l: Lst1, other: Lst2): J | |
} | |
object IsJoinable { | |
type Aux[H0, T0 <: HList, Lst2 <: HList] = IsJoinable[HPair[H0, T0], Lst2] | |
implicit def joinableFromHNil[Lst2 <: HList]: IsJoinable[HNil.type, Lst2] = new IsJoinable[HNil.type, Lst2] { | |
override type J = Lst2 | |
override def join(l: HNil.type, other: Lst2): J = other | |
} | |
implicit def joinAux[H0, T0 <: HList, Lst2 <: HList](implicit ev: IsJoinable[T0, Lst2]): Aux[H0, T0, Lst2] = new IsJoinable[HPair[H0, T0], Lst2] { | |
override type J = HPair[H0, IsJoinable[T0, Lst2]#J] | |
override def join(l: HPair[H0, T0], other: Lst2): J = HPair(l.head, ev.join(l.tail, other)) | |
} | |
implicit class HListJoinOps[Lst1 <: HList](hl: Lst1) { | |
def hjoin[Lst2 <: HList](o: Lst2)(implicit ev: IsJoinable[Lst1, Lst2]): IsJoinable[Lst1, Lst2]#J = { | |
ev.join(hl, o) | |
} | |
} | |
} | |
// TODO: zip, hfold | |
import HListFactory._ | |
val l0 = 1 ::: HNil | |
val l1 = 1 ::: true | |
val l2 = 1 ::: true ::: 1.0 | |
val l2_1 = 2 ::: true ::: 1.0 | |
val l2_2 = 2 ::: true ::: 1.0 ::: HNil | |
assert(l2_1 == l2_2) | |
assert(l2 == l2) | |
assert(l2 != l2_1) | |
import IsEnumerable._ | |
assert(length(l0) == 1) | |
assert(length(l1) == 2) | |
assert(length(l2) == 3) | |
import IsHPair._ | |
assert(l0.hhead == 1) | |
assert(l1.hhead == 1) | |
assert(l2.hhead == 1) | |
assert(l0.htail == HNil) | |
assert(l1.htail == true ::: HNil) | |
assert(l2.htail == true ::: 1.0 ::: HNil) | |
assert(l0.hlast == 1) | |
assert(l1.hlast == true) | |
assert(l2.hlast == 1.0) | |
assert(l0.hinit == HNil) | |
assert(l1.hinit == 1 ::: HNil) | |
assert(l2.hinit == 1 ::: true ::: HNil) | |
import IsJoinable._ | |
assert(l2.hjoin(HNil) == 1 ::: true ::: 1.0) | |
assert(l2.hjoin(l2_2) == 1 ::: true ::: 1.0 ::: 2 ::: true ::: 1.0) | |
assert(HNil.hjoin(l2_2) == 2 ::: true ::: 1.0) | |
val f1: Int => String = _.toString | |
val f2: String => Double = _.toDouble | |
val f3: Double => String = d => s"=$d" | |
val lf1 = f1 ::: HNil | |
val lf2 = f1 ::: f2 ::: f3 ::: HNil |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment