Skip to content

Instantly share code, notes, and snippets.

@Jutho
Last active August 29, 2015 14:09
Show Gist options
  • Save Jutho/832f3f4aee84cf927a53 to your computer and use it in GitHub Desktop.
Save Jutho/832f3f4aee84cf927a53 to your computer and use it in GitHub Desktop.
Cartesian indexing and iteration
import Base: start, done, next, getindex, setindex!
import Base: @nref, @ncall, @nif, @nexprs
export eachelement, eachindex, linearindexing, LinearFast
# Traits for linear indexing
abstract LinearIndexing
immutable LinearFast <: LinearIndexing end
immutable LinearSlow <: LinearIndexing end
linearindexing(::AbstractArray) = LinearSlow()
linearindexing(::Array) = LinearFast()
linearindexing(::BitArray) = LinearFast()
linearindexing(::Range) = LinearFast()
abstract CartesianIndex{N} # the state for all multidimensional iterators
abstract IndexIterator{N} # Iterator that visits the index associated with each element
stagedfunction Base.call{N}(::Type{CartesianIndex},index::NTuple{N,Int})
indextype,itertype=gen_cartesian(N)
return :($indextype(index))
end
stagedfunction Base.call{N}(::Type{IndexIterator},index::NTuple{N,Int})
indextype,itertype=gen_cartesian(N)
return :($itertype(index))
end
let implemented = IntSet()
global gen_cartesian
function gen_cartesian(N::Int, with_shared=Base.is_unix(OS_NAME))
# Create the types
indextype = symbol("CartesianIndex_$N")
itertype = symbol("IndexIterator_$N")
if !in(N,implemented)
fieldnames = [symbol("I_$i") for i = 1:N]
fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N]
extype = Expr(:type, false, Expr(:(<:), indextype, Expr(:curly, :CartesianIndex, N)), Expr(:block, fields...))
exindices = Expr[:(index[$i]) for i = 1:N]
onesN = ones(Int, N)
infsN = fill(typemax(Int), N)
anyzero = Expr(:(||), [:(iter.dims.$(fieldnames[i]) == 0) for i = 1:N]...)
# Some necessary ambiguity resolution
exrange = N != 1 ? nothing : quote
next(R::StepRange, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1)
next{T}(R::UnitRange{T}, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1)
end
exshared = !with_shared ? nothing : quote
getindex{T}(S::SharedArray{T,$N}, I::$indextype) = S.s[I]
setindex!{T}(S::SharedArray{T,$N}, v, I::$indextype) = S.s[I] = v
end
totalex = quote
# type definition
$extype
# extra constructor from tuple
$indextype(index::NTuple{$N,Int}) = $indextype($(exindices...))
immutable $itertype <: IndexIterator{$N}
dims::$indextype
end
$itertype(dims::NTuple{$N,Int})=$itertype($indextype(dims))
# getindex and setindex!
$exshared
getindex{T}(A::AbstractArray{T,$N}, index::$indextype) = @nref $N A d->getfield(index,d)
setindex!{T}(A::AbstractArray{T,$N}, v, index::$indextype) = (@nref $N A d->getfield(index,d)) = v
# next iteration
$exrange
@inline function next{T}(A::AbstractArray{T,$N}, state::$indextype)
@inbounds v = A[state]
newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
v, newstate
end
@inline function next(iter::$itertype, state::$indextype)
newstate = @nif $N d->(getfield(state,d) < iter.dims[d]) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
state, newstate
end
# start
start(iter::$itertype) = $anyzero ? $indextype($(infsN...)) : $indextype($(onesN...))
end
eval(totalex)
push!(implemented,N)
end
return indextype, itertype
end
end
# Iteration
eachindex(A::AbstractArray) = IndexIterator(size(A))
eachelement(A::AbstractArray) = A
# start iteration
start(A::AbstractArray) = start((A,linearindexing(A)))
start(::(AbstractArray,LinearFast)) = 1
start{T,N}(AT::(AbstractArray{T,N},LinearSlow)) = CartesianIndex(ntuple(N,n->(isempty(AT[1]) ? typemax(Int) : 1)))
# Ambiguity resolution
done(R::StepRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
done(R::UnitRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
done(R::FloatRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
done{T,N}(A::AbstractArray{T,N}, I::CartesianIndex{N}) = getfield(I, N) > size(A, N)
done{N}(iter::IndexIterator{N}, I::CartesianIndex{N}) = getfield(I, N) > getfield(iter.dims, N)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment