Skip to content

Instantly share code, notes, and snippets.

@c42f
Created November 9, 2017 10:30
Show Gist options
  • Save c42f/f84871b2e8a7cffae253859728be9817 to your computer and use it in GitHub Desktop.
Save c42f/f84871b2e8a7cffae253859728be9817 to your computer and use it in GitHub Desktop.
unroll macro
module Unroll
struct Vec{N,T} <: AbstractVector{T}
data::NTuple{N,T}
end
Base.size(v::Vec) = (length(v.data),)
Base.getindex(v::Vec, i::Int) = v.data[i]
function _unroll_for(ex)
@assert ex.head == :for
loop_var = QuoteNode(ex.args[1].args[1])
loop_iterator = ex.args[1].args[2]
body = QuoteNode(ex.args[2])
# Manual hygiene... Normal escaping seemed unusally difficult here
_exprs = gensym("exprs")
_loop_var = gensym("loop_var")
quote
if @generated
# Generate the expression unrolling loop. We can't actually do
# this in the macro, as the loop bounds are only accessible in the
# evaluating the `@generated` body.
$_exprs = []
for $_loop_var in $loop_iterator
push!($_exprs, Expr(:(=), $loop_var, $_loop_var))
push!($_exprs, $body)
end
Expr(:block, $_exprs...)
else
"asdf"
end
end
end
function _unroll_generator(ex)
@assert ex.head == :generator
@assert length(ex.args)==2 "Only 1D generators supported for now"
loop_var = QuoteNode(ex.args[2].args[1])
loop_iterator = ex.args[2].args[2]
body = QuoteNode(ex.args[1])
# Manual hygiene... Normal escaping seemed unusally difficult here
_exprs = gensym("exprs")
_loop_var = gensym("loop_var")
quote
if @generated
# Generate the expression unrolling loop. We can't actually do
# this in the macro, as the loop bounds are only accessible in the
# evaluating the `@generated` body.
$_exprs = Expr(:tuple)
for $_loop_var in $loop_iterator
push!($_exprs.args, Expr(:block, Expr(:(=), $loop_var, $_loop_var),
$body))
end
$_exprs
else
# Eek, this fallback will be super slow !?
tuple(collect($ex)...)
end
end
end
macro unroll(ex)
esc(
if ex.head == :for
_unroll_for(ex)
elseif ex.head == :generator
_unroll_generator(ex)
else
throw(ArgumentError("`@unroll` only supports for loops and generators"))
end
)
end
function mysum(v::Vec{N}) where N
x = zero(eltype(v))
@unroll for i = 1:N
x += v[i]
end
x
end
function Base.:+(v1::Vec{N}, v2::Vec{N}) where N
Vec(@unroll(v1[i]+v2[i] for i = 1:N))
end
@show mysum(Vec((1,2,3)))
@show Vec((1,2,3)) + Vec((-1,0,1))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment