Last active
May 29, 2019 10:50
-
-
Save baggepinnen/6fcb15c4f2a61cce9c78aa884da51b0c to your computer and use it in GitHub Desktop.
Compiler pass to translate branching code to map over function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#I would like to use Cassette.jl to implement a compiler pass. The task is to (conditioned on types) translate code that contains a branch into a map over a function that contains the branch. An example of the transformation I would like to do is from | |
code = Meta.@lower if x > 0 | |
return x^2 | |
else | |
return -x^2 | |
end | |
#to | |
code2 = Meta.@lower map(x.particles) do x | |
if x > 0 | |
return x^2 | |
else | |
return -x^2 | |
end | |
end | |
#where the translation should only be done if `x` is of a special type defined below | |
struct Particles{T} <: Real | |
particles::Vector{T} | |
end | |
#Without this transformation, code like the following fails predictably | |
Base.:(^)(p::Particles,r) = Particles(p.particles.^r) | |
Base.:(>)(p::Particles, r) = Particles(map(>, p.particles, r)) | |
p = Particles(randn(10)) | |
function negsquare(x) | |
if x > 0 | |
return x^2 | |
else | |
return -x^2 | |
end | |
end | |
#julia> negsquare(p) | |
#ERROR: TypeError: non-boolean (Particles) used in boolean context | |
#If the code was translated to | |
#function negsquare(x) | |
# Particles(map(x.particles) do x | |
# if x > 0 | |
# return x^2 | |
# else | |
# return -x^2 | |
# end | |
# end) | |
#end | |
#julia> negsquare(p) | |
#Particles([0.404953, 0.210984, -1.00176, 0.253796, -0.00620389, 0.831144, -0.0240916, -1.90169, 0.875192, 1.4788]) | |
#``` | |
#I would get the desired result. | |
#I have so far made some progress with Cassette; I can sort out code that do not operate on `Particles` and I can identify branches in the code. However, I cannot figure out how to transform the relevant code into the map statement above. My attempt, with three inserted "QUESTION" and "TODO" in comments | |
using Cassette | |
contains_branch(ir::Core.CodeInfo) = any(contains_branch, ir.code) | |
contains_branch(ex::Expr) = ex.head == :gotoifnot # Unfortunately, there seems to be more ways of branching in ir | |
contains_branch(any) = false | |
branch_target(ex) = ex.args[2] # Return the goto statement target index | |
"My custom compiler pass" | |
function mapif(::Type{<:Ctx}, reflection::Cassette.Reflection) | |
ir = reflection.code_info | |
any(x-> x <: Particles, reflection.signature.parameters) || (return ir) # No particles included in this call | |
contains_branch(ir) || (return ir) # If there is no branch we leave the code alone | |
stmtcount = function (stmt, i) | |
contains_branch(stmt) || (return nothing) | |
return 1 # QUESTION: One function call replaces the branch | |
end | |
newstmts = function (stmt, i) | |
@show branch_body = ir.code[i+1:branch_target(stmt)-1] # the branch body starts one after the index of the gotoifnot and ends one before the branch target | |
# TODO: put the branch body into a map function, have to somehow get rid of all stmt that were put into the function | |
[stmt] # Must have length | |
end | |
Cassette.insert_statements!(ir.code, ir.codelocs, stmtcount, newstmts) # QUESTION: Is it good to send in the entire ir.code so that all SSAValues are updated? | |
ir | |
end | |
mapifpass = Cassette.@pass mapif | |
ctx = Ctx(pass=mapifpass) | |
Cassette.overdub(ctx, negsquare, p) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment