Last active
January 2, 2022 19:09
-
-
Save SteveBronder/b22e5e577912d6cf4fe7a284b726a295 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
let collect_decl_sizes (mir_section : 'b list) = | |
let get_decls Stmt.Fixed.({pattern;_}) = match pattern with | |
| Stmt.Fixed.Pattern.Decl({decl_id; decl_type=Type.Sized st; _}) -> Some (decl_id, SizedType.dims_of st) | |
| _ -> None | |
in | |
List.filter_map ~f:get_decls mir_section | |
(** | |
* Collapse statements in for loops to not depend on the loop. | |
* The goal here is to take for loops in sets of statements such as | |
* ``` | |
* vector[1, 10] X1 = // fill...; | |
* vector[1, 10] X2; | |
* for (i in 1:10) { | |
* X2[i] = exp(X1[i]); | |
* } | |
* ``` | |
* and reduce them to | |
* ``` | |
* X2 = exp(X1); | |
* ``` | |
* to do that as a first impl working on only vectors I think steps needed are | |
* 1. Generate a map of ('variable names', sizes list) over all the blocks | |
* 2. Iterate over each part of the program searching for loops | |
* 3. Within each loop, remove the indices from the lhs and rhs if | |
* the following conditions are true for both sides | |
* a. Only the unmodified single top level index loopvar is used | |
* b. Any functions called have overloaded function impls that | |
* work with vectors | |
* c. The lower of the loop is 1 | |
* Using the above we want to generate code such as | |
* ``` | |
* for (i in 1:10) { | |
* X2[1:10] = exp(X1[1:10]); | |
* } | |
* ``` | |
* Where then another pass would see that this statement does not | |
* depend on the loop and can be removed to produce | |
* ``` | |
* X2[1:10] = exp(X1[1:10]); | |
* for (i in 1:10) { | |
* } | |
* ``` | |
* and DCE and another optimziation pass would see that | |
* X2 and X1 are of size 10 to reduce this to | |
* ``` | |
* X2 = exp(X1); | |
* ``` | |
* This is safe for cases where X1 and X1 differ in size as the additional pass | |
* can produce | |
* ``` | |
* X2 = exp(X1[1:10]); | |
* ``` | |
* if `X1` and `X2` lengths differ | |
*) | |
let collapse_loops_mir (mir : Program.Typed.t) = | |
let collapse_loops decl_lst (Stmt.Fixed.{pattern; _} as stmt) = | |
match pattern with | |
| Stmt.Fixed.Pattern.For | |
( { loopvar= iterator | |
; lower= Expr.Fixed.{pattern= Expr.Fixed.Pattern.Lit (Int, "1"); _} | |
; upper= Expr.Fixed.{pattern= Expr.Fixed.Pattern.Lit (Int, upp); _} | |
; body= Stmt.Fixed.{pattern= Block block_list; _} as inner_body_stmt | |
; _ } as body_stmt ) -> | |
let collapse_stmt_finder _ (Stmt.Fixed.{pattern; _} as lst_stmt) = | |
let new_pattern = | |
match pattern with | |
| Stmt.Fixed.Pattern.Assignment (((name : string), _, idx), rhs) -> | |
let decl_size = Map.find_exn decl_lst ~key:name in | |
let is_good_size = ( match decl_size with | [idx1] when idx1 = upp -> true | _ -> false) in | |
let is_idx_good idx = (match idx with | |
| [Index.Single (Expr.Fixed.{pattern= Expr.Fixed.Pattern.Var inner_idx;_})] when inner_idx = iterator -> true | |
| _ ->false) in | |
let is_lhs_idx_good = is_idx_good idx in | |
pattern | |
| _ -> pattern in | |
{lst_stmt with pattern= new_pattern} in | |
let blah1 = List.map ~f:(collapse_stmt_finder iterator) block_list in | |
{ stmt with | |
pattern= | |
Stmt.Fixed.Pattern.For | |
{body_stmt with body= {inner_body_stmt with pattern= Block blah1}} | |
} | |
| _ -> stmt in | |
let collect_decls = Map.of_alist_exn (collect_decl_sizes mir.prepare_data) in | |
{mir with prepare_data= List.map ~f:(collapse_loops collect_decls) mir.prepare_data} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment