Last active
August 2, 2018 23:49
-
-
Save devmotion/5fef9f5a80398c5dc89511e06afc2a03 to your computer and use it in GitHub Desktop.
Julia macros
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using MacroTools: namify | |
using Base.Meta: isexpr | |
## Utilities | |
# Split struct definition (neglecting constructors) | |
function splitstruct(structdef) | |
# Split type definition | |
isexpr(structdef, :struct, 3) || error("Not a type definition:", structdef) | |
mutable, structhead, structbody = structdef.args | |
# Split head of type definition | |
dict = splitstructhead(structhead) | |
dict[:mutable] = mutable | |
# Collect fields | |
fields = gatherfields(structbody) | |
dict[:fields] = fields | |
dict | |
end | |
# Split struct head | |
function splitstructhead(structhead) | |
# Handle supertype annotations | |
if isexpr(structhead, :<:, 2) | |
name_param, stype = structhead.args | |
else | |
name_param, stype = structhead, nothing | |
end | |
# Handle parameters | |
if isexpr(name_param, :curly) && length(name_param.args) > 1 | |
name, params = name_param.args[1], name_param.args[2:end] | |
else | |
name, params = name_param, nothing | |
end | |
isa(name, Symbol) || error("Not a head of a type definition:", structhead) | |
dict = Dict{Symbol,Any}(:name => name) | |
params != nothing && (dict[:params] = params) | |
stype != nothing && (dict[:supertype] = stype) | |
dict | |
end | |
# Collect all fields | |
gatherfields(ex) = _gatherfields!([], ex) | |
_gatherfields!(fields, ex) = fields | |
_gatherfields!(fields, ex::Symbol) = push!(fields, ex) | |
function _gatherfields!(fields, ex::Expr) | |
if isexpr(ex, Symbol("::"), 2) | |
push!(fields, ex) | |
else | |
for arg in ex.args | |
_gatherfields!(fields, arg) | |
end | |
fields | |
end | |
end | |
# Combine struct definition | |
function combinestruct(dict::Dict) | |
structhead = combinestructhead(dict) | |
# Generate struct definition | |
if dict[:mutable] | |
structdef = :(mutable struct $structhead | |
$(dict[:fields]...) | |
end) | |
else | |
structdef = :(struct $structhead | |
$(dict[:fields]...) | |
end) | |
end | |
# Add inner constructor | |
if haskey(dict, :inner) | |
push!(structdef.args[3].args, dict[:inner]) | |
end | |
structdef | |
end | |
# Combine struct head | |
function combinestructhead(dict::Dict) | |
name = dict[:name] | |
params = get(dict, :params, []) | |
stype = get(dict, :supertype, :Any) | |
isempty(params) ? :($name <: $stype) : :($name{$(params...)} <: $stype) | |
end | |
# Add inner constructor | |
function inner!(dict::Dict, n::Int) | |
# Compute subset of fields | |
allfields = dict[:fields] | |
fields = n < 0 ? fields = @view(allfields[1:end+n]) : | |
(n < length(allfields) ? @view(allfields[1:n]) : allfields) | |
# Obtain parameters without supertypes | |
paramnames = [namify(p) for p in get(dict, :params, [])] | |
# Add inner constructor | |
if isempty(paramnames) | |
dict[:inner] = | |
:(function $(dict[:name])($(fields...)) | |
new($(namify.(fields)...)) | |
end) | |
else | |
dict[:inner] = | |
:(function $(dict[:name]){$(paramnames...)}($(fields...)) where {$(paramnames...)} | |
new{$(paramnames...)}($(namify.(fields)...)) | |
end) | |
end | |
dict | |
end | |
## Extend struct definition | |
function extend!(dict::Dict, template::Dict) | |
# Merge parameters | |
if haskey(template, :params) | |
if !haskey(dict, :params) | |
dict[:params] = template[:params] | |
else | |
# TODO: do not copy existing parameters? | |
append!(dict[:params], template[:params]) | |
end | |
end | |
# Merge supertypes | |
!haskey(dict, :supertype) && haskey(template, :supertype) && | |
(dict[:supertype] = template[:supertype]) | |
# Merge fields | |
tfields = template[:fields] | |
if !isempty(tfields) | |
fields = dict[:fields] | |
if isempty(fields) | |
fields = tfields | |
else | |
# TODO: do not copy existing fields? | |
append!(fields, tfields) | |
end | |
end | |
dict | |
end | |
# MACROS | |
## Add inner constructor | |
macro add_inner(n::Int, structdef::Expr) | |
esc(add_inner(n, structdef)) | |
end | |
function add_inner(n::Int, structdef::Expr) | |
dict = splitstruct(structdef) | |
inner!(dict, n) | |
combinestruct(dict) | |
end | |
## Struct template | |
# Default template | |
_template(::Val) = Dict() | |
macro base(structdef::Expr) | |
esc(base(structdef)) | |
end | |
function base(structdef::Expr) | |
dict = splitstruct(structdef) | |
:($structdef; | |
_template(::$(Type{Val{dict[:name]}})) = $dict) | |
end | |
macro base_inner(n::Int, structdef::Expr) | |
esc(base_inner(n, structdef)) | |
end | |
function base_inner(n::Int, structdef::Expr) | |
dict = splitstruct(structdef) | |
inner!(dict, n) | |
:($(combinestruct(dict)); | |
_template(::$(Type{Val{dict[:name]}})) = $dict) | |
end | |
## Extend struct definition | |
macro extend(template::Symbol, structdef::Expr) | |
esc(extend(template, structdef)) | |
end | |
function extend(template::Symbol, structdef::Expr) | |
dict = splitstruct(structdef) | |
tdict = _template(Val{template}) | |
extend!(dict, tdict) | |
combinestruct(dict) | |
end | |
macro extend_inner(template::Symbol, n::Int, structdef::Expr) | |
esc(extend_inner(template, n, structdef)) | |
end | |
function extend_inner(template::Symbol, n::Int, structdef::Expr) | |
dict = splitstruct(structdef) | |
tdict = _template(Val{template}) | |
extend!(dict, tdict) | |
inner!(dict, n) | |
combinestruct(dict) | |
end | |
# EXAMPLE | |
using DiffEqBase | |
using OrdinaryDiffEq: OrdinaryDiffEqAlgorithm | |
@base_inner -2 mutable struct ODEIntegrator{algType<:OrdinaryDiffEqAlgorithm,uType,tType,pType,eigenType,QT,tdirType,ksEltype,SolType,F,CacheType,O,FSALType} <: DiffEqBase.AbstractODEIntegrator | |
sol::SolType | |
u::uType | |
k::ksEltype | |
t::tType | |
dt::tType | |
f::F | |
p::pType | |
uprev::uType | |
uprev2::uType | |
tprev::tType | |
alg::algType | |
dtcache::tType | |
dtchangeable::Bool | |
dtpropose::tType | |
tdir::tdirType | |
eigen_est::eigenType | |
EEst::QT | |
qold::QT | |
q11::QT | |
erracc::QT | |
dtacc::tType | |
success_iter::Int | |
iter::Int | |
saveiter::Int | |
saveiter_dense::Int | |
cache::CacheType | |
kshortsize::Int | |
force_stepfail::Bool | |
last_stepfail::Bool | |
just_hit_tstop::Bool | |
event_last_time::Bool | |
accept_step::Bool | |
isout::Bool | |
reeval_fsal::Bool | |
u_modified::Bool | |
opts::O | |
fsalfirst::FSALType | |
fsallast::FSALType | |
end | |
@show @macroexpand(@extend_inner ODEIntegrator -2 mutable struct DDEIntegrator{absType,relType,residType,IType,NType,tstopsType} <: AbstractDDEIntegrator | |
prev_idx::Int | |
prev2_idx::Int | |
fixedpoint_abstol::absType | |
fixedpoint_reltol::relType | |
resid::residType # This would have to resize for resizing DDE to work | |
fixedpoint_norm::NType | |
max_fixedpoint_iters::Int | |
saveat::tstopsType | |
tracked_discontinuities::Vector{Discontinuity{tType}} | |
integrator::IType | |
end) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment