Created
June 23, 2021 21:51
-
-
Save Keno/907c7ce6a7393f8d4224a4ac24c68b12 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#= | |
using Revise | |
Revise.track(Core.Compiler) | |
=# | |
using Symbolics | |
bar(a, b) = a * b | |
function foo(a, b, c) | |
bar(a, b) + c | |
end | |
function matmul(a, b) | |
a * b | |
end | |
function concat(a, b) | |
vcat(a, b) | |
end | |
using Cthulhu | |
using Symbolics | |
struct SymbolicInt | |
sym::Symbolics.Symbolic{Int} | |
end | |
Base.show(io::IO, a::SymbolicInt) = show(io, a.sym) | |
struct SymbolicDimArray | |
dims::Any | |
end | |
using Cthulhu: CthulhuInterpreter, get_specialization, | |
InferenceResult, InferenceState | |
using Core.Compiler: CustomLattice, LatticeCallbacks, typeinf, Builtin, Const, | |
getfield_tfunc, ⊑, PartialStruct, abstract_call_gf_by_type, AbstractInterpreter, | |
tuple_tfunc, CallMeta, nfields_tfunc, LatticeUnion, tmerge, instanceof_tfunc, | |
widenconst | |
using Base.Experimental: @opaque | |
function Base.show(io::IO, cl::CustomLattice) | |
print(io, string("CustomLattice(", cl.payload, ")")) | |
end | |
function Base.show(io::IO, a::SymbolicDimArray) | |
if !isa(a.dims, PartialStruct) | |
prin(io, "D(", a.dims, ")") | |
end | |
print(io, "D(") | |
join(io, map(a.dims.fields) do x | |
isa(x, CustomLattice) ? x.payload : x | |
end, ',') | |
print(io, ")") | |
end | |
@syms a::Int b::Int c::Int | |
function symbolic_tmerge(a, b) | |
if isa(a, Const) && isa(b, CustomLattice) | |
if isa(b.payload, SymbolicInt) | |
return LatticeUnion(a, b) | |
end | |
elseif isa(a, CustomLattice) && isa(b, Const) | |
if isa(a.payload, SymbolicInt) | |
return LatticeUnion(a, b) | |
end | |
end | |
if isa(a, CustomLattice) && isa(b, CustomLattice) && | |
isa(a.payload, SymbolicInt) && | |
isa(b.payload, SymbolicInt) | |
ap = a.payload | |
bp = b.payload | |
if isequal(ap.sym, bp.sym) | |
return a | |
else | |
return LatticeUnion(a, b) | |
end | |
end | |
if isa(a, CustomLattice) && isa(b, CustomLattice) && | |
isa(a.payload, SymbolicDimArray) && | |
isa(b.payload, SymbolicDimArray) | |
ap = a.payload | |
bp = b.payload | |
@assert isa(ap.dims, PartialStruct) | |
@assert isa(bp.dims, PartialStruct) | |
apd = ap.dims | |
bpd = bp.dims | |
@assert length(apd.fields) == length(bpd.fields) | |
return CustomLattice(a.typ, SymbolicDimArray( | |
PartialStruct(Tuple{ntuple(_->Int, length(apd.fields))...}, | |
Any[tmerge(a, b) for (a, b) in zip(apd.fields, bpd.fields)])), | |
callbacks) | |
end | |
@show (a, b) | |
error() | |
end | |
callbacks = LatticeCallbacks( | |
(@opaque (a::Any, b)->Base.invokelatest(⊑ₛ, a, b)), | |
(@opaque (a::Any, b)->Base.invokelatest(symbolic_tmeetbound, a, b)), | |
(@opaque (a::Any, b)->Base.invokelatest(symbolic_tmerge, a, b)), | |
(@opaque (a::Builtin, b::Vector{Any})->Base.invokelatest(symbolic_tfunc, a, b)), | |
) | |
function Core.Compiler.abstract_call_gf_by_type(interp::CthulhuInterpreter, @nospecialize(f), | |
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype), | |
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) | |
if isa(f, Type) && f <: Array | |
if argtypes[2] == Const(undef) && all(x->x ⊑ Int, argtypes[3:end]) | |
if any(x->isa(x, CustomLattice), argtypes[3:end]) | |
return CallMeta(CustomLattice(f, SymbolicDimArray(tuple_tfunc(argtypes[3:end])), callbacks), false) | |
end | |
end | |
end | |
invoke(Core.Compiler.abstract_call_gf_by_type, | |
Tuple{AbstractInterpreter, Any, Union{Nothing,Vector{Any}}, Vector{Any}, | |
Any, InferenceState, Int}, | |
interp, f, fargs, argtypes, atype, sv, max_methods) | |
end | |
function ⊑ₛ(a, b) | |
@assert isa(a, CustomLattice) | |
@Core.Main.Base.show (a, b) | |
if isa(a.payload, SymbolicInt) | |
isa(b, Int64) && return true | |
else isa(a.payload, SymbolicDimArray) | |
@assert isa(b.payload, SymbolicDimArray) | |
return ⊑(a.payload.dims, b.payload.dims) | |
end | |
end | |
function symbolic_tfunc(f, args) | |
@show f === Core.Intrinsics.mul_int | |
if f === Core.Intrinsics.mul_int || f === Core.Intrinsics.add_int | |
a2 = args[2] | |
a3 = args[3] | |
if !(isa(a2, CustomLattice) && isa(a2.payload, SymbolicInt)) && | |
!isa(a2, Const) | |
return Any | |
end | |
if !(isa(a3, CustomLattice) && isa(a3.payload, SymbolicInt)) && | |
!isa(a3, Const) | |
return Any | |
end | |
a2v = isa(a2, Const) ? a2.val : a2.payload.sym | |
a3v = isa(a3, Const) ? a3.val : a3.payload.sym | |
rva = if f === Core.Intrinsics.mul_int | |
a2v * a3v | |
elseif f === Core.Intrinsics.add_int | |
a2v + a3v | |
elseif f === Core.Intrinsics.sub_int | |
a2v - a3v | |
else | |
error() | |
end | |
@show (a2v, a3v, rva) | |
if isa(rva, Int) | |
return Const(rva) | |
else | |
return CustomLattice(Int, SymbolicInt(rva), callbacks) | |
end | |
elseif f === Core.arraysize | |
a2 = args[2] | |
a3 = args[3] | |
isa(a2, CustomLattice) || return Int | |
isa(a3, Const) || return Int | |
a2p = a2.payload | |
if isa(a2p, SymbolicDimArray) | |
return getfield_tfunc(a2p.dims, a3) | |
else | |
error() | |
end | |
elseif f === Core.Intrinsics.arraylen | |
a2 = args[2] | |
isa(a2, CustomLattice) || return Int | |
a2p = a2.payload | |
if isa(a2p, SymbolicDimArray) | |
@assert isa(a2p.dims, PartialStruct) | |
nf = nfields_tfunc(a2p.dims) | |
@assert isa(nf, Const) | |
flds = map(1:nf.val) do i | |
getfield_tfunc(a2p.dims, Const(i)) | |
end | |
length(flds) == 1 && return flds[1] | |
return foldl(flds) do a, b | |
symbolic_tfunc(Core.Intrinsics.mul_int, Any[Const(Core.Intrinsics.mul_int), a, b]) | |
end | |
else | |
error() | |
end | |
elseif f === typeassert | |
a2 = args[2] | |
a3 = args[3] | |
t = instanceof_tfunc(a3)[1] | |
if a2.typ <: widenconst(t) | |
return a2 | |
end | |
error() | |
else | |
@show f | |
@show args | |
error() | |
end | |
end | |
function symbolic_tmeetbound(a, b) | |
b === Any && return a | |
a === Any && return b | |
if isa(b, CustomLattice) | |
if a === b.typ | |
return b | |
end | |
if Core.Compiler.tmeet(a, b.typ) === Union{} | |
return Union{} | |
end | |
if isa(a, CustomLattice) | |
if isa(a.payload, SymbolicInt) && isa(b.payload, SymbolicInt) | |
if isequal(a.payload.sym, b.payload.sym) | |
return a | |
end | |
end | |
end | |
end | |
@show (a, b) | |
error() | |
end | |
avar = CustomLattice(Int, | |
SymbolicInt(a), | |
callbacks) | |
bvar = CustomLattice(Int, | |
SymbolicInt(b), | |
callbacks) | |
cvar = CustomLattice(Int, | |
SymbolicInt(c), | |
callbacks) | |
function run_1() | |
mi = get_specialization(foo, Tuple{Int, Int, Int}) | |
interp = CthulhuInterpreter() | |
result = InferenceResult(mi, [typeof(foo), avar, bvar, cvar]) | |
frame = InferenceState(result, true, interp) | |
typeinf(interp, frame) | |
@show frame.src | |
@show interp.msgs | |
end | |
aavar = CustomLattice(Array{Float64, 2}, | |
SymbolicDimArray( | |
Core.Compiler.PartialStruct(Tuple{Int, Int}, Any[avar, bvar])), | |
callbacks) | |
bbvar = CustomLattice(Array{Float64, 2}, | |
SymbolicDimArray( | |
Core.Compiler.PartialStruct(Tuple{Int, Int}, Any[bvar, cvar])), | |
callbacks) | |
function run_2() | |
mi = get_specialization(matmul, Tuple{Array{Float64, 2}, Array{Float64, 2}}) | |
interp = CthulhuInterpreter() | |
result = InferenceResult(mi, [typeof(matmul), aavar, bbvar]) | |
frame = InferenceState(result, true, interp) | |
typeinf(interp, frame) | |
@show frame.src | |
@show interp.msgs | |
end | |
aavar2 = CustomLattice(Array{Float64, 2}, | |
SymbolicDimArray( | |
Core.Compiler.PartialStruct(Tuple{Int, Int}, Any[avar, bvar])), | |
callbacks) | |
bbvar2 = CustomLattice(Array{Float64, 2}, | |
SymbolicDimArray( | |
Core.Compiler.PartialStruct(Tuple{Int, Int}, Any[cvar, bvar])), | |
callbacks) | |
function run_3() | |
mi = get_specialization(concat, Tuple{Array{Float64, 2}, Array{Float64, 2}}) | |
interp = CthulhuInterpreter() | |
result = InferenceResult(mi, [typeof(concat), aavar2, bbvar2]) | |
frame = InferenceState(result, true, interp) | |
typeinf(interp, frame) | |
@show frame.src | |
@show interp.msgs | |
end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment