Skip to content

Instantly share code, notes, and snippets.

@jkrumbiegel
Last active February 1, 2022 14:59
Show Gist options
  • Save jkrumbiegel/f91b6d539f317eaa9f7499dd4b2a55a0 to your computer and use it in GitHub Desktop.
Save jkrumbiegel/f91b6d539f317eaa9f7499dd4b2a55a0 to your computer and use it in GitHub Desktop.
playing around with tagged unions
function struct_reinterpret(::Type{T}, x)::T where T
ref = Base.RefValue(x)
GC.@preserve ref unsafe_load(Ptr{T}(pointer_from_objref(ref)))
end
struct TaggedUnion
type::UInt8
bytes::NTuple{8, UInt8} # would be the max size of all types used, determined automatically
end
types(::Type{TaggedUnion}) = (Int64, Float64, Float32)
function TaggedUnion(x)
i = findfirst(y -> y == typeof(x), types(TaggedUnion))
i === nothing && error("Type not valid.")
T = types(TaggedUnion)[i]
NT = fieldtype(TaggedUnion, :bytes)
ref = Base.RefValue{NT}()
GC.@preserve ref begin
ptr = Base.unsafe_convert(Ptr{NT}, ref)
ptr_converted = Ptr{T}(ptr)
unsafe_store!(ptr_converted, x)
end
TaggedUnion(i-1, ref[])
end
m1 = TaggedUnion(1)
m2 = TaggedUnion(1.0)
m3 = TaggedUnion(1f0)
# should be macro generated
function do_something(f, m::TaggedUnion)
if m.type == 0
return f(struct_reinterpret(Int64, m.bytes))
elseif m.type == 1
return f(struct_reinterpret(Float64, m.bytes))
elseif m.type == 2
return f(struct_reinterpret(Float32, m.bytes))
end
error()
end
function Base.show(io::IO, m::TaggedUnion)
print(io, "TaggedUnion: ")
do_something(x -> show(io, x), m)
end
f(x::Int64) = :was_an_int64
f(x::Float64) = :was_a_float64
f(x::Float32) = :was_a_float32
@time do_something(f, m1)
@time do_something(f, m2)
@time do_something(f, m3)
m_arr = [TaggedUnion(x) for x in (1, 2.0, 3f0) for _ in 1:1000]
@time do_something.(f, m_arr)
module TaggedUnions
function struct_reinterpret(::Type{T}, x)::T where T
ref = Base.RefValue(x)
GC.@preserve ref unsafe_load(Ptr{T}(pointer_from_objref(ref)))
end
function make_if_block(types)
e = Expr(:if)
current_e = e
n = length(types)
for (i, type) in enumerate(types)
push!(current_e.args, :(m.type == $(i-1)))
push!(current_e.args, :(return f(struct_reinterpret($type, m.bytes))))
if i < n
new_e = Expr(:elseif )
push!(current_e.args, new_e)
current_e = new_e
end
end
push!(current_e.args, :(error()))
e
end
macro TaggedUnion(name::Symbol, types_expr::Expr)
typelist = types_expr.args
structname = esc(name)
quote
typs = $(esc(types_expr))
@assert all(isbitstype, typs)
@assert length(typs) <= 256
sz = maximum(sizeof, typs)
struct $structname
type::UInt8
bytes::NTuple{sz, UInt8}
end
function $structname(x)
i = findfirst(y -> y == typeof(x), types($structname))
i === nothing && error("Type not valid.")
T = types($structname)[i]
NT = fieldtype($structname, :bytes)
ref = Base.RefValue{NT}()
GC.@preserve ref begin
ptr = Base.unsafe_convert(Ptr{NT}, ref)
ptr_converted = Ptr{T}(ptr)
unsafe_store!(ptr_converted, x)
end
$structname(i-1, ref[])
end
types(::Type{$structname}) = $(esc(types_expr))
function Base.show(io::IO, m::$structname)
print(io, "$($(string(name)))(")
do_something(x -> show(io, x), m)
print(io, ")")
end
function do_something(f, m::$structname)
$(make_if_block(typelist))
end
$structname
end
end
end
@TaggedUnions.TaggedUnion MyType2 (Int64, Float64, Float32, Int)
MyType2(1.0f0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment