Skip to content

Instantly share code, notes, and snippets.

@tomasaschan
Last active August 12, 2016 08:02
Show Gist options
  • Save tomasaschan/e501e1079d947d699b71941d93b7113e to your computer and use it in GitHub Desktop.
Save tomasaschan/e501e1079d947d699b71941d93b7113e to your computer and use it in GitHub Desktop.
Alternative to @eval in @generated functions
# not really needed other than for readability
typealias DimPoint Union{Int,Symbol}
# expand_dims makes sure that all dims are on the form ((lower,step,upper),...),
# even if initially specified e.g. as (upper, (lower,upper), ...) etc
expand_dims(upper::DimPoint) = (1,1,upper)
expand_dims(unit::NTuple{2,DimPoint}) = (unit[1], 1, unit[2])
expand_dims(full::NTuple{3,DimPoint}) = full
function expand_dims(dims)
map(dims) do dim
expand_dims(dim)
end
end
# realize_dims replaces the symbols with values from the provided kwargs
realize_dim(dim::Int; kwargs...) = Nullable{DimPoint}(dim)
realize_dim(dim::Symbol; kwargs...) = get_kwarg_value(kwargs, dim)
function realize_dims(dims; kwargs...)
expanded = expand_dims(dims)
map(expanded) do dim
map(dim) do d
realized = realize_dim(d; kwargs...)
!isnull(realized) || error("Could not realize the variable dimension $d; no value given. (Dimspec: $dims)")
get(realized)
end
end
end
function get_kwarg_value(kwargs, name::Symbol)
matching = filter(kwarg -> kwarg[1] == name, kwargs)
length(matching) > 0 ? Nullable{DimPoint}(first(matching)[2]) : Nullable{DimPoint}()
end
# this type is your FastArray; it stores the realized dims information in its type, to allow dispatch on it
immutable FooArray{T, N, dims} <: AbstractArray{T, N}
data::Array{T,N}
end
function FooArray{T}(::Type{T}, dims...; kwargs...)
realized = realize_dims(dims; kwargs...)
sizes = map(realized) do dim
length(dim[1]:dim[2]:dim[3])
end
FooArray{T,length(sizes),realized}(zeros(sizes))
end
# instead of generating a new type for each variant, just create a "generator" which is essentially a constructor
# of FooArray which figures out the correct arguments
function generator(dims...)
function generate{T}(::Type{T}; kwargs...)
FooArray(T, dims...; kwargs...)
end
end
# these are just to make the results display nicely
# getindex in particular will of course need another implementation that takes dims into account
Base.size(A::FooArray) = size(A.data)
Base.getindex(A::FooArray, i...) = getindex(A.data, i...)
const A_3xn = generator(3, :n)
a_3xn = A_3xn(Float64; n = 20) # A 3x20 FooArray
const A_ax2to5xnstrided = generator(:a, (2,5), (1,:n,10))
a_ax2to5xnstrided = A_ax2to5xnstrided(Int; a = 3, n = 2) # A 3x4x5 FooArray
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment