Created
November 9, 2017 10:30
-
-
Save c42f/f84871b2e8a7cffae253859728be9817 to your computer and use it in GitHub Desktop.
unroll macro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Unroll | |
struct Vec{N,T} <: AbstractVector{T} | |
data::NTuple{N,T} | |
end | |
Base.size(v::Vec) = (length(v.data),) | |
Base.getindex(v::Vec, i::Int) = v.data[i] | |
function _unroll_for(ex) | |
@assert ex.head == :for | |
loop_var = QuoteNode(ex.args[1].args[1]) | |
loop_iterator = ex.args[1].args[2] | |
body = QuoteNode(ex.args[2]) | |
# Manual hygiene... Normal escaping seemed unusally difficult here | |
_exprs = gensym("exprs") | |
_loop_var = gensym("loop_var") | |
quote | |
if @generated | |
# Generate the expression unrolling loop. We can't actually do | |
# this in the macro, as the loop bounds are only accessible in the | |
# evaluating the `@generated` body. | |
$_exprs = [] | |
for $_loop_var in $loop_iterator | |
push!($_exprs, Expr(:(=), $loop_var, $_loop_var)) | |
push!($_exprs, $body) | |
end | |
Expr(:block, $_exprs...) | |
else | |
"asdf" | |
end | |
end | |
end | |
function _unroll_generator(ex) | |
@assert ex.head == :generator | |
@assert length(ex.args)==2 "Only 1D generators supported for now" | |
loop_var = QuoteNode(ex.args[2].args[1]) | |
loop_iterator = ex.args[2].args[2] | |
body = QuoteNode(ex.args[1]) | |
# Manual hygiene... Normal escaping seemed unusally difficult here | |
_exprs = gensym("exprs") | |
_loop_var = gensym("loop_var") | |
quote | |
if @generated | |
# Generate the expression unrolling loop. We can't actually do | |
# this in the macro, as the loop bounds are only accessible in the | |
# evaluating the `@generated` body. | |
$_exprs = Expr(:tuple) | |
for $_loop_var in $loop_iterator | |
push!($_exprs.args, Expr(:block, Expr(:(=), $loop_var, $_loop_var), | |
$body)) | |
end | |
$_exprs | |
else | |
# Eek, this fallback will be super slow !? | |
tuple(collect($ex)...) | |
end | |
end | |
end | |
macro unroll(ex) | |
esc( | |
if ex.head == :for | |
_unroll_for(ex) | |
elseif ex.head == :generator | |
_unroll_generator(ex) | |
else | |
throw(ArgumentError("`@unroll` only supports for loops and generators")) | |
end | |
) | |
end | |
function mysum(v::Vec{N}) where N | |
x = zero(eltype(v)) | |
@unroll for i = 1:N | |
x += v[i] | |
end | |
x | |
end | |
function Base.:+(v1::Vec{N}, v2::Vec{N}) where N | |
Vec(@unroll(v1[i]+v2[i] for i = 1:N)) | |
end | |
@show mysum(Vec((1,2,3))) | |
@show Vec((1,2,3)) + Vec((-1,0,1)) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment