Last active
August 12, 2016 08:02
-
-
Save tomasaschan/e501e1079d947d699b71941d93b7113e to your computer and use it in GitHub Desktop.
Alternative to @eval in @generated functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# not really needed other than for readability | |
typealias DimPoint Union{Int,Symbol} | |
# expand_dims makes sure that all dims are on the form ((lower,step,upper),...), | |
# even if initially specified e.g. as (upper, (lower,upper), ...) etc | |
expand_dims(upper::DimPoint) = (1,1,upper) | |
expand_dims(unit::NTuple{2,DimPoint}) = (unit[1], 1, unit[2]) | |
expand_dims(full::NTuple{3,DimPoint}) = full | |
function expand_dims(dims) | |
map(dims) do dim | |
expand_dims(dim) | |
end | |
end | |
# realize_dims replaces the symbols with values from the provided kwargs | |
realize_dim(dim::Int; kwargs...) = Nullable{DimPoint}(dim) | |
realize_dim(dim::Symbol; kwargs...) = get_kwarg_value(kwargs, dim) | |
function realize_dims(dims; kwargs...) | |
expanded = expand_dims(dims) | |
map(expanded) do dim | |
map(dim) do d | |
realized = realize_dim(d; kwargs...) | |
!isnull(realized) || error("Could not realize the variable dimension $d; no value given. (Dimspec: $dims)") | |
get(realized) | |
end | |
end | |
end | |
function get_kwarg_value(kwargs, name::Symbol) | |
matching = filter(kwarg -> kwarg[1] == name, kwargs) | |
length(matching) > 0 ? Nullable{DimPoint}(first(matching)[2]) : Nullable{DimPoint}() | |
end | |
# this type is your FastArray; it stores the realized dims information in its type, to allow dispatch on it | |
immutable FooArray{T, N, dims} <: AbstractArray{T, N} | |
data::Array{T,N} | |
end | |
function FooArray{T}(::Type{T}, dims...; kwargs...) | |
realized = realize_dims(dims; kwargs...) | |
sizes = map(realized) do dim | |
length(dim[1]:dim[2]:dim[3]) | |
end | |
FooArray{T,length(sizes),realized}(zeros(sizes)) | |
end | |
# instead of generating a new type for each variant, just create a "generator" which is essentially a constructor | |
# of FooArray which figures out the correct arguments | |
function generator(dims...) | |
function generate{T}(::Type{T}; kwargs...) | |
FooArray(T, dims...; kwargs...) | |
end | |
end | |
# these are just to make the results display nicely | |
# getindex in particular will of course need another implementation that takes dims into account | |
Base.size(A::FooArray) = size(A.data) | |
Base.getindex(A::FooArray, i...) = getindex(A.data, i...) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const A_3xn = generator(3, :n) | |
a_3xn = A_3xn(Float64; n = 20) # A 3x20 FooArray | |
const A_ax2to5xnstrided = generator(:a, (2,5), (1,:n,10)) | |
a_ax2to5xnstrided = A_ax2to5xnstrided(Int; a = 3, n = 2) # A 3x4x5 FooArray |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment