Skip to content

Instantly share code, notes, and snippets.

@jxnl
Last active August 29, 2015 14:16
Show Gist options
  • Save jxnl/f2ad632fc48b90dda282 to your computer and use it in GitHub Desktop.
Save jxnl/f2ad632fc48b90dda282 to your computer and use it in GitHub Desktop.
made the mistake of not defining row length in the template....
package learning
object LinearAlgebra {
class Row(r: Double*) extends Seq[Double] {
val row: List[Double] = r.toList
val length = row.length
def * (that: Row): Double = {
require(that.length == length, "Sizes don't align")
(this.row zip that.row).map(t => t._1 * t._2).reduce(_+_)
}
def + (that: Row): Row = {
require(that.length == length, "Sizes don't align")
Row((this.row zip that.row).map(t => t._1 + t._2):_*)
}
def * (that: Matrix): Row = {
require(this.length == that.width, "Sizes don't align")
Row((for (i <- (0 until that.width)) yield (this * that.col(i))):_*)
}
def basis(B: Int*): Row = {
Row(B.map(_-1).map(this.row(_)):_*)
}
override def toString: String = "[" + row.mkString(" ") + "]"
def <= (that: Row): Boolean = (this.row zip that.row).forall(t => t._1 <= (t._2))
def >= (that: Row): Boolean = (this.row zip that.row).forall(t => t._1 >= t._2)
def < (that: Row): Boolean = (this.row zip that.row).forall(t => t._1 < t._2)
def > (that: Row): Boolean = (this.row zip that.row).forall(t => t._1 > t._2)
def - (that: Row): Row = Row(this.row.zip(that.row).map(t => t._1 - t._2):_*)
def * (that: Double): Row = Row(this.row.map(_ * that):_*)
def / (that: Double): Row = Row(this.row.map(_ / that):_*)
def + (that: Double): Row = Row(this.row.map(_ + that):_*)
def - (that: Double): Row = Row(this.row.map(_ - that):_*)
def iterator = row.iterator
def apply(idx: Int): Double = row(idx)
}
class Matrix(rs: Row*) extends Seq[Row] {
require(rs.forall(x => x.length == width), "Rows are not the same size")
val rows: List[Row] = rs.toList
val width = rs(0).length
val length = rows.length
val dim: Tuple2[Int, Int] = (width, length)
override def toString: String = (rows.map(_.toString)).mkString("\n")
def * (row: Row): Row = {
Row((for (r <- rows) yield (r * row)):_*)
}
def row(idx: Int): Row = {
apply(idx)
}
def col(idx: Int): Row = {
require(this.dim._1 >= idx)
Row((for (r <- rows) yield r(idx)):_*)
}
def T : Matrix = Matrix((0 until width).map(col(_)):_*)
def basis(B: Int*): Matrix = {
require(B.length == this.dim._2 && B.forall(_ <= width))
Matrix((for (r <- rows) yield (r.basis(B:_*))):_*)
}
def + (that: Matrix): Matrix = {
require(this.dim == that.dim, "Dimentions must be equal")
Matrix(
(for (i <- (0 until length))
yield (this(i) + that(i))):_*)
}
def * (that: Matrix): Matrix = {
require(this.dim._2 == this.dim._1, "Inner Dimentions must be equal")
// For every row,
Matrix(
(for (r <- rows)
yield Row(
(for (c <- 0 until width) yield (r * col(c))):_*))
:_*)
}
def apply(idx: Int): Row = rows(idx)
def apply(i:Int, j:Int):Double = rows(i)(j)
def iterator = rows.iterator
}
object Matrix {
def apply(data: Row*): Matrix= new Matrix(data.toList:_*)
}
object Row {
def apply(data: Double*): Row= new Row(data.toList:_*)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment