Last active
February 1, 2022 14:59
-
-
Save jkrumbiegel/f91b6d539f317eaa9f7499dd4b2a55a0 to your computer and use it in GitHub Desktop.
playing around with tagged unions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function struct_reinterpret(::Type{T}, x)::T where T | |
ref = Base.RefValue(x) | |
GC.@preserve ref unsafe_load(Ptr{T}(pointer_from_objref(ref))) | |
end | |
struct TaggedUnion | |
type::UInt8 | |
bytes::NTuple{8, UInt8} # would be the max size of all types used, determined automatically | |
end | |
types(::Type{TaggedUnion}) = (Int64, Float64, Float32) | |
function TaggedUnion(x) | |
i = findfirst(y -> y == typeof(x), types(TaggedUnion)) | |
i === nothing && error("Type not valid.") | |
T = types(TaggedUnion)[i] | |
NT = fieldtype(TaggedUnion, :bytes) | |
ref = Base.RefValue{NT}() | |
GC.@preserve ref begin | |
ptr = Base.unsafe_convert(Ptr{NT}, ref) | |
ptr_converted = Ptr{T}(ptr) | |
unsafe_store!(ptr_converted, x) | |
end | |
TaggedUnion(i-1, ref[]) | |
end | |
m1 = TaggedUnion(1) | |
m2 = TaggedUnion(1.0) | |
m3 = TaggedUnion(1f0) | |
# should be macro generated | |
function do_something(f, m::TaggedUnion) | |
if m.type == 0 | |
return f(struct_reinterpret(Int64, m.bytes)) | |
elseif m.type == 1 | |
return f(struct_reinterpret(Float64, m.bytes)) | |
elseif m.type == 2 | |
return f(struct_reinterpret(Float32, m.bytes)) | |
end | |
error() | |
end | |
function Base.show(io::IO, m::TaggedUnion) | |
print(io, "TaggedUnion: ") | |
do_something(x -> show(io, x), m) | |
end | |
f(x::Int64) = :was_an_int64 | |
f(x::Float64) = :was_a_float64 | |
f(x::Float32) = :was_a_float32 | |
@time do_something(f, m1) | |
@time do_something(f, m2) | |
@time do_something(f, m3) | |
m_arr = [TaggedUnion(x) for x in (1, 2.0, 3f0) for _ in 1:1000] | |
@time do_something.(f, m_arr) | |
module TaggedUnions | |
function struct_reinterpret(::Type{T}, x)::T where T | |
ref = Base.RefValue(x) | |
GC.@preserve ref unsafe_load(Ptr{T}(pointer_from_objref(ref))) | |
end | |
function make_if_block(types) | |
e = Expr(:if) | |
current_e = e | |
n = length(types) | |
for (i, type) in enumerate(types) | |
push!(current_e.args, :(m.type == $(i-1))) | |
push!(current_e.args, :(return f(struct_reinterpret($type, m.bytes)))) | |
if i < n | |
new_e = Expr(:elseif ) | |
push!(current_e.args, new_e) | |
current_e = new_e | |
end | |
end | |
push!(current_e.args, :(error())) | |
e | |
end | |
macro TaggedUnion(name::Symbol, types_expr::Expr) | |
typelist = types_expr.args | |
structname = esc(name) | |
quote | |
typs = $(esc(types_expr)) | |
@assert all(isbitstype, typs) | |
@assert length(typs) <= 256 | |
sz = maximum(sizeof, typs) | |
struct $structname | |
type::UInt8 | |
bytes::NTuple{sz, UInt8} | |
end | |
function $structname(x) | |
i = findfirst(y -> y == typeof(x), types($structname)) | |
i === nothing && error("Type not valid.") | |
T = types($structname)[i] | |
NT = fieldtype($structname, :bytes) | |
ref = Base.RefValue{NT}() | |
GC.@preserve ref begin | |
ptr = Base.unsafe_convert(Ptr{NT}, ref) | |
ptr_converted = Ptr{T}(ptr) | |
unsafe_store!(ptr_converted, x) | |
end | |
$structname(i-1, ref[]) | |
end | |
types(::Type{$structname}) = $(esc(types_expr)) | |
function Base.show(io::IO, m::$structname) | |
print(io, "$($(string(name)))(") | |
do_something(x -> show(io, x), m) | |
print(io, ")") | |
end | |
function do_something(f, m::$structname) | |
$(make_if_block(typelist)) | |
end | |
$structname | |
end | |
end | |
end | |
@TaggedUnions.TaggedUnion MyType2 (Int64, Float64, Float32, Int) | |
MyType2(1.0f0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment