Skip to content

Instantly share code, notes, and snippets.

@gmarkall
Created February 18, 2021 15:42
Show Gist options
  • Select an option

  • Save gmarkall/ea64d302482701cdf53f8464eb8c707e to your computer and use it in GitHub Desktop.

Select an option

Save gmarkall/ea64d302482701cdf53f8464eb8c707e to your computer and use it in GitHub Desktop.
Adding a CUDA pipeline to Numba
diff --git a/numba/core/compiler.py b/numba/core/compiler.py
index a0c03fdb3..e47270a2f 100644
--- a/numba/core/compiler.py
+++ b/numba/core/compiler.py
@@ -28,7 +28,8 @@ from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes,
NopythonRewrites, PreParforPass,
ParforPass, DumpParforDiagnostics,
IRLegalization, NoPythonBackend,
- InlineOverloads, PreLowerStripPhis)
+ InlineOverloads, PreLowerStripPhis,
+ NativeLowering)
from numba.core.object_mode_passes import (ObjectModeFrontEnd,
ObjectModeBackEnd)
@@ -476,6 +477,7 @@ class DefaultPassBuilder(object):
"ensure IR is legal prior to lowering")
# lower
+ pm.add_pass(NativeLowering, "native lowering")
pm.add_pass(NoPythonBackend, "nopython mode backend")
pm.add_pass(DumpParforDiagnostics, "dump parfor diagnostics")
pm.finalize()
diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py
index 937fe0bb9..0aa848d36 100644
--- a/numba/core/typed_passes.py
+++ b/numba/core/typed_passes.py
@@ -358,8 +358,15 @@ class NativeLowering(LoweringPass):
LoweringPass.__init__(self)
def run_pass(self, state):
- targetctx = state.targetctx
+ if state.library is None:
+ codegen = state.targetctx.codegen()
+ state.library = codegen.create_library(state.func_id.func_qualname)
+ # Enable object caching upfront, so that the library can
+ # be later serialized.
+ state.library.enable_object_caching()
+
library = state.library
+ targetctx = state.targetctx
interp = state.func_ir # why is it called this?!
typemap = state.typemap
restype = state.return_type
@@ -452,15 +459,6 @@ class NoPythonBackend(LoweringPass):
"""
Back-end: Generate LLVM IR from Numba IR, compile to machine code
"""
- if state.library is None:
- codegen = state.targetctx.codegen()
- state.library = codegen.create_library(state.func_id.func_qualname)
- # Enable object caching upfront, so that the library can
- # be later serialized.
- state.library.enable_object_caching()
-
- # TODO: Pull this out into the pipeline
- NativeLowering().run_pass(state)
lowered = state['cr']
signature = typing.signature(state.return_type, *state.args)
diff --git a/numba/cuda/compiler.py b/numba/cuda/compiler.py
index 887d930e4..5060fce59 100644
--- a/numba/cuda/compiler.py
+++ b/numba/cuda/compiler.py
@@ -14,9 +14,14 @@ from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate
from numba.core import (types, typing, utils, funcdesc, serialize, config,
compiler, sigutils)
from numba.core.typeconv.rules import default_type_manager
+from numba.core.compiler import (CompilerBase, DefaultPassBuilder,
+ compile_result)
from numba.core.compiler_lock import global_compiler_lock
+from numba.core.compiler_machinery import (LoweringPass, PassManager,
+ register_pass)
from numba.core.dispatcher import OmittedArg
from numba.core.errors import NumbaDeprecationWarning
+from numba.core.typed_passes import IRLegalization, NativeLowering
from numba.core.typing.typeof import Purpose, typeof
from warnings import warn
import numba
@@ -28,6 +33,65 @@ from .api import get_current_device
from .args import wrap_arg
+@register_pass(mutates_CFG=True, analysis_only=False)
+class CUDABackend(LoweringPass):
+
+ _name = "cuda_backend"
+
+ def __init__(self):
+ LoweringPass.__init__(self)
+
+ def run_pass(self, state):
+ """
+ Back-end: Packages lowering output in a compile result
+ """
+ lowered = state['cr']
+ signature = typing.signature(state.return_type, *state.args)
+
+ state.cr = compile_result(
+ typing_context=state.typingctx,
+ target_context=state.targetctx,
+ typing_error=state.status.fail_reason,
+ type_annotation=state.type_annotation,
+ library=state.library,
+ call_helper=lowered.call_helper,
+ signature=signature,
+ fndesc=lowered.fndesc,
+ )
+ return True
+
+
+class CUDACompiler(CompilerBase):
+ def define_pipelines(self):
+ dpb = DefaultPassBuilder
+ pm = PassManager('cuda')
+
+ untyped_passes = dpb.define_untyped_pipeline(self.state)
+ pm.passes.extend(untyped_passes.passes)
+
+ typed_passes = dpb.define_typed_pipeline(self.state)
+ pm.passes.extend(typed_passes.passes)
+
+ lowering_passes = self.define_cuda_lowering_pipeline(self.state)
+ pm.passes.extend(lowering_passes.passes)
+
+ pm.finalize()
+ return [pm]
+
+ def define_cuda_lowering_pipeline(self, state):
+ pm = PassManager('cuda_lowering')
+ # legalise
+ pm.add_pass(IRLegalization,
+ "ensure IR is legal prior to lowering")
+
+ # lower
+ pm.add_pass(NativeLowering, "native lowering")
+ pm.add_pass(CUDABackend, "cuda backend")
+
+ pm.finalize()
+ return pm
+
+
@global_compiler_lock
def compile_cuda(pyfunc, return_type, args, debug=False, inline=False):
from .descriptor import cuda_target
@@ -50,11 +114,11 @@ def compile_cuda(pyfunc, return_type, args, debug=False, inline=False):
args=args,
return_type=return_type,
flags=flags,
- locals={})
+ locals={},
+ pipeline_class=CUDACompiler)
library = cres.library
library.finalize()
-
return cres
@@ -90,6 +154,7 @@ def compile_ptx(pyfunc, args, debug=False, device=False, fastmath=False,
"""
cres = compile_cuda(pyfunc, None, args, debug=debug)
resty = cres.signature.return_type
+
if device:
llvm_modules = cres.library.modules
else:
diff --git a/numba/tests/test_errorhandling.py b/numba/tests/test_errorhandling.py
index 83f06211b..96db57a8d 100644
--- a/numba/tests/test_errorhandling.py
+++ b/numba/tests/test_errorhandling.py
@@ -15,7 +15,7 @@ from numba.core.compiler import CompilerBase
from numba.core.untyped_passes import (TranslateByteCode, FixupArgs,
IRProcessing,)
from numba.core.typed_passes import (NopythonTypeInference, DeadCodeElimination,
- NoPythonBackend)
+ NoPythonBackend, NativeLowering)
from numba.core.compiler_machinery import PassManager
from numba.core.types.functions import _err_reasons as error_reasons
@@ -110,6 +110,7 @@ class TestMiscErrorHandling(unittest.TestCase):
pm.add_pass(DeadCodeElimination, "DCE")
# typing
pm.add_pass(NopythonTypeInference, "nopython frontend")
+ pm.add_pass(NativeLowering, "native lowering")
pm.add_pass(NoPythonBackend, "nopython mode backend")
pm.finalize()
return [pm]
diff --git a/numba/tests/test_inlining.py b/numba/tests/test_inlining.py
index 76a7565fa..1b4012545 100644
--- a/numba/tests/test_inlining.py
+++ b/numba/tests/test_inlining.py
@@ -18,7 +18,7 @@ from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode, Fixup
from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes,
NopythonRewrites, PreParforPass, ParforPass,
DumpParforDiagnostics, NativeLowering,
- IRLegalization, NoPythonBackend)
+ IRLegalization, NoPythonBackend, NativeLowering)
from numba.core.compiler_machinery import FunctionPass, PassManager, register_pass
import unittest
@@ -98,6 +98,7 @@ def gen_pipeline(state, test_pass):
pm.add_pass(PreserveIR, "preserve IR")
# lower
+ pm.add_pass(NativeLowering, "native lowering")
pm.add_pass(NoPythonBackend, "nopython mode backend")
pm.add_pass(DumpParforDiagnostics, "dump parfor diagnostics")
return pm
diff --git a/numba/tests/test_mixed_tuple_unroller.py b/numba/tests/test_mixed_tuple_unroller.py
index ad6ddad18..7dd314a16 100644
--- a/numba/tests/test_mixed_tuple_unroller.py
+++ b/numba/tests/test_mixed_tuple_unroller.py
@@ -15,7 +15,8 @@ from numba.core.untyped_passes import (FixupArgs, TranslateByteCode,
SimplifyCFG, IterLoopCanonicalization,
LiteralUnroll, PreserveIR)
from numba.core.typed_passes import (NopythonTypeInference, IRLegalization,
- NoPythonBackend, PartialTypeInference)
+ NoPythonBackend, PartialTypeInference,
+ NativeLowering)
from numba.core.ir_utils import (compute_cfg_from_blocks, flatten_labels)
from numba.core.types.functions import _header_lead
@@ -108,6 +109,7 @@ class TestLoopCanonicalisation(MemoryLeakMixin, TestCase):
pm.add_pass(PreserveIR, "save IR for later inspection")
# lower
+ pm.add_pass(NativeLowering, "native lowering")
pm.add_pass(NoPythonBackend, "nopython mode backend")
# finalise the contents
@@ -1859,6 +1861,7 @@ class CapturingCompiler(CompilerBase):
"ensure IR is legal prior to lowering")
# lower
+ add_pass(NativeLowering, "native lowering")
add_pass(NoPythonBackend, "nopython mode backend")
pm.finalize()
return [pm]
diff --git a/numba/tests/test_remove_dead.py b/numba/tests/test_remove_dead.py
index a061cdb1e..ec888cadc 100644
--- a/numba/tests/test_remove_dead.py
+++ b/numba/tests/test_remove_dead.py
@@ -23,7 +23,7 @@ from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode, Fixup
from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes,
NopythonRewrites, PreParforPass, ParforPass,
DumpParforDiagnostics, NativeLowering,
- IRLegalization, NoPythonBackend)
+ IRLegalization, NoPythonBackend, NativeLowering)
import numpy as np
from numba.tests.support import skip_parfors_unsupported, needs_blas
import unittest
@@ -288,6 +288,7 @@ class TestRemoveDead(unittest.TestCase):
pm.add_pass(AnnotateTypes, "annotate types")
# lower
+ pm.add_pass(NativeLowering, "native lowering")
pm.add_pass(NoPythonBackend, "nopython mode backend")
pm.finalize()
return [pm]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment