Created
December 30, 2019 12:39
-
-
Save jw3126/8ce0d6b7cdbbcc83486e86198648934c to your computer and use it in GitHub Desktop.
layered array
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Revise | |
using ArgCheck | |
struct LayeredArray{T, N, L} <: AbstractArray{T,N} | |
layers::L | |
function LayeredArray(layers) | |
@argcheck !isempty(layers) | |
@argcheck first(layers) isa AbstractArray | |
l = first(layers) | |
L = typeof(layers) | |
N = ndims(l) | |
T = eltype(l) | |
for layer in layers | |
@argcheck layer isa AbstractArray{T, N} | |
end | |
ret = new{T,N,L}(layers) | |
ret | |
end | |
end | |
fallbacklayer(l::LayeredArray) = last(l.layers) | |
Base.size(l::LayeredArray) = size(fallbacklayer(l)) | |
Base.axes(l::LayeredArray) = Base.axes(fallbacklayer(l)) | |
Base.IndexStyle(l::Type{<: LayeredArray}) = Base.IndexCartesian() | |
function Base.getindex(o::LayeredArray, ci::CartesianIndex) | |
l = get_containing_layer(o, ci) | |
l[ci] | |
end | |
function get_containing_layer(o::LayeredArray, ci::CartesianIndex) | |
for layer in o.layers | |
if ci in CartesianIndices(layer) | |
return layer | |
end | |
end | |
err = BoundsError(o, ci) | |
throw(err) | |
end | |
function Base.getindex(l::LayeredArray, inds...) | |
ci = CartesianIndices(l)[inds...] | |
l[ci] | |
end | |
using Test | |
l = LayeredArray([[10,2], [1,6,3]]) | |
@inferred l == [10, 2, 3] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment