Last active
August 29, 2015 14:09
-
-
Save Jutho/832f3f4aee84cf927a53 to your computer and use it in GitHub Desktop.
Cartesian indexing and iteration
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base: start, done, next, getindex, setindex! | |
import Base: @nref, @ncall, @nif, @nexprs | |
export eachelement, eachindex, linearindexing, LinearFast | |
# Traits for linear indexing | |
abstract LinearIndexing | |
immutable LinearFast <: LinearIndexing end | |
immutable LinearSlow <: LinearIndexing end | |
linearindexing(::AbstractArray) = LinearSlow() | |
linearindexing(::Array) = LinearFast() | |
linearindexing(::BitArray) = LinearFast() | |
linearindexing(::Range) = LinearFast() | |
abstract CartesianIndex{N} # the state for all multidimensional iterators | |
abstract IndexIterator{N} # Iterator that visits the index associated with each element | |
stagedfunction Base.call{N}(::Type{CartesianIndex},index::NTuple{N,Int}) | |
indextype,itertype=gen_cartesian(N) | |
return :($indextype(index)) | |
end | |
stagedfunction Base.call{N}(::Type{IndexIterator},index::NTuple{N,Int}) | |
indextype,itertype=gen_cartesian(N) | |
return :($itertype(index)) | |
end | |
let implemented = IntSet() | |
global gen_cartesian | |
function gen_cartesian(N::Int, with_shared=Base.is_unix(OS_NAME)) | |
# Create the types | |
indextype = symbol("CartesianIndex_$N") | |
itertype = symbol("IndexIterator_$N") | |
if !in(N,implemented) | |
fieldnames = [symbol("I_$i") for i = 1:N] | |
fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N] | |
extype = Expr(:type, false, Expr(:(<:), indextype, Expr(:curly, :CartesianIndex, N)), Expr(:block, fields...)) | |
exindices = Expr[:(index[$i]) for i = 1:N] | |
onesN = ones(Int, N) | |
infsN = fill(typemax(Int), N) | |
anyzero = Expr(:(||), [:(iter.dims.$(fieldnames[i]) == 0) for i = 1:N]...) | |
# Some necessary ambiguity resolution | |
exrange = N != 1 ? nothing : quote | |
next(R::StepRange, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1) | |
next{T}(R::UnitRange{T}, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1) | |
end | |
exshared = !with_shared ? nothing : quote | |
getindex{T}(S::SharedArray{T,$N}, I::$indextype) = S.s[I] | |
setindex!{T}(S::SharedArray{T,$N}, v, I::$indextype) = S.s[I] = v | |
end | |
totalex = quote | |
# type definition | |
$extype | |
# extra constructor from tuple | |
$indextype(index::NTuple{$N,Int}) = $indextype($(exindices...)) | |
immutable $itertype <: IndexIterator{$N} | |
dims::$indextype | |
end | |
$itertype(dims::NTuple{$N,Int})=$itertype($indextype(dims)) | |
# getindex and setindex! | |
$exshared | |
getindex{T}(A::AbstractArray{T,$N}, index::$indextype) = @nref $N A d->getfield(index,d) | |
setindex!{T}(A::AbstractArray{T,$N}, v, index::$indextype) = (@nref $N A d->getfield(index,d)) = v | |
# next iteration | |
$exrange | |
@inline function next{T}(A::AbstractArray{T,$N}, state::$indextype) | |
@inbounds v = A[state] | |
newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1))) | |
v, newstate | |
end | |
@inline function next(iter::$itertype, state::$indextype) | |
newstate = @nif $N d->(getfield(state,d) < iter.dims[d]) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1))) | |
state, newstate | |
end | |
# start | |
start(iter::$itertype) = $anyzero ? $indextype($(infsN...)) : $indextype($(onesN...)) | |
end | |
eval(totalex) | |
push!(implemented,N) | |
end | |
return indextype, itertype | |
end | |
end | |
# Iteration | |
eachindex(A::AbstractArray) = IndexIterator(size(A)) | |
eachelement(A::AbstractArray) = A | |
# start iteration | |
start(A::AbstractArray) = start((A,linearindexing(A))) | |
start(::(AbstractArray,LinearFast)) = 1 | |
start{T,N}(AT::(AbstractArray{T,N},LinearSlow)) = CartesianIndex(ntuple(N,n->(isempty(AT[1]) ? typemax(Int) : 1))) | |
# Ambiguity resolution | |
done(R::StepRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R) | |
done(R::UnitRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R) | |
done(R::FloatRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R) | |
done{T,N}(A::AbstractArray{T,N}, I::CartesianIndex{N}) = getfield(I, N) > size(A, N) | |
done{N}(iter::IndexIterator{N}, I::CartesianIndex{N}) = getfield(I, N) > getfield(iter.dims, N) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment