Last active
July 22, 2020 10:54
-
-
Save mclements/560be1c891044183d4b83dd2ad8fb5bd to your computer and use it in GitHub Desktop.
Hack to use MLton with BLAS using the underlying data structure for Array2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
infix 7 *! | |
infix 6 +! -! | |
structure SeqIndex = struct | |
open Int | |
val op +! = Int.+ | |
val op -! = Int.- | |
val op *! = Int.* | |
local | |
fun ltu (lhs, rhs) = | |
case (compare (lhs, 0), compare (rhs, 0)) of | |
(LESS, LESS) => lhs < rhs | |
| (LESS, GREATER) => false | |
| (_, EQUAL) => false | |
| (EQUAL, _) => true | |
| (GREATER, LESS) => true | |
| (GREATER, GREATER) => lhs < rhs | |
structure S = IntegralComparisons(type t = Int.int | |
val (op <) = ltu) | |
in | |
val ltu = S.< | |
val leu = S.<= | |
val gtu = S.> | |
val geu = S.>= | |
end | |
fun toIntUnsafe i = i | |
fun fromIntUnsafe i = i | |
end | |
structure Primitive = struct | |
structure V = Vector | |
open Primitive | |
structure Array = struct | |
open Array | |
val alloc = Unsafe.Array.alloc | |
val new = Unsafe.Array.create | |
val unsafeSub = Unsafe.Array.sub | |
val unsafeUpdate = Unsafe.Array.update | |
structure Slice = ArraySlice | |
end | |
structure Vector = V | |
end | |
signature ARRAY2EXT = sig | |
include ARRAY2 | |
val getArray : 'a array -> 'a Array.array | |
val makeArray : 'a Array.array * int * int -> 'a array | |
end | |
structure Array2ext : ARRAY2EXT = (* new name, new signature, 2 new functions *) | |
struct | |
val op +! = SeqIndex.+! | |
val op + = SeqIndex.+ | |
val op -! = SeqIndex.-! | |
val op - = SeqIndex.- | |
val op *! = SeqIndex.*! | |
val op * = SeqIndex.* | |
val op < = SeqIndex.< | |
val op <= = SeqIndex.<= | |
val op > = SeqIndex.> | |
val op >= = SeqIndex.>= | |
val ltu = SeqIndex.ltu | |
val leu = SeqIndex.leu | |
val gtu = SeqIndex.gtu | |
val geu = SeqIndex.geu | |
type 'a array = {array: 'a Array.array, | |
rows: SeqIndex.int, | |
cols: SeqIndex.int} | |
(* new functions *) | |
fun 'a getArray({array, ...} : 'a array) = array | |
fun 'a makeArray(a, m, n) = {array=a, rows=m, cols=n} | |
fun dimensions' ({rows, cols, ...}: 'a array) = (rows, cols) | |
fun dimensions ({rows, cols, ...}: 'a array) = | |
(SeqIndex.toIntUnsafe rows, SeqIndex.toIntUnsafe cols) | |
fun nRows' ({rows, ...}: 'a array) = rows | |
fun nRows ({rows, ...}: 'a array) = SeqIndex.toIntUnsafe rows | |
fun nCols' ({cols, ...}: 'a array) = cols | |
fun nCols ({cols, ...}: 'a array) = SeqIndex.toIntUnsafe cols | |
type 'a region = {base: 'a array, | |
row: int, | |
col: int, | |
nrows: int option, | |
ncols: int option} | |
local | |
fun checkSliceMax' (start: int, | |
num: SeqIndex.int option, | |
max: SeqIndex.int): SeqIndex.int * SeqIndex.int = | |
case num of | |
NONE => if Primitive.Controls.safe | |
then let | |
val start = | |
(SeqIndex.fromInt start) | |
handle Overflow => raise Subscript | |
in | |
if gtu (start, max) | |
then raise Subscript | |
else (start, max) | |
end | |
else (SeqIndex.fromIntUnsafe start, max) | |
| SOME num => if Primitive.Controls.safe | |
then let | |
val start = | |
(SeqIndex.fromInt start) | |
handle Overflow => raise Subscript | |
in | |
if (start < 0 orelse num < 0 | |
orelse start +! num > max) | |
then raise Subscript | |
else (start, start +! num) | |
end | |
else (SeqIndex.fromIntUnsafe start, | |
SeqIndex.fromIntUnsafe start +! num) | |
fun checkSliceMax (start: int, | |
num: int option, | |
max: SeqIndex.int): SeqIndex.int * SeqIndex.int = | |
if Primitive.Controls.safe | |
then (checkSliceMax' (start, Option.map SeqIndex.fromInt num, max)) | |
handle Overflow => raise Subscript | |
else checkSliceMax' (start, Option.map SeqIndex.fromIntUnsafe num, max) | |
in | |
fun checkRegion' {base, row, col, nrows, ncols} = | |
let | |
val (rows, cols) = dimensions' base | |
val (startRow, stopRow) = checkSliceMax' (row, nrows, rows) | |
val (startCol, stopCol) = checkSliceMax' (col, ncols, cols) | |
in | |
{startRow = startRow, stopRow = stopRow, | |
startCol = startCol, stopCol = stopCol} | |
end | |
fun checkRegion {base, row, col, nrows, ncols} = | |
let | |
val (rows, cols) = dimensions' base | |
val (startRow, stopRow) = checkSliceMax (row, nrows, rows) | |
val (startCol, stopCol) = checkSliceMax (col, ncols, cols) | |
in | |
{startRow = startRow, stopRow = stopRow, | |
startCol = startCol, stopCol = stopCol} | |
end | |
end | |
fun wholeRegion (a : 'a array): 'a region = | |
{base = a, row = 0, col = 0, nrows = NONE, ncols = NONE} | |
datatype traversal = RowMajor | ColMajor | |
local | |
fun make (rows, cols, doit) = | |
if Primitive.Controls.safe | |
andalso (rows < 0 orelse cols < 0) | |
then raise Size | |
else {array = doit (rows * cols handle Overflow => raise Size), | |
rows = rows, | |
cols = cols} | |
in | |
fun alloc' (rows, cols) = | |
make (rows, cols, Primitive.Array.alloc) | |
fun array' (rows, cols, init) = | |
make (rows, cols, fn size => Primitive.Array.new (size, init)) | |
end | |
local | |
fun make (rows, cols, doit) = | |
if Primitive.Controls.safe | |
then let | |
val rows = | |
(SeqIndex.fromInt rows) | |
handle Overflow => raise Size | |
val cols = | |
(SeqIndex.fromInt cols) | |
handle Overflow => raise Size | |
in | |
doit (rows, cols) | |
end | |
else doit (SeqIndex.fromIntUnsafe rows, | |
SeqIndex.fromIntUnsafe cols) | |
in | |
fun alloc (rows, cols) = | |
make (rows, cols, fn (rows, cols) => alloc' (rows, cols)) | |
fun array (rows, cols, init) = | |
make (rows, cols, fn (rows, cols) => array' (rows, cols, init)) | |
end | |
fun array0 (): 'a array = | |
{array = Primitive.Array.alloc 0, | |
rows = 0, | |
cols = 0} | |
fun unsafeSpot' ({cols, ...}: 'a array, r, c) = | |
r *! cols +! c | |
fun spot' (a as {rows, cols, ...}: 'a array, r, c) = | |
if Primitive.Controls.safe | |
andalso (geu (r, rows) orelse geu (c, cols)) | |
then raise Subscript | |
else unsafeSpot' (a, r, c) | |
fun unsafeSub' (a as {array, ...}: 'a array, r, c) = | |
Primitive.Array.unsafeSub (array, unsafeSpot' (a, r, c)) | |
fun sub' (a as {array, ...}: 'a array, r, c) = | |
Primitive.Array.unsafeSub (array, spot' (a, r, c)) | |
fun unsafeUpdate' (a as {array, ...}: 'a array, r, c, x) = | |
Primitive.Array.unsafeUpdate (array, unsafeSpot' (a, r, c), x) | |
fun update' (a as {array, ...}: 'a array, r, c, x) = | |
Primitive.Array.unsafeUpdate (array, spot' (a, r, c), x) | |
local | |
fun make (r, c, doit) = | |
if Primitive.Controls.safe | |
then let | |
val r = | |
(SeqIndex.fromInt r) | |
handle Overflow => raise Subscript | |
val c = | |
(SeqIndex.fromInt c) | |
handle Overflow => raise Subscript | |
in | |
doit (r, c) | |
end | |
else doit (SeqIndex.fromIntUnsafe r, | |
SeqIndex.fromIntUnsafe c) | |
in | |
fun sub (a, r, c) = | |
make (r, c, fn (r, c) => sub' (a, r, c)) | |
fun update (a, r, c, x) = | |
make (r, c, fn (r, c) => update' (a, r, c, x)) | |
end | |
fun 'a fromList (rows: 'a list list): 'a array = | |
case rows of | |
[] => array0 () | |
| row1 :: _ => | |
let | |
val cols = length row1 | |
val a as {array, cols = cols', ...} = | |
alloc (length rows, cols) | |
val _ = | |
List.foldl | |
(fn (row: 'a list, i) => | |
let | |
val max = i +! cols' | |
val i' = | |
List.foldl (fn (x: 'a, i) => | |
(if i >= max | |
then raise Size | |
else (Primitive.Array.unsafeUpdate (array, i, x) | |
; i +! 1))) | |
i row | |
in if i' = max | |
then i' | |
else raise Size | |
end) | |
0 rows | |
in | |
a | |
end | |
fun row' ({array, rows, cols}, r) = | |
if Primitive.Controls.safe andalso geu (r, rows) | |
then raise Subscript | |
else | |
ArraySlice.vector (Primitive.Array.Slice.slice (array, r *! cols, SOME cols)) | |
fun row (a, r) = | |
if Primitive.Controls.safe | |
then let | |
val r = | |
(SeqIndex.fromInt r) | |
handle Overflow => raise Subscript | |
in | |
row' (a, r) | |
end | |
else row' (a, SeqIndex.fromIntUnsafe r) | |
fun column' (a as {rows, cols, ...}: 'a array, c) = | |
if Primitive.Controls.safe andalso geu (c, cols) | |
then raise Subscript | |
else | |
Primitive.Vector.tabulate (rows, fn r => unsafeSub' (a, r, c)) | |
fun column (a, c) = | |
if Primitive.Controls.safe | |
then let | |
val c = | |
(SeqIndex.fromInt c) | |
handle Overflow => raise Subscript | |
in | |
column' (a, c) | |
end | |
else column' (a, SeqIndex.fromIntUnsafe c) | |
fun foldi' trv f b (region as {base, ...}) = | |
let | |
val {startRow, stopRow, startCol, stopCol} = checkRegion region | |
in | |
case trv of | |
RowMajor => | |
let | |
fun loopRow (r, b) = | |
if r >= stopRow then b | |
else let | |
fun loopCol (c, b) = | |
if c >= stopCol then b | |
else loopCol (c +! 1, f (r, c, sub' (base, r, c), b)) | |
in | |
loopRow (r +! 1, loopCol (startCol, b)) | |
end | |
in | |
loopRow (startRow, b) | |
end | |
| ColMajor => | |
let | |
fun loopCol (c, b) = | |
if c >= stopCol then b | |
else let | |
fun loopRow (r, b) = | |
if r >= stopRow then b | |
else loopRow (r +! 1, f (r, c, sub' (base, r, c), b)) | |
in | |
loopCol (c +! 1, loopRow (startRow, b)) | |
end | |
in | |
loopCol (startCol, b) | |
end | |
end | |
fun foldi trv f b a = | |
foldi' trv (fn (r, c, x, b) => | |
f (SeqIndex.toIntUnsafe r, | |
SeqIndex.toIntUnsafe c, | |
x, b)) b a | |
fun fold trv f b a = | |
foldi trv (fn (_, _, x, b) => f (x, b)) b (wholeRegion a) | |
fun appi trv f = | |
foldi trv (fn (r, c, x, ()) => f (r, c, x)) () | |
fun app trv f = fold trv (f o #1) () | |
fun modifyi trv f (r as {base, ...}) = | |
appi trv (fn (r, c, x) => update (base, r, c, f (r, c, x))) r | |
fun modify trv f a = modifyi trv (f o #3) (wholeRegion a) | |
fun tabulate trv (rows, cols, f) = | |
let | |
val a = alloc (rows, cols) | |
val () = modifyi trv (fn (r, c, _) => f (r, c)) (wholeRegion a) | |
in | |
a | |
end | |
fun copy {src = src as {base, ...}: 'a region, | |
dst, dst_row, dst_col} = | |
let | |
val {startRow, stopRow, startCol, stopCol} = checkRegion src | |
val nrows = stopRow -! startRow | |
val ncols = stopCol -! startCol | |
val {startRow = dst_row, startCol = dst_col, ...} = | |
checkRegion' {base = dst, row = dst_row, col = dst_col, | |
nrows = SOME nrows, | |
ncols = SOME ncols} | |
fun forUp (start, stop, f: SeqIndex.int -> unit) = | |
let | |
fun loop i = | |
if i >= stop | |
then () | |
else (f i; loop (i + 1)) | |
in loop start | |
end | |
fun forDown (start, stop, f: SeqIndex.int -> unit) = | |
let | |
fun loop i = | |
if i < start | |
then () | |
else (f i; loop (i - 1)) | |
in loop (stop -! 1) | |
end | |
val forRows = if startRow <= dst_row then forDown else forUp | |
val forCols = if startCol <= dst_col then forUp else forDown | |
in forRows (0, nrows, fn r => | |
forCols (0, ncols, fn c => | |
unsafeUpdate' (dst, dst_row +! r, dst_col +! c, | |
unsafeSub' (base, startRow +! r, startCol +! c)))) | |
end | |
end | |
local | |
datatype cblasTranspose = NoTrans | Trans | ConjTrans | ConjNoTrans | |
fun cblasOrder Array2ext.RowMajor = 101 | |
| cblasOrder Array2ext.ColMajor = 102 | |
fun cblasTranspose NoTrans = 111 | |
| cblasTranspose Trans = 112 | |
| cblasTranspose ConjTrans = 113 | |
| cblasTranspose ConjNoTrans = 114 | |
val call = _import "cblas_dgemm" public: int * int * int * int * int * int * real * real Vector.vector * int * real Vector.vector * int * real * real Array.array * int -> unit; | |
in | |
fun matmul3(a : real Array2ext.array, | |
b : real Array2ext.array) : real Array2ext.array = | |
let | |
open Array2ext | |
val ((m,k), (k',n)) = (dimensions a, dimensions b) | |
val () = if k <> k' then raise General.Size else () | |
val arrayc = Array.array(m*n,0.0) | |
val getVector = Array.vector o getArray | |
val _ = call(cblasOrder RowMajor, cblasTranspose NoTrans, cblasTranspose NoTrans, m, n, k, 1.0, getVector a, k, getVector b, n, 0.0, arrayc, n) | |
in | |
makeArray (arrayc, m, n) | |
end | |
end; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment