Skip to content

Instantly share code, notes, and snippets.

@rjsvaljean
Last active February 8, 2016 19:44
Show Gist options
  • Select an option

  • Save rjsvaljean/7be5918f5a4b08738a72 to your computer and use it in GitHub Desktop.

Select an option

Save rjsvaljean/7be5918f5a4b08738a72 to your computer and use it in GitHub Desktop.
Implicit context passing between stages of a pipeline with shapeless
package com.xdotai
import shapeless._
import shapeless.ops.hlist.{Prepend, Align, Remove}
object TypedPipelineComposition {
import ShapelessExt._
// Given a `f` that takes a `HList` `In` and returns a `HList` `Out`
// And a `Context` `HList`
// This checks that all the inputs to `f` are included in `Context`,
// extracts `In` from `Context` and calls `f`
def step[In <: HList,
Out <: HList,
Context <: HList,
Temp <: HList](f: In => Out)
(context: Context)
(implicit
ev: Intersection.Aux[Context, In, Temp],
ev1: Align[Temp, In]) = f(context.intersect[In].align[In])
// NOTE: having `start` be a 0-arity function instead of just another HList fn is arbitrary
// Composes fns `start`, `s1` and `s2` checking that `sN`'s input is a subset of `start` ... `sN-1`'s output
// It then does the task of keep track of the complete context outputed by each of the stages and
// selecting the subset that's required for `sN` and callin it with that subset
def compose3[
Temp1 <: HList,
Temp2 <: HList,
In2 <: HList,
In3 <: HList,
Out1 <: HList,
Out2 <: HList,
Out1And2 <: HList,
Out3 <: HList](start: () => Out1,
s1: In2 => Out2,
s2: In3 => Out3)
(implicit
ev1: Intersection.Aux[Out1, In2, Temp1],
ev2: Align[Temp1, In2],
ev3: Prepend.Aux[Out2, Out1, Out1And2],
ev4: Intersection.Aux[Out1And2, In3, Temp2],
ev5: Align[Temp2, In3]
): Out3 = {
val v1 = start()
val v2 = step(s1)(v1)
val v1And2 = v2 ::: v1
val v3 = step(s2)(v1And2)
v3
}
def compose4[
Temp1 <: HList,
Temp2 <: HList,
Temp3 <: HList,
In2 <: HList,
In3 <: HList,
In4 <: HList,
Out1 <: HList,
Out2 <: HList,
Out1And2 <: HList,
Out3 <: HList,
Out1And2And3 <: HList,
Out4 <: HList](start: () => Out1,
s1: In2 => Out2,
s2: In3 => Out3,
s3: In4 => Out4)
(implicit
ev1: Intersection.Aux[Out1, In2, Temp1],
ev2: Align[Temp1, In2],
ev3: Prepend.Aux[Out2, Out1, Out1And2],
ev4: Intersection.Aux[Out1And2, In3, Temp2],
ev5: Align[Temp2, In3],
ev6: Prepend.Aux[Out3, Out1And2, Out1And2And3],
ev7: Intersection.Aux[Out1And2And3, In4, Temp3],
ev8: Align[Temp3, In4]
): Out4 = {
val v1 = start()
val v2 = step(s1)(v1)
val v1And2 = v2 ::: v1
val v3 = step(s2)(v1And2)
val v4 = step(s3)(v3 ::: v1And2)
v4
}
def compose5[
Temp1 <: HList,
Temp2 <: HList,
Temp3 <: HList,
Temp4 <: HList,
In2 <: HList,
In3 <: HList,
In4 <: HList,
In5 <: HList,
Out1 <: HList,
Out2 <: HList,
Out1And2 <: HList,
Out3 <: HList,
Out1And2And3 <: HList,
Out4 <: HList,
Out1And2And3And4 <: HList,
Out5 <: HList](start: () => Out1,
s1: In2 => Out2,
s2: In3 => Out3,
s3: In4 => Out4,
s4: In5 => Out5)
(implicit
ev1: Intersection.Aux[Out1, In2, Temp1],
ev2: Align[Temp1, In2],
ev3: Prepend.Aux[Out2, Out1, Out1And2],
ev4: Intersection.Aux[Out1And2, In3, Temp2],
ev5: Align[Temp2, In3],
ev6: Prepend.Aux[Out3, Out1And2, Out1And2And3],
ev7: Intersection.Aux[Out1And2And3, In4, Temp3],
ev8: Align[Temp3, In4],
ev9: Prepend.Aux[Out4, Out1And2And3, Out1And2And3And4],
ev10: Intersection.Aux[Out1And2And3And4, In5, Temp4],
ev11: Align[Temp4, In5]
): Out5 = {
val v1 = start()
val v2 = step(s1)(v1)
val v1And2 = v2 ::: v1
val v3 = step(s2)(v1And2)
val v1And2And3 = v3 ::: v1And2
val v4 = step(s3)(v1And2And3)
val v1And2And3And4 = v4 ::: v1And2And3
val v5 = step(s4)(v1And2And3And4)
v5
}
}
object ShapelessExt {
// TODO : Remove once https://github.com/milessabin/shapeless/commit/384ce3cadd5488c1d818ec4b3712379e3d993086 makes it to a release.
// As of this writing. The latest release is 2.2.5 which does not contain it
trait Intersection[L <: HList, M <: HList] extends DepFn1[L] with Serializable { type Out <: HList }
trait LowPriorityIntersection {
type Aux[L <: HList, M <: HList, Out0 <: HList] = Intersection[L, M] { type Out = Out0 }
implicit def hlistIntersection1[H, T <: HList, M <: HList]
(implicit i: Intersection[T, M]): Aux[H :: T, M, i.Out] =
new Intersection[H :: T, M] {
type Out = i.Out
def apply(l: H :: T): Out = i(l.tail)
}
}
object Intersection extends LowPriorityIntersection {
def apply[L <: HList, M <: HList](implicit intersection: Intersection[L, M]): Aux[L, M, intersection.Out] = intersection
implicit def hnilIntersection[M <: HList]: Aux[HNil, M, HNil] =
new Intersection[HNil, M] {
type Out = HNil
def apply(l: HNil): Out = HNil
}
implicit def hlistIntersection2[H, T <: HList, M <: HList, MR <: HList]
(implicit
r: Remove.Aux[M, H, (H, MR)],
i: Intersection[T, MR]
): Aux[H :: T, M, H :: i.Out] =
new Intersection[H :: T, M] {
type Out = H :: i.Out
def apply(l: H :: T): Out = l.head :: i(l.tail)
}
}
implicit class ShaplessHlistExtOps[L <: HList](l: L) {
def intersect[M <: HList](implicit intersection: Intersection[L, M]): intersection.Out = intersection(l)
}
}
package com.xdotai
import shapeless._
import org.scalatest.{ShouldMatchers, FlatSpec}
class TypedPipelineCompositionSpec extends FlatSpec with ShouldMatchers {
import TypedPipelineComposition._
it should "compose3" in {
compose3(stage1, stage2, stage3) should be(D("bd") :: HNil)
}
it should "compose4" in {
compose4(stage1, stage2, stage3, stage4) should be(E("A + B + C + D") :: HNil)
}
it should "compose5" in {
compose5(stage1, stage2, stage3, stage4, stage5) should be(F("A + B + C + D + E") :: HNil)
}
case class A(v: String)
case class B(v: String)
case class C(v: String)
case class D(v: String)
case class E(v: String)
case class F(v: String)
val stage1: () => A :: B :: HNil =
() => A("a") :: B("b") :: HNil
val stage2: (A :: HNil) => C :: HNil =
(a: A :: HNil) => C(a.head.v + "c") :: HNil
val stage3: (B :: HNil) => D :: HNil =
(b: B :: HNil) => D(b.head.v + "d") :: HNil
val stage4: A :: B :: C :: D :: HNil => E :: HNil = { in =>
val (a: A, b: B, c: C, d: D) = in.tupled
E(s"${a.getClass.getSimpleName} + ${b.getClass.getSimpleName} + ${c.getClass.getSimpleName} + ${d.getClass.getSimpleName}") :: HNil
}
val stage5: A :: B :: C :: D :: E :: HNil => F :: HNil = { in =>
val (a: A, b: B, c: C, d: D, e: E) = in.tupled
F(s"${a.getClass.getSimpleName} + ${b.getClass.getSimpleName} + ${c.getClass.getSimpleName} + ${d.getClass.getSimpleName} + ${e.getClass.getSimpleName}") :: HNil
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment