Created
April 23, 2020 18:36
-
-
Save MasonProtter/ea7588a191cf1ef388af3f26c0105c53 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Syms | |
export sym, Sym | |
#-------------------------------------------------------------------------------- | |
# Set up some symbolic types | |
#-------------------------------------------------------------------------------- | |
abstract type Symbolic{T} end # Symbolic{T} will act like it is <: T | |
struct Sym{T} <: Symbolic{T} | |
name::Symbol | |
end | |
Symbol(s::Sym) = s.name | |
struct SymExpr{T} <: Symbolic{T} | |
op | |
args::Vector{Any} | |
end | |
#-------------------------------------------------------------------------------- | |
# Pretty printing | |
#-------------------------------------------------------------------------------- | |
function Base.show(io::IO, s::Sym{T}) where {T} | |
print(io, string(s.name)*"::$T") | |
end | |
expr!(se::SymExpr, Ts) = Expr(:call, expr!(se.op, Ts), expr!.(se.args, (Ts,))...) | |
expr!(x, Ts) = x | |
expr!(f::Function, Ts) = Symbol(f) | |
function expr!(s::Sym{T}, Ts) where {T} | |
if s ∉ Ts | |
push!(Ts, s) | |
end | |
s.name | |
end | |
function Base.show(io::IO, se::SymExpr{T}) where {T} | |
sset = Set() | |
ex = expr!(se, sset) | |
print(io, repr(ex)[2:end]*" :: $T" * " where {"*repr(sset)[9:end-2]*"}") | |
end | |
#-------------------------------------------------------------------------------- | |
# Set up the IRTools pass | |
#-------------------------------------------------------------------------------- | |
sym_substitute(::Type{Sym{T}}) where {T} = T | |
sym_substitute(::Type{SymExpr{T}}) where {T} = T | |
sym_substitute(::Type{T}) where {T} = T | |
using IRTools: @dynamo, argument!, IR, isexpr | |
@dynamo function sneakyinvoke(f, ::Type{T}, args...) where T<:Tuple | |
ir = IR(f, T.parameters...) | |
argument!(ir, at = 2) | |
return ir | |
end | |
@dynamo function sym(args...) | |
ir = IR(args...) | |
ir == nothing && return | |
for (x, st) in ir | |
isexpr(st.expr, :call) || continue | |
ir[x] = Expr(:call, _sym, st.expr.args...) | |
end | |
return ir | |
end | |
_sym(f::Core.IntrinsicFunction, args...) = f(args...) | |
_sym(f) = sym(f) | |
function _sym(f, args...) | |
argsT = typeof.(args) | |
if any((<:).(argsT, Symbolic)) | |
argsT′ = Tuple{sym_substitute.(argsT)...} | |
rt = Core.Compiler.return_type(f, argsT′) | |
if isprimitive(f) | |
SymExpr{rt}(f, [args...]) | |
else | |
sym(sneakyinvoke, f, argsT′, args...) | |
end | |
else | |
sym(f, args...) | |
end | |
end | |
# If isprimitive(f) == true, then inside a pass, we won't recurse into the insides of f. | |
# Primitives are stopping points for us | |
for f in [:+, :-, :*, :/, :^, :exp, :log, | |
:sin, :cos, :tan, :asin, :acos, :atan, | |
:sinh, :cosh, :tanh, :asinh, :acosh, :atanh, :adjoint] | |
@eval isprimitive(::typeof($f)) = true | |
end | |
isprimitive(::Any) = false | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment