Created
January 13, 2023 17:20
-
-
Save Tokazama/eba2474f13754cbb3b9fadf5151ebf48 to your computer and use it in GitHub Desktop.
Reference types from A(tomic) to S(tatic)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module RefTypes | |
import .Base.Sys: ARCH, WORD_SIZE | |
export | |
AtomicRef, | |
ImmutableRef, | |
MutableRef, | |
StaticRef, | |
add!, | |
and!, | |
dec, | |
dec!, | |
inc, | |
inc!, | |
max!, | |
min!, | |
nand!, | |
or!, | |
sub!, | |
swap!, | |
xor! | |
_bool(x::Bool) = x | |
_rettype(::Type{T}, value::T) where {T} = value | |
struct StaticRef{T,value} <: Ref{T} | |
StaticRef{T,value}() where {T,value} = new{T,value::T}() | |
StaticRef{T}(value::T) where {T} = new{T,value}() | |
StaticRef{T}(value) where {T} = StaticRef{T}(convert(T, value)) | |
StaticRef{T}() where {T} = StaticRef{T}(T()) | |
StaticRef(value::T) where {T} = StaticRef{T}(value) | |
end | |
struct ImmutableRef{T} <: Ref{T} | |
value::T | |
ImmutableRef{T}(v) where {T} = new{T}(v) | |
ImmutableRef{T}() where {T} = ImmutableRef{T}(T()) | |
ImmutableRef(v::T) where {T} = ImmutableRef{T}(v) | |
end | |
mutable struct MutableRef{T} <: Ref{T} | |
value::T | |
MutableRef{T}(v) where {T} = new{T}(v) | |
MutableRef{T}() where {T} = MutableRef{T}(T()) | |
MutableRef(v::T) where {T} = MutableRef{T}(v) | |
end | |
# Filter out unsupported atomic types on platforms | |
# - 128-bit atomics do not exist on AArch32. | |
# - Omitting 128-bit types on 32bit x86 and ppc64 | |
# - LLVM doesn't currently support atomics on floats for ppc64 | |
# C++20 is adding limited support for atomics on float, but as of | |
# now Clang does not support that yet. | |
if Base.Sys.ARCH === :i686 || startswith(string(Base.Sys.ARCH), "arm") || | |
Base.Sys.ARCH === :powerpc64le || Base.Sys.ARCH === :ppc64le | |
const inttypes = (Int8, Int16, Int32, Int64, | |
UInt8, UInt16, UInt32, UInt64) | |
else | |
const inttypes = (Int8, Int16, Int32, Int64, Int128, | |
UInt8, UInt16, UInt32, UInt64, UInt128) | |
end | |
const floattypes = (Float16, Float32, Float64) | |
const arithmetictypes = (inttypes..., floattypes...) | |
# TODO: Support Ptr | |
if Base.Sys.ARCH === :powerpc64le || Base.Sys.ARCH === :ppc64le | |
const atomictypes = (inttypes...,) | |
else | |
const atomictypes = (arithmetictypes...,) | |
end | |
const IntTypes = Union{inttypes...} | |
const FloatTypes = Union{floattypes...} | |
const ArithmeticTypes = Union{arithmetictypes...} | |
const AtomicRefTypes = Union{atomictypes...} | |
mutable struct AtomicRef{T} <: Ref{T} | |
@atomic value::T | |
AtomicRef{T}(value::T) where {T} = new{T}(value) | |
AtomicRef{T}() where {T} = AtomicRef{T}(T()) | |
AtomicRef(value::T) where {T} = AtomicRef{T}(value) | |
end | |
function Base.unsafe_convert(::Type{Ptr{T}}, x::AtomicRef{T}) where {T} | |
convert(Ptr{T}, pointer_from_objref(x)) | |
end | |
Base.setindex!(x::AtomicRef{T}, v) where {T} = setindex!(x, convert(T, v)) | |
const llvmtypes = IdDict{Any,String}( | |
Int8 => "i8", UInt8 => "i8", | |
Int16 => "i16", UInt16 => "i16", | |
Int32 => "i32", UInt32 => "i32", | |
Int64 => "i64", UInt64 => "i64", | |
Int128 => "i128", UInt128 => "i128", | |
Float16 => "half", | |
Float32 => "float", | |
Float64 => "double", | |
) | |
inttype(::Type{T}) where {T<:Integer} = T | |
inttype(::Type{Float16}) = Int16 | |
inttype(::Type{Float32}) = Int32 | |
inttype(::Type{Float64}) = Int64 | |
import ..Base.gc_alignment | |
# All atomic operations have acquire and/or release semantics, depending on | |
# whether the load or store values. Most of the time, this is what one wants | |
# anyway, and it's only moderately expensive on most hardware. | |
for typ in atomictypes | |
lt = llvmtypes[typ] | |
ilt = llvmtypes[inttype(typ)] | |
rt = "$lt, $lt*" | |
irt = "$ilt, $ilt*" | |
@eval Base.getindex(x::AtomicRef{$typ}) = | |
GC.@preserve x Base.llvmcall($""" | |
%ptr = inttoptr i$(WORD_SIZE) %0 to $lt* | |
%rv = load atomic $rt %ptr acquire, align $(gc_alignment(typ)) | |
ret $lt %rv | |
""", $typ, Tuple{Ptr{$typ}}, Base.unsafe_convert(Ptr{$typ}, x)) | |
@eval Base.setindex!(x::AtomicRef{$typ}, v::$typ) = | |
GC.@preserve x Base.llvmcall($""" | |
%ptr = inttoptr i$(WORD_SIZE) %0 to $lt* | |
store atomic $lt %1, $lt* %ptr release, align $(gc_alignment(typ)) | |
ret void | |
""", Cvoid, Tuple{Ptr{$typ}, $typ}, Base.unsafe_convert(Ptr{$typ}, x), v) | |
# Note: atomic_cas! succeeded (i.e. it stored "new") if and only if the result is "cmp" | |
if typ <: Integer | |
@eval cas!(x::AtomicRef{$typ}, cmp::$typ, new::$typ) = | |
GC.@preserve x Base.llvmcall($""" | |
%ptr = inttoptr i$(WORD_SIZE) %0 to $lt* | |
%rs = cmpxchg $lt* %ptr, $lt %1, $lt %2 acq_rel acquire | |
%rv = extractvalue { $lt, i1 } %rs, 0 | |
ret $lt %rv | |
""", $typ, Tuple{Ptr{$typ},$typ,$typ}, | |
Base.unsafe_convert(Ptr{$typ}, x), cmp, new) | |
else | |
@eval cas!(x::AtomicRef{$typ}, cmp::$typ, new::$typ) = | |
GC.@preserve x Base.llvmcall($""" | |
%iptr = inttoptr i$WORD_SIZE %0 to $ilt* | |
%icmp = bitcast $lt %1 to $ilt | |
%inew = bitcast $lt %2 to $ilt | |
%irs = cmpxchg $ilt* %iptr, $ilt %icmp, $ilt %inew acq_rel acquire | |
%irv = extractvalue { $ilt, i1 } %irs, 0 | |
%rv = bitcast $ilt %irv to $lt | |
ret $lt %rv | |
""", $typ, Tuple{Ptr{$typ},$typ,$typ}, | |
Base.unsafe_convert(Ptr{$typ}, x), cmp, new) | |
end | |
arithmetic_ops = [:add, :sub] | |
for rmwop in [arithmetic_ops..., :xchg, :and, :nand, :or, :xor, :max, :min] | |
rmw = string(rmwop) | |
fn = Symbol(rmw, "!") | |
if (rmw == "max" || rmw == "min") && typ <: Unsigned | |
# LLVM distinguishes signedness in the operation, not the integer type. | |
rmw = "u" * rmw | |
end | |
if rmwop in arithmetic_ops && !(typ <: ArithmeticTypes) continue end | |
if typ <: Integer | |
@eval $fn(x::AtomicRef{$typ}, v::$typ) = | |
GC.@preserve x Base.llvmcall($""" | |
%ptr = inttoptr i$WORD_SIZE %0 to $lt* | |
%rv = atomicrmw $rmw $lt* %ptr, $lt %1 acq_rel | |
ret $lt %rv | |
""", $typ, Tuple{Ptr{$typ}, $typ}, Base.unsafe_convert(Ptr{$typ}, x), v) | |
else | |
rmwop === :xchg || continue | |
@eval $fn(x::AtomicRef{$typ}, v::$typ) = | |
GC.@preserve x Base.llvmcall($""" | |
%iptr = inttoptr i$WORD_SIZE %0 to $ilt* | |
%ival = bitcast $lt %1 to $ilt | |
%irv = atomicrmw $rmw $ilt* %iptr, $ilt %ival acq_rel | |
%rv = bitcast $ilt %irv to $lt | |
ret $lt %rv | |
""", $typ, Tuple{Ptr{$typ}, $typ}, Base.unsafe_convert(Ptr{$typ}, x), v) | |
end | |
end | |
end | |
# Provide atomic floating-point operations via atomic_cas! | |
const opnames = Dict{Symbol, Symbol}(:+ => :add, :- => :sub) | |
for op in [:+, :-, :max, :min] | |
opname = get(opnames, op, op) | |
@eval function $(Symbol(opname, "!"))(var::AtomicRef{T}, val::T) where T<:FloatTypes | |
IT = inttype(T) | |
old = var[] | |
while true | |
new = $op(old, val) | |
cmp = old | |
old = cas!(var, cmp, new) | |
reinterpret(IT, old) == reinterpret(IT, cmp) && return old | |
# Temporary solution before we have gc transition support in codegen. | |
ccall(:jl_gc_safepoint, Cvoid, ()) | |
end | |
end | |
end | |
const RefType{T} = Union{StaticRef{T},ImmutableRef{T},MutableRef{T},AtomicRef{T}} | |
# ImmutableRef | |
function Base.promote_rule(::Type{ImmutableRef{X}}, ::Type{<:StaticRef{Y}}) where {X,Y} | |
ImmutableRef{promote_type(X, Y)} | |
end | |
function Base.promote_rule(::Type{ImmutableRef{X}}, ::Type{AtomicRef{Y}}) where {X,Y} | |
ImmutableRef{promote_type(X, Y)} | |
end | |
function Base.promote_rule(::Type{ImmutableRef{X}}, ::Type{MutableRef{Y}}) where {X,Y} | |
ImmutableRef{promote_type(X, Y)} | |
end | |
# MutableRef | |
function Base.promote_rule(::Type{MutableRef{X}}, ::Type{<:ImmutableRef{Y}}) where {X,Y} | |
MutableRef{promote_type(X, Y)} | |
end | |
function Base.promote_rule(::Type{MutableRef{X}}, ::Type{AtomicRef{Y}}) where {X,Y} | |
AtomicRef{promote_type(X, Y)} | |
end | |
# AtomicRef | |
function Base.promote_rule(::Type{AtomicRef{X}}, ::Type{<:StaticRef{Y}}) where {X,Y} | |
AtomicRef{promote_type(X, Y)} | |
end | |
Base.convert(::Type{ImmutableRef{T}}, v::ImmutableRef{T}) where {T} = v | |
Base.convert(::Type{ImmutableRef{T}}, v::Ref) where {T} = ImmutableRef{T}(v[]) | |
Base.convert(::Type{MutableRef{T}}, v::MutableRef{T}) where {T} = v | |
Base.convert(::Type{MutableRef{T}}, v::Ref) where {T} = MutableRef{T}(v[]) | |
Base.convert(::Type{AtomicRef{T}}, v::AtomicRef{T}) where {T} = v | |
Base.convert(::Type{AtomicRef{T}}, v::Ref) where {T} = AtomicRef{T}(v[]) | |
Base.convert(::Type{<:StaticRef{T}}, v::StaticRef{T}) where {T} = v | |
Base.convert(::Type{<:StaticRef{T}}, v::Ref) where {T} = StaticRef{T}(v[]) | |
Base.eltype(::Type{<:RefType{T}}) where {T} = T | |
for f in (:<, :<=, :>, :>=, :(==), :isequal, :isless) | |
eval(:(Base.$(f)(x::RefType, y::RefType) = _bool(Base.$(f)(x[], y[])))) | |
end | |
# replace! | |
function replace!(x::Union{AtomicRef{T},MutableRef{T}}, expected, desired) where {T} | |
replace!(x, convert(T, expected), convert(T, desired)) | |
end | |
function replace!(x::MutableRef{T}, expected::T, desired::T) where {T} | |
getfield(replacefield!(x, 1, expected, desired), 1, false) | |
end | |
function replace!(x::AtomicRef{T}, expected::T, desired::T) where {T} | |
getfield(replacefield!(x, 1, expected, desired, :sequentially_consistent), 1, false) | |
end | |
# TODO doc swap! | |
function swap!(x::Union{AtomicRef{T},MutableRef{T}}, newval) where {T} | |
swap!(x, convert(T, newval)) | |
end | |
function swap!(x::MutableRef{T}, newval::T) where {T} | |
swapfield!(x, 1, newval) | |
end | |
function swap!(x::AtomicRef{T}, newval::T) where {T<:IntTypes} | |
xchg!(x, newval) | |
end | |
function swap!(x::AtomicRef{T}, newval::T) where {T} | |
swapfield!(x, 1, newval, :sequentially_consistent) | |
end | |
# TODO modify! | |
function modify!(r::MutableRef{T}, op, x) where {T} | |
modifyfield!(r, 1, op, x) | |
return r | |
end | |
function modify!(r::AtomicRef{T}, op, x) where {T} | |
modifyfield!(r, 1, op, x, :sequentially_consistent) | |
return r | |
end | |
# TODO doc sub! | |
sub!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = sub!(x, convert(T, y)) | |
function sub!(x::MutableRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, -, y, :sequentially_consistent), 1, false) | |
end | |
function sub!(x::AtomicRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, -, y, :sequentially_consistent), 1, false) | |
end | |
# TODO doc add! | |
add!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = add!(x, convert(T, y)) | |
function add!(x::MutableRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, +, y), 1, false) | |
end | |
function add!(x::AtomicRef{T}, y::T) where {T} | |
getfield!(modifyfield!(x, 1, +, y, :sequentially_consistent), 1, false) | |
end | |
# TODO doc or! | |
or!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = or!(x, convert(T, y)) | |
function or!(x::MutableRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, |, y), 1, false) | |
end | |
function or!(x::AtomicRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, |, y), 1, false) | |
end | |
# TODO doc xor! | |
xor!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = xor!(x, convert(T, y)) | |
function xor!(x::MutableRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, ⊻, y), 1, false) | |
end | |
function xor!(x::AtomicRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, ⊻, y, :sequentially_consistent), 1, false) | |
end | |
# TODO doc and! | |
and!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = and!(x, convert(T, y)) | |
function and!(x::MutableRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, &, y), 1, false) | |
end | |
function and!(x::AtomicRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, &, y, :sequentially_consistent), 1, false) | |
end | |
# TODO doc nand! | |
nand!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = nand!(x, convert(T, y)) | |
function nand!(x::MutableRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, ⊼, y), 1, false) | |
end | |
function nand!(x::AtomicRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, ⊼, y, :sequentially_consistent), 1, false) | |
end | |
# TODO doc max! | |
max!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = max!(x, convert(T, y)) | |
function max!(x::MutableRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, max, y), 1, false) | |
end | |
function max!(x::AtomicRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, max, y, :sequentially_consistent), 1, false) | |
end | |
# TODO doc min! | |
min!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = min!(x, convert(T, y)) | |
function min!(x::MutableRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, min, y), 1, false) | |
end | |
function min!(x::AtomicRef{T}, y::T) where {T} | |
getfield(modifyfield!(x, 1, min, y, :sequentially_consistent), 1, false) | |
end | |
# TODO document inc | |
inc(x::T) where {T} = x + one(T) | |
# TODO document inc! | |
inc!(x::Union{AtomicRef{T},MutableRef{T}}) where {T} = add!(x, one(T)) | |
# TODO document dec | |
dec(x::T) where {T} = x - one(T) | |
# TODO document dec! | |
dec!(x::Union{AtomicRef{T},MutableRef{T}}) where {T} = sub!(x, one(T)) | |
Base.getindex(::StaticRef{T,value}) where {T,value} = _rettype(T, value) | |
Base.getindex(x::ImmutableRef{T}) where {T} = getfield(x, 1) | |
Base.getindex(x::MutableRef{T}) where {T} = getfield(x, 1) | |
Base.getindex(x::AtomicRef{T}) where {T} = getfield(x, 1, :sequentially_consistent) | |
function Base.setindex!(x::Union{MutableRef{T},AtomicRef{T}}, newval) where {T} | |
setindex!(x, convert(T, newval)) | |
end | |
Base.setindex!(x::MutableRef{T}, newval::T) where {T} = setfield!(x, 1, newval) | |
function Base.setindex!(x::AtomicRef{T}, newval::T) where {T} | |
setfield!(x, 1, newval, :sequentially_consistent) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment