Skip to content

Instantly share code, notes, and snippets.

@steinwaywhw
Created March 5, 2014 15:07
Show Gist options
  • Save steinwaywhw/9369033 to your computer and use it in GitHub Desktop.
Save steinwaywhw/9369033 to your computer and use it in GitHub Desktop.
Matrix in ATS
#include "share/atspre_staload.hats"
typedef matrix (a:t@ype) = mtrxszref (a)
extern fun {a:t@ype} matrix_transpose (matrix (a)): matrix (a)
extern fun {a:t@ype} foreach (matrix (a), (a, size_t, size_t) -<cloref1> void): matrix (a)
extern fun matrix_print (matrix (int)): void
implement {a} foreach (m, f) = let
val nrow = mtrxszref_get_nrow (m)
val ncol = mtrxszref_get_ncol (m)
fun do_row (row: size_t):<cloref1> void =
if row >= nrow then ()
else let
val _ = do_col (row, size_of_int (0))
in
do_row (row + size_of_int (1))
end
and do_col (row: size_t, col: size_t):<cloref1> void =
if col >= ncol then ()
else let
val _ = f (m[row, col], row, col)
in
do_col (row, col + size_of_int (1))
end
val _ = do_row (size_of_int (0))
in
m
end
implement matrix_print (m) = () where {
val ncol = mtrxszref_get_ncol (m)
val _ = foreach<int> (m, lam (e, r, c) =<cloref1> () where {
val _ = fprint (stdout_ref, e)
val _ = if c = ncol - size_of_int (1) then println! ()
})
}
implement {a} matrix_transpose (m) = let
val nrow = mtrxszref_get_nrow (m)
val ncol = mtrxszref_get_ncol (m)
val n = mtrxszref_make_elt (ncol, nrow, m[0,0])
val _ = foreach (m, lam (e, r, c) =<cloref1> () where {val _ = n[c, r] := e})
in
n
end
implement main0 () = () where {
val m = mtrxszref_make_elt<int> (size_of_int (5), size_of_int (8), 0)
val _ = matrix_print (matrix_transpose (m))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment