Skip to content

Instantly share code, notes, and snippets.

@SteveBronder
Last active January 2, 2022 19:09
Show Gist options
  • Save SteveBronder/b22e5e577912d6cf4fe7a284b726a295 to your computer and use it in GitHub Desktop.
Save SteveBronder/b22e5e577912d6cf4fe7a284b726a295 to your computer and use it in GitHub Desktop.
let collect_decl_sizes (mir_section : 'b list) =
let get_decls Stmt.Fixed.({pattern;_}) = match pattern with
| Stmt.Fixed.Pattern.Decl({decl_id; decl_type=Type.Sized st; _}) -> Some (decl_id, SizedType.dims_of st)
| _ -> None
in
List.filter_map ~f:get_decls mir_section
(**
* Collapse statements in for loops to not depend on the loop.
* The goal here is to take for loops in sets of statements such as
* ```
* vector[1, 10] X1 = // fill...;
* vector[1, 10] X2;
* for (i in 1:10) {
* X2[i] = exp(X1[i]);
* }
* ```
* and reduce them to
* ```
* X2 = exp(X1);
* ```
* to do that as a first impl working on only vectors I think steps needed are
* 1. Generate a map of ('variable names', sizes list) over all the blocks
* 2. Iterate over each part of the program searching for loops
* 3. Within each loop, remove the indices from the lhs and rhs if
* the following conditions are true for both sides
* a. Only the unmodified single top level index loopvar is used
* b. Any functions called have overloaded function impls that
* work with vectors
* c. The lower of the loop is 1
* Using the above we want to generate code such as
* ```
* for (i in 1:10) {
* X2[1:10] = exp(X1[1:10]);
* }
* ```
* Where then another pass would see that this statement does not
* depend on the loop and can be removed to produce
* ```
* X2[1:10] = exp(X1[1:10]);
* for (i in 1:10) {
* }
* ```
* and DCE and another optimziation pass would see that
* X2 and X1 are of size 10 to reduce this to
* ```
* X2 = exp(X1);
* ```
* This is safe for cases where X1 and X1 differ in size as the additional pass
* can produce
* ```
* X2 = exp(X1[1:10]);
* ```
* if `X1` and `X2` lengths differ
*)
let collapse_loops_mir (mir : Program.Typed.t) =
let collapse_loops decl_lst (Stmt.Fixed.{pattern; _} as stmt) =
match pattern with
| Stmt.Fixed.Pattern.For
( { loopvar= iterator
; lower= Expr.Fixed.{pattern= Expr.Fixed.Pattern.Lit (Int, "1"); _}
; upper= Expr.Fixed.{pattern= Expr.Fixed.Pattern.Lit (Int, upp); _}
; body= Stmt.Fixed.{pattern= Block block_list; _} as inner_body_stmt
; _ } as body_stmt ) ->
let collapse_stmt_finder _ (Stmt.Fixed.{pattern; _} as lst_stmt) =
let new_pattern =
match pattern with
| Stmt.Fixed.Pattern.Assignment (((name : string), _, idx), rhs) ->
let decl_size = Map.find_exn decl_lst ~key:name in
let is_good_size = ( match decl_size with | [idx1] when idx1 = upp -> true | _ -> false) in
let is_idx_good idx = (match idx with
| [Index.Single (Expr.Fixed.{pattern= Expr.Fixed.Pattern.Var inner_idx;_})] when inner_idx = iterator -> true
| _ ->false) in
let is_lhs_idx_good = is_idx_good idx in
pattern
| _ -> pattern in
{lst_stmt with pattern= new_pattern} in
let blah1 = List.map ~f:(collapse_stmt_finder iterator) block_list in
{ stmt with
pattern=
Stmt.Fixed.Pattern.For
{body_stmt with body= {inner_body_stmt with pattern= Block blah1}}
}
| _ -> stmt in
let collect_decls = Map.of_alist_exn (collect_decl_sizes mir.prepare_data) in
{mir with prepare_data= List.map ~f:(collapse_loops collect_decls) mir.prepare_data}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment