Last active
January 11, 2024 14:50
-
-
Save biboudis/3c4823d905dbadd6d61a7ce3fa47ff64 to your computer and use it in GitHub Desktop.
Type your matrices for great good
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.compiletime.ops._ | |
import scala.compiletime.ops.int._ | |
import scala.compiletime.ops.any._ | |
/** | |
* Type your matrices for great good: a Haskell library of typed matrices and applications (functional pearl) | |
* https://dl.acm.org/doi/10.1145/3406088.3409019 | |
*/ | |
object Test { | |
object Internal { | |
enum Matrix[E, C, R] { | |
case One[E](e: E) extends Matrix[E, Unit, Unit] | |
case Join[E, R, A, B](m1: Matrix[E, A, R], m2: Matrix[E, B, R]) extends Matrix[E, Either[A, B], R] | |
case Fork[E, C, A, B](m1: Matrix[E, C, A], m2: Matrix[E, C, B]) extends Matrix[E, C, Either[A, B]] | |
} | |
type Count[D] <: Int = D match { | |
case Null => 0 | |
case Unit => 1 | |
case Either[a, b] => a + b | |
case (a, b) => a * b | |
} | |
type FromNat[N <: Int] = N match { | |
case 0 => Null | |
case 1 => Unit | |
case N => FromNatB[N % 2 == 0, FromNat[N / 2]] | |
} | |
type FromNatB[B <: Boolean, M] = B match { | |
case true => Either[M, M] | |
case false => Either[Unit, Either[M, M]] | |
} | |
type Normalize[D] = D match { | |
case Either[a, b] => Either[Normalize[a], Normalize[b]] | |
case D => FromNat[Count[D]] | |
} | |
def abideJF[Cols, Rows](m: Matrix[Int, Cols, Rows]): Matrix[Int, Cols, Rows] = { | |
import Matrix._ | |
m match { | |
case Join(Fork(a, c), Fork(b, d)) => Fork(Join(abideJF(a), abideJF(b)), (Join(abideJF(c), abideJF(d)))) | |
case One(e) => One(e) | |
case Join(a, b) => Join(abideJF(a), abideJF(b)) | |
case Fork(a, b) => Fork(abideJF(a), abideJF(b)) | |
} | |
} | |
def zipWith[Cols, Rows](f: Int => Int => Int, m1: Matrix[Int, Cols, Rows], m2: Matrix[Int, Cols, Rows]): Matrix[Int, Cols, Rows] = { | |
import Matrix._ | |
(m1, m2) match { | |
case (One(a), One(b)) => One(f(a)(b)) | |
case (Join(a, b), Join(c, d)) => Join(zipWith(f, a, c), zipWith(f, b, d)) | |
case (Fork(a, b), Fork(c, d)) => Fork(zipWith(f, a, c), zipWith(f, b, d)) | |
case (x@Fork, y@Join) => zipWith(f, x, abideJF(y)) | |
case (x@Join, y@Fork) => zipWith(f, abideJF(x), y) | |
} | |
} | |
extension [Rows, Cols](m1: Matrix[Int, Rows, Cols]) | |
def + (m2: Matrix[Int, Rows, Cols]) = zipWith(a => b => a + b, m1, m2) | |
def comp[CR, Rows, Cols](m1: Matrix[Int, CR, Rows], m2: Matrix[Int, Cols, CR]): Matrix[Int, Cols, Rows] = { | |
import Matrix._ | |
(m1, m2) match { | |
case (One(a), One(b)) => One(a * b) | |
case (Join(a, b), Fork(c, d)) => comp(a, c) + comp(b, d) | |
case (Fork(a, b), c) => Fork(comp(a, c), comp(b, c)) | |
case (c, Join(a, b)) => Join(comp(c, a), comp(c, b)) | |
} | |
} | |
trait FromLists[Cols, Rows] { | |
def fromLists(arr: Array[Array[Int]]): Matrix[Int, Cols, Rows] | |
} | |
} | |
opaque type Matrix[Cols <: Int, Rows <: Int] = Internal.Matrix[Int, Internal.FromNat[Cols], Internal.FromNat[Rows]] | |
trait Tests { | |
import Internal.Matrix._ | |
def iden2x2(e: Int): Matrix[2, 2] = { | |
Fork(Join(One(1), One(0)), | |
Join(One(0), One(1))) | |
} | |
def ones1x3(e: Int): Matrix[1, 3] = { | |
Fork(One(1), Fork(One(1), One(1))) | |
} | |
def ones3x1(e: Int): Matrix[3, 1] = { | |
Join(One(1), Join(One(1), One(1))) | |
} | |
def iden3x3(e: Int): Matrix[3, 3] = { | |
Fork(Join(One(1), Join(One(0), One(0))), | |
Fork(Join(One(0), Join(One(1), One(0))), | |
Join(One(0), Join(One(0), One(1))))) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment