Last active
March 7, 2021 09:03
-
-
Save thautwarm/62c738ac352ff44592776df26bf0c3b7 to your computer and use it in GitHub Desktop.
[julia] convenient interface for c-interops(zero-cost get element pointer/type-checked ccall, etc)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
modules that provides convenient pointer operations. | |
```julia | |
struct C | |
a :: Cint | |
end | |
struct A | |
a :: Cint | |
b :: Cdouble | |
c :: C | |
end | |
a :: Ptr{A} | |
cp = to_cptr(a) :: CPtr{A} | |
cp.b :: CPtr{Cdouble} | |
cp.b[] :: Cdouble | |
cp.b[] = Cdouble(1.0) | |
cp.c.a[] = 1 | |
``` | |
""" | |
module CPointer | |
export CPtr, to_cptr, to_stdptr, CFuncType | |
primitive type CPtr{A} sizeof(Ptr{Nothing}) * 8 end | |
function Base.show(io::IO, p::CPtr{a}) where a | |
Base.show(io, a) | |
print(io, "* (") | |
Base.show(io, reinterpret(UInt64, p)) | |
print(io, ")") | |
end | |
@inline Base.convert(::Type{CPtr{A}}, a) where A = to_cptr(a) | |
@inline Base.convert(::Type{Ptr{A}}, a) where A = to_stdptr(a) | |
@inline to_cptr(p::CPtr{A}) where A = p | |
@inline to_cptr(p::Ptr{A}) where A = reinterpret(CPtr{A}, p) | |
@inline to_stdptr(p::CPtr{A}) where A = reinterpret(Ptr{A}, p) | |
@inline to_stdptr(p::Ptr{A}) where A = p | |
# unsafe, do it only if you can make sure objects alive | |
@inline to_cptr(p::Base.RefValue{A}) where A = reinterpret(CPtr{A}, pointer_from_objref(p)) | |
@inline to_cptr(p::Vector{A}) where A = | |
reinterpret(CPtr{A}, unsafe_load(reinterpret(Ptr{Ptr{Nothing}}, pointer_from_objref(p)))) | |
@inline @generated function _get(p::CPtr{A}, s::Val{S}) where {A, S} | |
i = findfirst((==)(S), fieldnames(A)) | |
i === nothing && return :($error("$($A) has no field $($S).")) | |
off = fieldoffset(A, i) | |
t = fieldtype(A, i) | |
if off == 0 | |
:($reinterpret($CPtr{$t}, p)) | |
else | |
:($reinterpret($CPtr{$t}, $reinterpret(UInt, p) + $off)) | |
end | |
end | |
@inline Base.getproperty(p::CPtr, s::Symbol) = _get(p, Val(s)) | |
@inline Base.getindex(p::CPtr, a::Integer) = unsafe_load(to_stdptr(p), a) | |
@inline Base.getindex(p::CPtr) = unsafe_load(to_stdptr(p)) | |
@inline Base.setindex!(p::CPtr{A}, v::B, a::Integer) where {A, B} = unsafe_store!(to_stdptr(p), convert(B, v), a) | |
@inline Base.setindex!(p::CPtr{A}, v::B) where {A, B} = unsafe_store!(to_stdptr(p), convert(B, v)) | |
struct CFuncType{Args <: Tuple, Ret} | |
f :: Ptr | |
end | |
@generated function(p::CPtr{CFuncType{Args, Ret}})(args...) where {Args <: Tuple, Ret} | |
if Tuple{args...} <: Args | |
:(ccall(f.f, $Ret, $args, $(args...))) | |
else | |
params = Tuple(Args.parameters) | |
error("$f expects argument types $params, got $args.") | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment