Last active
February 19, 2018 15:28
-
-
Save Octachron/4e833a22844fd90cd6d15b0af927dab0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type scalar = float | |
type 'a one = [`one of 'a] | |
type 'a z = [`zero of 'a] | |
type 'a two = [`two of 'a] | |
type 'a three = [`three of 'a] | |
type 'a four = [`three of 'a] | |
let map2 f x y = Array.init (min (Array.length x) (Array.length y)) | |
(fun n -> f x.(n) y.(n)) | |
type ('a,'b,'c) any = | |
[< `zero of 'b & 'a | `one of 'b & 'a | `two of 'b & 'a] as 'c | |
type ('a, 'b,'c, 'parameters) product = | |
[<`zero of 'b & (* scalar broadcasting *) | |
[< `zero of 'c & 'p1 z | |
| `one of 'c & 'p1 one | |
| `two of 'c & 'p1 two] | |
| `one of 'b & | |
[< `zero of 'c & 'p1 one | |
| `one of 'c & 'p1 one | |
| `two of 'c & 'p1 one] | |
| `two of 'b & | |
[< `zero of 'c & 'p1 two | |
| `one of 'c & 'p1 one | |
| `two of 'c & 'p1 two] | |
] as 'a | |
constraint 'parameters = 'p1 * 'p2 * 'p3 | |
(** (x,y,z,_ ) product computes the rank of x * y and | |
put the result inside z *) | |
type ('a, 'b,'c, 'parameters) sum = | |
[<`zero of 'b & (* scalar broadcasting *) | |
[< `zero of 'c & 'p1 z | |
| `one of 'c & 'p1 one | |
| `two of 'c & 'p1 two] | |
| `one of 'b & | |
[< `zero of 'c & 'p1 one | |
| `one of 'c & 'p1 one ] | |
| `two of 'b & | |
[< `zero of 'c & 'p1 two | |
| `two of 'c & 'p1 two] | |
] as 'a | |
constraint 'parameters = 'p1 * 'p2 * 'p3 | |
(** (x,y,z,_ ) sum computes the rank of x + y and | |
put the result inside z *) | |
exception Unexpected_matrix_dimension | |
exception Unexpected_ranks of int * int | |
module Phantom: sig | |
type (+'dim,+'rank) t | |
type +'x scalar = ('a, 'b z) t constraint 'x = 'a * 'b | |
type +'x vec2 = ('a two,'b one) t constraint 'x = 'a * 'b | |
type +'x vec3 = ('a three,'b one) t constraint 'x = 'a * 'b | |
type +'x vec4 = ('a four,'b one) t constraint 'x = 'a * 'b | |
type +'x mat2 = ('a two,'b two) t constraint 'x = 'a * 'b | |
type +'x mat3 = ('a three,'b two) t constraint 'x = 'a * 'b | |
type +'x mat4 = ('a four,'b two) t constraint 'x = 'a * 'b | |
val scalar: float -> _ scalar | |
val vec2: float -> float -> _ vec2 | |
val vec3: float -> float -> float -> ([`one],[`three]) t | |
val vec4: float -> float -> float -> float -> ([`one],[`three]) t | |
val mat2: _ vec2 -> _ vec2 -> _ mat2 | |
val mat3: _ vec3 -> _ vec3 -> _ vec3 -> _ mat3 | |
val mat4: _ vec4 -> _ vec4 -> _ vec4 -> _ vec4 -> _ mat4 | |
val (+): ('a,('rank1,'rank2,'rank3,_) sum ) t | |
-> ('a,'rank2) t -> ('a,'rank3) t | |
val ( * ) : ('dim, ('rank1, 'rank2, 'rank3, _ ) product) t | |
-> ('dim, 'rank2) t -> | |
('dim,'rank3) t | |
val floor: _ scalar -> int | |
end = struct | |
type (+'dim,+'rank) t = {rank:int; data: float array } | |
type +'x scalar = ('a, 'b z) t constraint 'x = 'a * 'b | |
type +'x vec2 = ('a two,'b one) t constraint 'x = 'a * 'b | |
type +'x vec3 = ('a three,'b one) t constraint 'x = 'a * 'b | |
type +'x vec4 = ('a four,'b one) t constraint 'x = 'a * 'b | |
type +'x mat2 = ('a two,'b two) t constraint 'x = 'a * 'b | |
type +'x mat3 = ('a three,'b two) t constraint 'x = 'a * 'b | |
type +'x mat4 = ('a four,'b two) t constraint 'x = 'a * 'b | |
let scalar x = { rank = 0; data = [|x|] } | |
let vec2 x y = {rank=1; data = [|x;y|]} | |
let vec3 x y z = { rank=1; data = [|x;y;z|] } | |
let vec4 x y z t = { rank=1; data = [|x;y;z;t|] } | |
let mat2 {data=a;_} {data=b;_} = | |
{ rank = 2; data = [| a.(0); a.(1); b.(0); b.(1) |] } | |
let mat3 {data=a;_} {data=b;_} {data=c; _ } = | |
{ rank = 2; data = [| a.(0); a.(1); a.(2); | |
b.(0); b.(1); b.(2); | |
c.(0); c.(1); c.(2) |] | |
} | |
let mat4 {data=a;_} {data=b;_} {data=c; _ } {data=d;_} = | |
{ rank = 2; data = [| a.(0); a.(1); a.(2); a.(3); | |
b.(0); b.(1); b.(2); b.(3); | |
c.(0); c.(1); c.(2); c.(3); | |
d.(0); d.(1); d.(2); d.(3); | |
|] | |
} | |
let map f x = { x with data = Array.map f x.data } | |
let smap f x y = map (f x.data.(0)) y | |
let dim a = match Array.length a.data with | |
| 4 -> 2 | |
| 9 -> 3 | |
| 16 -> 4 | |
| _ -> raise Unexpected_matrix_dimension | |
let ( * ) a b = match a.rank, b.rank with | |
| 0, _ -> smap ( *. ) a b | |
| _, 0 -> smap ( *. ) b a | |
| 1, 1 -> { rank=1; data = map2 ( *. ) a.data b.data } | |
| 1, 2 | 2, 1 -> | |
let dim = Array.length a.data in | |
let a , b, s1, s2= if a.rank = 1 then a.data, b.data, dim, 1 | |
else b.data, a.data, 1, dim in | |
let sum i = let s = ref 0. and ij = ref (i * s1) in | |
for j = 0 to dim do s:= a.(j) *. b.(!ij) +. !s; ij := s2 + !ij done; | |
!s | |
in | |
{ rank=1; data = Array.init dim sum } | |
| 2, 2 -> | |
let dim = dim a in | |
let a = a.data and b = b.data in | |
let sum i j = let s = ref 0. in | |
for k = 0 to dim do | |
s:= a.(i * dim + k) *. b.(k * dim + j) +. !s done; | |
!s | |
in | |
let data = Array.init (Array.length a) | |
(fun n -> sum (n / dim) (n mod dim)) in | |
{ rank = 2; data } | |
| x, y -> raise (Unexpected_ranks (x,y)) | |
let (+) a b = | |
if a.rank = 0 then | |
smap (+.) a b | |
else if b.rank = 0 then | |
smap (+.) b a | |
else { a with data = Array.map2 (+.) a.data b.data } | |
let floor a = int_of_float ( a.data.(0) ) | |
end | |
open Phantom | |
let v = vec2 0. 1. | |
let m = mat2 v v | |
let x = m * m | |
let z = v * m | |
let w = vec3 0. 0. 1. | |
let s = scalar 0. | |
let w = v + vec2 1. 1. | |
let v' = s + v | |
let v'' = s * v | |
let v''' = v' * v'' | |
(* | |
let error = vec3 0. 0. 1. + vec2 1. 0. | |
let error2 = vec3 0. 0. 1. * vec2 1. 0. | |
*) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment