Created
March 7, 2014 15:41
-
-
Save carlobaldassi/9413784 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base: @ngenerate, @nloops, @nref, @ncall, @ntuple, @nif, @nexprs | |
import Base: start, done, next | |
# this generates types like this: | |
# immutable SubInd_3 <: SubInds{3} | |
# I_1::Int | |
# I_2::Int | |
# I_3::Int | |
# end | |
# they are used as iterator states | |
abstract SubInds{N} | |
for N = 1:10 | |
name = symbol("SubInd_$N") | |
fields = [Expr(:(::), symbol("I_$i"), :Int) for i = 1:N] | |
eval(Expr(:type, false, Expr(:(<:), name, Expr(:curly, :SubInds, N)), Expr(:block, fields...))) | |
end | |
@ngenerate N SubInds{N} function start{T,N}(A::SubArray{T,N}) | |
@ncall N (@nexprs 1 x->SubInd_{N}) d->1 | |
end | |
@ngenerate N (eltype(A), typeof(I)) function next{T,N}(A::SubArray{T,N}, I::SubInds{N}) | |
@inbounds v = @nref N A d->getfield(I,d) | |
@inbounds I = @nif N d->(getfield(I,d) < A.dims[d]) d->(@ncall(N, (@nexprs 1 x->SubInd_{N}), k->(k>d ? getfield(I,k) : k==d ? getfield(I,k)+1 : 1))) | |
return v::T, I | |
end | |
done{T,N}(A::SubArray{T,N}, I::SubInds{N}) = getfield(I, N) > size(A, N) | |
## Test functions | |
# standard cartesian iteration | |
@ngenerate N Float64 function subit_cart{T,N}(A::SubArray{T,N}) | |
s = 0.0 | |
@nloops N i A begin | |
@inbounds (s += @nref N A i) | |
end | |
s | |
end | |
# start/next/done iteration, not inlined | |
function subit_iter(A::SubArray) | |
s = 0.0 | |
@inbounds for x in A | |
s += x | |
end | |
s | |
end | |
# start/next/done iteration, manually inlined | |
@ngenerate N Float64 function subit_iter_in{T,N}(A::SubArray{T,N}) | |
s = 0.0 | |
I = @ncall N (@nexprs 1 x->SubInd_{N}) d->1 | |
@inbounds while getfield(I, N) <= A.dims[N] | |
x = @nref N A d->getfield(I,d) | |
I = @nif N d->(getfield(I,d) < A.dims[d]) d->(@ncall(N, (@nexprs 1 x->SubInd_{N}), k->(k>d ? getfield(I,k) : k==d ? getfield(I,k)+1 : 1))) | |
s += x | |
end | |
s | |
end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment