Skip to content

Instantly share code, notes, and snippets.

@MasonProtter
Created April 23, 2020 18:36
Show Gist options
  • Save MasonProtter/ea7588a191cf1ef388af3f26c0105c53 to your computer and use it in GitHub Desktop.
Save MasonProtter/ea7588a191cf1ef388af3f26c0105c53 to your computer and use it in GitHub Desktop.
module Syms
export sym, Sym
#--------------------------------------------------------------------------------
# Set up some symbolic types
#--------------------------------------------------------------------------------
abstract type Symbolic{T} end # Symbolic{T} will act like it is <: T
struct Sym{T} <: Symbolic{T}
name::Symbol
end
Symbol(s::Sym) = s.name
struct SymExpr{T} <: Symbolic{T}
op
args::Vector{Any}
end
#--------------------------------------------------------------------------------
# Pretty printing
#--------------------------------------------------------------------------------
function Base.show(io::IO, s::Sym{T}) where {T}
print(io, string(s.name)*"::$T")
end
expr!(se::SymExpr, Ts) = Expr(:call, expr!(se.op, Ts), expr!.(se.args, (Ts,))...)
expr!(x, Ts) = x
expr!(f::Function, Ts) = Symbol(f)
function expr!(s::Sym{T}, Ts) where {T}
if s ∉ Ts
push!(Ts, s)
end
s.name
end
function Base.show(io::IO, se::SymExpr{T}) where {T}
sset = Set()
ex = expr!(se, sset)
print(io, repr(ex)[2:end]*" :: $T" * " where {"*repr(sset)[9:end-2]*"}")
end
#--------------------------------------------------------------------------------
# Set up the IRTools pass
#--------------------------------------------------------------------------------
sym_substitute(::Type{Sym{T}}) where {T} = T
sym_substitute(::Type{SymExpr{T}}) where {T} = T
sym_substitute(::Type{T}) where {T} = T
using IRTools: @dynamo, argument!, IR, isexpr
@dynamo function sneakyinvoke(f, ::Type{T}, args...) where T<:Tuple
ir = IR(f, T.parameters...)
argument!(ir, at = 2)
return ir
end
@dynamo function sym(args...)
ir = IR(args...)
ir == nothing && return
for (x, st) in ir
isexpr(st.expr, :call) || continue
ir[x] = Expr(:call, _sym, st.expr.args...)
end
return ir
end
_sym(f::Core.IntrinsicFunction, args...) = f(args...)
_sym(f) = sym(f)
function _sym(f, args...)
argsT = typeof.(args)
if any((<:).(argsT, Symbolic))
argsT′ = Tuple{sym_substitute.(argsT)...}
rt = Core.Compiler.return_type(f, argsT′)
if isprimitive(f)
SymExpr{rt}(f, [args...])
else
sym(sneakyinvoke, f, argsT′, args...)
end
else
sym(f, args...)
end
end
# If isprimitive(f) == true, then inside a pass, we won't recurse into the insides of f.
# Primitives are stopping points for us
for f in [:+, :-, :*, :/, :^, :exp, :log,
:sin, :cos, :tan, :asin, :acos, :atan,
:sinh, :cosh, :tanh, :asinh, :acosh, :atanh, :adjoint]
@eval isprimitive(::typeof($f)) = true
end
isprimitive(::Any) = false
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment