Skip to content

Instantly share code, notes, and snippets.

@jpsamaroo
Created October 24, 2019 22:55
Show Gist options
  • Save jpsamaroo/f91f75d5933767daa71a3aa1f49ae308 to your computer and use it in GitHub Desktop.
Save jpsamaroo/f91f75d5933767daa71a3aa1f49ae308 to your computer and use it in GitHub Desktop.
AMDGPUnative compiler/driver.jl WIP
# compiler driver and main interface
# (::CompilerJob)
const compile_hook = Ref{Union{Nothing,Function}}(nothing)
"""
compile(target::Symbol, agent::HSAAgent, f, tt, kernel=true;
libraries=true, optimize=true, strip=false, strict=true, ...)
Compile a function `f` invoked with types `tt` for agent `agent` to one of the
following formats as specified by the `target` argument: `:julia` for Julia
IR, `:llvm` for LLVM IR, `:gcn` for GCN assembly, and `:roc` for linked
objects. If the `kernel` flag is set, specialized code generation and
optimization for kernel functions is enabled.
The following keyword arguments are supported:
- `libraries`: link the ROCm device libraries (if required)
- `optimize`: optimize the code (default: true)
- `strip`: strip non-functional metadata and debug information (default: false)
- `strict`: perform code validation either as early or as late as possible
Other keyword arguments can be found in the documentation of [`rocfunction`](@ref).
"""
compile(target::Symbol, agent::HSAAgent, @nospecialize(f::Core.Function),
@nospecialize(tt), kernel::Bool=true; libraries::Bool=true,
optimize::Bool=true,
strip::Bool=false, strict::Bool=true, kwargs...) =
compile(target, CompilerJob(f, tt, agent, kernel; kwargs...);
libraries=libraries,
optimize=optimize, strip=strip, strict=strict)
function compile(target::Symbol, job::CompilerJob;
libraries::Bool=true,
optimize::Bool=true, strip::Bool=false, strict::Bool=true)
@debug "(Re)compiling function" job
if compile_hook[] != nothing
global globalUnique
previous_globalUnique = globalUnique
compile_hook[](job)
globalUnique = previous_globalUnique
end
return codegen(target, job;
libraries=libraries,
optimize=optimize, strip=strip, strict=strict)
end
function codegen(target::Symbol, job::CompilerJob;
libraries::Bool=true, optimize::Bool=true,
strip::Bool=false, strict::Bool=true)
## Julia IR
@timeit to[] "validation" check_method(job)
@timeit to[] "Julia front-end" begin
# get the method instance
world = typemax(UInt)
meth = which(job.f, job.tt)
sig = Base.signature_type(job.f, job.tt)::Type
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig, meth.sig)::Core.SimpleVector
if VERSION >= v"1.2.0-DEV.320"
meth = Base.func_for_method_checked(meth, ti, env)
else
meth = Base.func_for_method_checked(meth, ti)
end
method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, ti, env, world)
for var in env
if var isa TypeVar
throw(KernelError(job, "method captures a typevar (you probably use an unbound type variable)"))
end
end
end
target == :julia && return method_instance
## LLVM IR
defs(mod) = filter(f -> !isdeclaration(f), collect(functions(mod)))
decls(mod) = filter(f -> isdeclaration(f) && intrinsic_id(f) == 0,
collect(functions(mod)))
need_library(lib) = any(f -> isdeclaration(f) &&
intrinsic_id(f) == 0 &&
haskey(functions(lib), LLVM.name(f)),
functions(mod))
# always preload the runtime, and do so early; it cannot be part of any timing block
# because it recurses into the compiler
if libraries
# FIXME: device_libs = load_device_libs(job.agent)
runtime = load_runtime(job.agent)
runtime_fns = LLVM.name.(defs(runtime))
end
@timeit to[] "LLVM middle-end" begin
ir, kernel = @timeit to[] "IR generation" irgen(job, method_instance, world)
# AG: mod, entry = irgen(job)
if libraries
undefined_fns = LLVM.name.(decls(ir))
# FIXME: Remove this "__nv_" check
if any(fn->startswith(fn, "__nv_"), undefined_fns)
libdevice = load_libdevice(job.cap)
@timeit to[] "device library" link_libdevice!(job, ir, libdevice)
end
#= FIXME
for lib in device_libs
if need_library(lib)
link_device_lib!(job, mod, lib)
end
end
link_oclc_defaults!(job, mod)
=#
end
if optimize
kernel = @timeit to[] "optimization" optimize!(job, ir, kernel)
end
if libraries
undefined_fns = LLVM.name.(decls(ir))
if any(fn -> fn in runtime_fns, undefined_fns)
@timeit to[] "runtime library" link_library!(job, ir, runtime)
end
end
if ccall(:jl_is_debugbuild, Cint, ()) == 1
@timeit to[] "verification" verify(ir)
end
kernel_fn = LLVM.name(kernel)
end
if strict
# NOTE: keep in sync with non-strict check below
@timeit to[] "validation" begin
check_invocation(job, kernel)
check_ir(job, ir)
end
end
if strip
@timeit to[] "strip debug info" strip_debuginfo!(ir)
end
target == :llvm && return ir, kernel
## GCN machine code
@timeit to[] "LLVM back-end" begin
@timeit to[] "preparation" prepare_execution!(job, ir)
asm = @timeit to[] "machine-code generation" mcgen(job, ir, kernel)
end
target == :gcn && return asm, kernel_fn
## CUDA objects
if !strict
# NOTE: keep in sync with strict check above
@timeit to[] "validation" begin
check_invocation(job, kernel)
check_ir(job, ir)
end
end
# FIXME: Pull the ld linking functionality into this code block?
@timeit to[] "CUDA object generation" begin
# enable debug options based on Julia's debug setting
jit_options = Dict{Any,Any}()
#= FIXME: Debug options
if Base.JLOptions().debug_level == 1
jit_options[CUDAdrv.GENERATE_LINE_INFO] = true
elseif Base.JLOptions().debug_level >= 2
jit_options[CUDAdrv.GENERATE_DEBUG_INFO] = true
end
=#
# link the CUDA device library
image = asm
if libraries
# linking the device runtime library requires use of the CUDA linker,
# which in turn switches compilation to device relocatable code (-rdc) mode.
#
# even if not doing any actual calls that need -rdc (i.e., calls to the runtime
# library), this significantly hurts performance, so don't do it unconditionally
undefined_fns = LLVM.name.(decls(ir))
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
"__nvvm_reflect" #= TODO: should have been optimized away =#]
if !isempty(setdiff(undefined_fns, intrinsic_fns))
@timeit to[] "device runtime library" begin
linker = CUDAdrv.CuLink(jit_options)
CUDAdrv.add_file!(linker, libcudadevrt, CUDAdrv.LIBRARY)
CUDAdrv.add_data!(linker, kernel_fn, asm)
image = CUDAdrv.complete(linker)
end
end
end
@timeit to[] "compilation" begin
roc_mod = ROCModule(image, jit_options)
roc_fun = ROCFunction(roc_mod, kernel_fn)
end
end
target == :roc && return roc_fun, roc_mod
error("Unknown compilation target $target")
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment