Skip to content

Instantly share code, notes, and snippets.

@praeclarum
Created May 26, 2016 05:06
Show Gist options
  • Save praeclarum/08ae7122c7ebd6bfab31fa9f52faf205 to your computer and use it in GitHub Desktop.
Save praeclarum/08ae7122c7ebd6bfab31fa9f52faf205 to your computer and use it in GitHub Desktop.
Basic matrix math in F#
module Drone.Control.Matrix
open System
type Matrix =
| MZero of int * int
| MEye of int * float
| MArray of float[,]
| MTranspose of Matrix
| MDiagonal of float[]
| MScaled of float * Matrix
| MMul of Matrix * Matrix
| MAdd of Matrix * Matrix
| MNegate of Matrix
| MColumn of int * Matrix
| MSub of int*int*int*int*Matrix
member private a.Mess (b : Matrix) =
if a.NumRows <> b.NumRows || a.NumColumns <> b.NumColumns then
failwithf "Mismatched matrices: %dx%d vs %dx%d" a.NumRows a.NumColumns b.NumRows b.NumColumns
else ()
static member ( * ) (s : float, m : Matrix) = MScaled (s, m)
static member ( * ) (a : Matrix, b : Matrix) = MMul (a, b)
static member ( + ) (a : Matrix, b : Matrix) = a.Mess b; MAdd (a, b)
static member ( ~- ) (a : Matrix) = match a with MNegate m -> m | _ -> MNegate (a)
static member ( - ) (a : Matrix, b : Matrix) = a.Mess b; MAdd (a, -b)
member m.Item (r, c) =
match m with
| MEye (_, v) -> if r = c then v else 0.0
| MArray a -> a.[r, c]
| MTranspose m -> m.[c, r]
| MZero _ -> 0.0
| MDiagonal a -> if r = c then a.[r] else 0.0
| MScaled (s, m) -> s * m.[r, c]
| MColumn (ci, m) -> m.[r, ci]
| MAdd (a, b) -> a.[r, c] + b.[r, c]
| MNegate m -> -m.[r, c]
| MSub (sr,sc,_,_,m) -> m.[sr + r, sc + c]
| MMul (a, b) ->
let alen = a.NumColumns
let blen = b.NumRows
if alen <> blen then
failwithf "Mismatched matrices trying to be multiplied: %dx%d vs %dx%d"
a.NumRows alen
blen b.NumColumns
let mutable s = 0.0
for i = 0 to alen - 1 do
s <- s + a.[r, i] * b.[i, c]
s
member m.NumRows : int =
match m with
| MEye (s, _) -> s
| MArray a -> a.GetLength (0)
| MTranspose m -> m.NumColumns
| MZero (r, _) -> r
| MDiagonal a -> a.Length
| MScaled (_, m) -> m.NumRows
| MColumn (_, m) -> m.NumRows
| MAdd (m, _) -> m.NumRows
| MNegate m -> m.NumRows
| MSub (_,_,r,_,_) -> r
| MMul (m, _) -> m.NumRows
member m.NumColumns : int =
match m with
| MEye (s, _) -> s
| MArray a -> a.GetLength (1)
| MTranspose m -> m.NumRows
| MZero (_, c) -> c
| MDiagonal a -> a.Length
| MScaled (_, m) -> m.NumColumns
| MColumn _ -> 1
| MAdd (m, _) -> m.NumColumns
| MNegate m -> m.NumColumns
| MSub (_,_,_,c,_) -> c
| MMul (_, m) -> m.NumColumns
override m.ToString () =
use w = new System.IO.StringWriter ()
w.Write "["
let mutable head = ""
for r = 0 to m.NumRows - 1 do
for c = 0 to m.NumColumns - 1 do
w.Write (sprintf "%s%.4f" head m.[r, c])
head <- ", "
head <- "; "
w.Write "]"
w.ToString ()
let mrows (m : Matrix) : int = m.NumRows
let mcols (m : Matrix) : int = m.NumColumns
let mget (m : Matrix) r c : float = m.[r, c]
let mt (m : Matrix) : Matrix =
match m with
| MZero (r, c) -> MZero (c, r)
| MTranspose mt -> mt
| MDiagonal _ -> m
| MEye _ -> m
| _ -> MTranspose m
let mcol (m : Matrix) ci : Matrix =
if ci >= mcols m then
failwithf "Cannot index column %d of %dx%d matrix."
ci
(mrows m) (mcols m)
MColumn (ci, m)
let cholesky (inp : float[,]) =
let n = inp.GetLength (0)
let res : float[,] = Array2D.zeroCreate n n
let factor i k =
let rec sum j =
if j = k then 0.0
else res.[i, j] * res.[k, j] + sum (j+1)
inp.[i, k] - sum 0
for col = 0 to n-1 do
res.[col, col] <- Math.Sqrt (factor col col)
for row = col+1 to n-1 do
res.[row,col] <- (factor row col) / res.[col, col]
MArray res
let marray (m : Matrix) : float[,] =
match m with
| MArray a -> a
| _ -> Array2D.init (mrows m) (mcols m) (mget m)
let minv (m : Matrix) : Matrix =
let mm = MathNet.Numerics.LinearAlgebra.Double.DenseMatrix.OfArray (m |> marray)
let im = mm.Inverse ()
MArray (im.ToArray ())
let msqrt (m : Matrix) : Matrix =
let mm = MathNet.Numerics.LinearAlgebra.Double.DenseMatrix.OfArray (m |> marray)
let sm = (mm.Cholesky ()).Factor
MArray (sm.ToArray ())
let mat (mats : Matrix seq seq) : Matrix =
let matas = mats |> Seq.map Array.ofSeq |> Array.ofSeq
let nr = matas |> Seq.map (fun x -> mrows x.[0]) |> Seq.sum
let nc = matas.[0] |> Seq.map mcols |> Seq.sum
let a : float[,] = Array2D.zeroCreate nr nc
let mutable ri = 0
for mr in matas do
let nr = mrows mr.[0]
let mutable ci = 0
for mc in mr do
let nc = mcols mc
for r = 0 to (nr-1) do
for c = 0 to (nc-1) do
a.[ri+r,ci+c] <- mget mc r c
ci <- ci + nc
ri <- ri + nr
MArray a
let mvec (es : float seq) : Matrix =
let eas = es |> Array.ofSeq
let a = Array2D.zeroCreate eas.Length 1
for i = 0 to (eas.Length-1) do a.[i,0] <- eas.[i]
MArray a
let meval (m : Matrix) = MArray (m |> marray)
let msub sr sc nr nc m = MSub (sr, sc, nr, nc, m)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment