Created
October 17, 2017 05:56
-
-
Save avibryant/00bc2ca495766f86c62b2573263cd26e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//A type-level implementation of the broadcasting rules from NumPy, | |
//such that incompatible shapes are a compile-time error. | |
//see eg http://scipy.github.io/old-wiki/pages/EricsBroadcastingDoc | |
//for more info on what the constraints this is enccoding | |
//Note: this is only the shape logic. It does not include, | |
//and is agnostic to, any particular multi-dimensional array implementation. | |
sealed trait Shape | |
sealed trait Dimension extends Shape | |
case class One() extends Dimension { | |
def by[X<:Shape](inner: X) = By(this, inner) | |
} | |
case class N[A]() extends Dimension { | |
def by[X<:Shape](inner: X) = By(this, inner) | |
} | |
case class By[D<:Dimension,X<:Shape](outer: D, inner: X) extends Shape | |
object Shape { | |
def broadcast[X <: Shape,Y <: Shape,Z <: Shape](x: X, y: Y)(implicit b: Broadcaster[X,Y,Z]): Z | |
= b(x,y) | |
def newAxis[X <: Shape, Z <: Shape](x: X)(implicit n: NewAxis[X,Z]): Z = n(x) | |
} | |
trait Broadcaster[X <: Shape,Y <: Shape,Z <: Shape] { | |
def apply(x: X, y: Y): Z | |
} | |
trait BroadcasterLowPriority { | |
implicit def one2one = new Broadcaster[One,One,One] { | |
def apply(x: One, y: One) = x | |
} | |
implicit def one2n[A] = new Broadcaster[One,N[A],N[A]] { | |
def apply(x: One, y: N[A]) = y | |
} | |
implicit def n2one[A] = new Broadcaster[N[A],One,N[A]] { | |
def apply(x: N[A], y: One) = x | |
} | |
implicit def n2n[A] = new Broadcaster[N[A],N[A],N[A]] { | |
def apply(x: N[A], y: N[A]) = x | |
} | |
implicit def leftInner[X <: Dimension, B<:Shape, C<:Shape, D<:Shape]( | |
implicit innerBroadcaster: Broadcaster[B,C,D] | |
) = new Broadcaster[By[X,B], C, By[X,D]] { | |
def apply(x: By[X,B], y: C) = By(x.outer, innerBroadcaster(x.inner,y)) | |
} | |
implicit def rightInner[X <: Dimension, B<:Shape, C<:Shape, D<:Shape]( | |
implicit innerBroadcaster: Broadcaster[B,C,D] | |
) = new Broadcaster[B,By[X,C],By[X,D]] { | |
def apply(x: B, y: By[X,C]) = By(y.outer, innerBroadcaster(x,y.inner)) | |
} | |
} | |
object Broadcaster extends BroadcasterLowPriority { | |
implicit def outerInner[X <: Dimension, Y <: Dimension, Z <: Dimension, B<:Shape, C<:Shape, D<:Shape]( | |
implicit outerBroadcaster: Broadcaster[X,Y,Z], | |
innerBroadcaster: Broadcaster[B,C,D] | |
) = new Broadcaster[By[X,B],By[Y,C],By[Z,D]] { | |
def apply(x: By[X,B], y: By[Y,C]) = | |
By(outerBroadcaster(x.outer,y.outer), innerBroadcaster(x.inner,y.inner)) | |
} | |
} | |
trait NewAxis[X <: Shape, Z <: Shape] { | |
def apply(x: X): Z | |
} | |
object NewAxis { | |
implicit def one: NewAxis[One, By[One,One]] = new NewAxis[One, By[One,One]] { | |
def apply(x: One) = x.by(One()) | |
} | |
implicit def n[A]: NewAxis[N[A], By[N[A],One]] = new NewAxis[N[A], By[N[A],One]] { | |
def apply(x: N[A]) = x.by(One()) | |
} | |
implicit def by[D <: Dimension, X <: Shape, Z <: Shape]( | |
implicit innerNewAxis: NewAxis[X,Z] | |
) = new NewAxis[By[D,X], By[D,Z]] { | |
def apply(x: By[D,X]) = By(x.outer, innerNewAxis(x.inner)) | |
} | |
} | |
object Example { | |
def example[Foo,Bar] = { | |
val scalar = One() | |
val vector = N[Foo] | |
val vector2 = N[Bar] | |
val z1: One = Shape.broadcast(scalar, scalar) | |
val z2: N[Foo] = Shape.broadcast(vector, vector) | |
val z3: N[Foo] = Shape.broadcast(scalar, vector) | |
val z4: N[Foo] = Shape.broadcast(vector, scalar) | |
//fails to compile | |
//val z5 = Shape.broadcast(vector, vector2) | |
val matrix1: One By N[Foo] = scalar by vector | |
val matrix2: N[Bar] By N[Foo] = vector2 by vector | |
val matrix3: N[Foo] By N[Foo] = vector by vector | |
val z6: One By N[Foo] = Shape.broadcast(matrix1, matrix1) | |
val z7: N[Bar] By N[Foo] = Shape.broadcast(matrix2, matrix2) | |
val z8: N[Bar] By N[Foo] = Shape.broadcast(matrix1, matrix2) | |
//fails to compile | |
//val z9 = Shape.broadcast(matrix2, matrix3) | |
val z10: One By N[Foo] = Shape.broadcast(vector, matrix1) | |
val z11: N[Bar] By N[Foo] = Shape.broadcast(vector, matrix2) | |
val z12: N[Foo] By N[Foo] = Shape.broadcast(vector, matrix3) | |
//fails to compile | |
//val z13 = Shape.broadcast(vector2, matrix1) | |
val z14: One By N[Foo] = Shape.broadcast(scalar, matrix1) | |
val z15: N[Bar] By N[Foo] = Shape.broadcast(scalar, matrix2) | |
val z16: N[Foo] By N[Foo] = Shape.broadcast(scalar, matrix3) | |
val z17: One By N[Foo] = Shape.broadcast(matrix1, scalar) | |
val z18: N[Bar] By N[Foo] = Shape.broadcast(matrix2, scalar) | |
val z19: N[Foo] By N[Foo] = Shape.broadcast(matrix3, scalar) | |
val z20: N[Foo] By (N[Foo] By One) = Shape.newAxis(matrix3) | |
val z21: N[Foo] By (N[Foo] By (One By One)) = Shape.newAxis(z20) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment