Created
February 18, 2021 15:42
-
-
Save gmarkall/ea64d302482701cdf53f8464eb8c707e to your computer and use it in GitHub Desktop.
Adding a CUDA pipeline to Numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/numba/core/compiler.py b/numba/core/compiler.py | |
| index a0c03fdb3..e47270a2f 100644 | |
| --- a/numba/core/compiler.py | |
| +++ b/numba/core/compiler.py | |
| @@ -28,7 +28,8 @@ from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes, | |
| NopythonRewrites, PreParforPass, | |
| ParforPass, DumpParforDiagnostics, | |
| IRLegalization, NoPythonBackend, | |
| - InlineOverloads, PreLowerStripPhis) | |
| + InlineOverloads, PreLowerStripPhis, | |
| + NativeLowering) | |
| from numba.core.object_mode_passes import (ObjectModeFrontEnd, | |
| ObjectModeBackEnd) | |
| @@ -476,6 +477,7 @@ class DefaultPassBuilder(object): | |
| "ensure IR is legal prior to lowering") | |
| # lower | |
| + pm.add_pass(NativeLowering, "native lowering") | |
| pm.add_pass(NoPythonBackend, "nopython mode backend") | |
| pm.add_pass(DumpParforDiagnostics, "dump parfor diagnostics") | |
| pm.finalize() | |
| diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py | |
| index 937fe0bb9..0aa848d36 100644 | |
| --- a/numba/core/typed_passes.py | |
| +++ b/numba/core/typed_passes.py | |
| @@ -358,8 +358,15 @@ class NativeLowering(LoweringPass): | |
| LoweringPass.__init__(self) | |
| def run_pass(self, state): | |
| - targetctx = state.targetctx | |
| + if state.library is None: | |
| + codegen = state.targetctx.codegen() | |
| + state.library = codegen.create_library(state.func_id.func_qualname) | |
| + # Enable object caching upfront, so that the library can | |
| + # be later serialized. | |
| + state.library.enable_object_caching() | |
| + | |
| library = state.library | |
| + targetctx = state.targetctx | |
| interp = state.func_ir # why is it called this?! | |
| typemap = state.typemap | |
| restype = state.return_type | |
| @@ -452,15 +459,6 @@ class NoPythonBackend(LoweringPass): | |
| """ | |
| Back-end: Generate LLVM IR from Numba IR, compile to machine code | |
| """ | |
| - if state.library is None: | |
| - codegen = state.targetctx.codegen() | |
| - state.library = codegen.create_library(state.func_id.func_qualname) | |
| - # Enable object caching upfront, so that the library can | |
| - # be later serialized. | |
| - state.library.enable_object_caching() | |
| - | |
| - # TODO: Pull this out into the pipeline | |
| - NativeLowering().run_pass(state) | |
| lowered = state['cr'] | |
| signature = typing.signature(state.return_type, *state.args) | |
| diff --git a/numba/cuda/compiler.py b/numba/cuda/compiler.py | |
| index 887d930e4..5060fce59 100644 | |
| --- a/numba/cuda/compiler.py | |
| +++ b/numba/cuda/compiler.py | |
| @@ -14,9 +14,14 @@ from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate | |
| from numba.core import (types, typing, utils, funcdesc, serialize, config, | |
| compiler, sigutils) | |
| from numba.core.typeconv.rules import default_type_manager | |
| +from numba.core.compiler import (CompilerBase, DefaultPassBuilder, | |
| + compile_result) | |
| from numba.core.compiler_lock import global_compiler_lock | |
| +from numba.core.compiler_machinery import (LoweringPass, PassManager, | |
| + register_pass) | |
| from numba.core.dispatcher import OmittedArg | |
| from numba.core.errors import NumbaDeprecationWarning | |
| +from numba.core.typed_passes import IRLegalization, NativeLowering | |
| from numba.core.typing.typeof import Purpose, typeof | |
| from warnings import warn | |
| import numba | |
| @@ -28,6 +33,65 @@ from .api import get_current_device | |
| from .args import wrap_arg | |
| +@register_pass(mutates_CFG=True, analysis_only=False) | |
| +class CUDABackend(LoweringPass): | |
| + | |
| + _name = "cuda_backend" | |
| + | |
| + def __init__(self): | |
| + LoweringPass.__init__(self) | |
| + | |
| + def run_pass(self, state): | |
| + """ | |
| + Back-end: Packages lowering output in a compile result | |
| + """ | |
| + lowered = state['cr'] | |
| + signature = typing.signature(state.return_type, *state.args) | |
| + | |
| + state.cr = compile_result( | |
| + typing_context=state.typingctx, | |
| + target_context=state.targetctx, | |
| + typing_error=state.status.fail_reason, | |
| + type_annotation=state.type_annotation, | |
| + library=state.library, | |
| + call_helper=lowered.call_helper, | |
| + signature=signature, | |
| + fndesc=lowered.fndesc, | |
| + ) | |
| + return True | |
| + | |
| + | |
| +class CUDACompiler(CompilerBase): | |
| + def define_pipelines(self): | |
| + dpb = DefaultPassBuilder | |
| + pm = PassManager('cuda') | |
| + | |
| + untyped_passes = dpb.define_untyped_pipeline(self.state) | |
| + pm.passes.extend(untyped_passes.passes) | |
| + | |
| + typed_passes = dpb.define_typed_pipeline(self.state) | |
| + pm.passes.extend(typed_passes.passes) | |
| + | |
| + lowering_passes = self.define_cuda_lowering_pipeline(self.state) | |
| + pm.passes.extend(lowering_passes.passes) | |
| + | |
| + pm.finalize() | |
| + return [pm] | |
| + | |
| + def define_cuda_lowering_pipeline(self, state): | |
| + pm = PassManager('cuda_lowering') | |
| + # legalise | |
| + pm.add_pass(IRLegalization, | |
| + "ensure IR is legal prior to lowering") | |
| + | |
| + # lower | |
| + pm.add_pass(NativeLowering, "native lowering") | |
| + pm.add_pass(CUDABackend, "cuda backend") | |
| + | |
| + pm.finalize() | |
| + return pm | |
| + | |
| + | |
| @global_compiler_lock | |
| def compile_cuda(pyfunc, return_type, args, debug=False, inline=False): | |
| from .descriptor import cuda_target | |
| @@ -50,11 +114,11 @@ def compile_cuda(pyfunc, return_type, args, debug=False, inline=False): | |
| args=args, | |
| return_type=return_type, | |
| flags=flags, | |
| - locals={}) | |
| + locals={}, | |
| + pipeline_class=CUDACompiler) | |
| library = cres.library | |
| library.finalize() | |
| - | |
| return cres | |
| @@ -90,6 +154,7 @@ def compile_ptx(pyfunc, args, debug=False, device=False, fastmath=False, | |
| """ | |
| cres = compile_cuda(pyfunc, None, args, debug=debug) | |
| resty = cres.signature.return_type | |
| + | |
| if device: | |
| llvm_modules = cres.library.modules | |
| else: | |
| diff --git a/numba/tests/test_errorhandling.py b/numba/tests/test_errorhandling.py | |
| index 83f06211b..96db57a8d 100644 | |
| --- a/numba/tests/test_errorhandling.py | |
| +++ b/numba/tests/test_errorhandling.py | |
| @@ -15,7 +15,7 @@ from numba.core.compiler import CompilerBase | |
| from numba.core.untyped_passes import (TranslateByteCode, FixupArgs, | |
| IRProcessing,) | |
| from numba.core.typed_passes import (NopythonTypeInference, DeadCodeElimination, | |
| - NoPythonBackend) | |
| + NoPythonBackend, NativeLowering) | |
| from numba.core.compiler_machinery import PassManager | |
| from numba.core.types.functions import _err_reasons as error_reasons | |
| @@ -110,6 +110,7 @@ class TestMiscErrorHandling(unittest.TestCase): | |
| pm.add_pass(DeadCodeElimination, "DCE") | |
| # typing | |
| pm.add_pass(NopythonTypeInference, "nopython frontend") | |
| + pm.add_pass(NativeLowering, "native lowering") | |
| pm.add_pass(NoPythonBackend, "nopython mode backend") | |
| pm.finalize() | |
| return [pm] | |
| diff --git a/numba/tests/test_inlining.py b/numba/tests/test_inlining.py | |
| index 76a7565fa..1b4012545 100644 | |
| --- a/numba/tests/test_inlining.py | |
| +++ b/numba/tests/test_inlining.py | |
| @@ -18,7 +18,7 @@ from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode, Fixup | |
| from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes, | |
| NopythonRewrites, PreParforPass, ParforPass, | |
| DumpParforDiagnostics, NativeLowering, | |
| - IRLegalization, NoPythonBackend) | |
| + IRLegalization, NoPythonBackend, NativeLowering) | |
| from numba.core.compiler_machinery import FunctionPass, PassManager, register_pass | |
| import unittest | |
| @@ -98,6 +98,7 @@ def gen_pipeline(state, test_pass): | |
| pm.add_pass(PreserveIR, "preserve IR") | |
| # lower | |
| + pm.add_pass(NativeLowering, "native lowering") | |
| pm.add_pass(NoPythonBackend, "nopython mode backend") | |
| pm.add_pass(DumpParforDiagnostics, "dump parfor diagnostics") | |
| return pm | |
| diff --git a/numba/tests/test_mixed_tuple_unroller.py b/numba/tests/test_mixed_tuple_unroller.py | |
| index ad6ddad18..7dd314a16 100644 | |
| --- a/numba/tests/test_mixed_tuple_unroller.py | |
| +++ b/numba/tests/test_mixed_tuple_unroller.py | |
| @@ -15,7 +15,8 @@ from numba.core.untyped_passes import (FixupArgs, TranslateByteCode, | |
| SimplifyCFG, IterLoopCanonicalization, | |
| LiteralUnroll, PreserveIR) | |
| from numba.core.typed_passes import (NopythonTypeInference, IRLegalization, | |
| - NoPythonBackend, PartialTypeInference) | |
| + NoPythonBackend, PartialTypeInference, | |
| + NativeLowering) | |
| from numba.core.ir_utils import (compute_cfg_from_blocks, flatten_labels) | |
| from numba.core.types.functions import _header_lead | |
| @@ -108,6 +109,7 @@ class TestLoopCanonicalisation(MemoryLeakMixin, TestCase): | |
| pm.add_pass(PreserveIR, "save IR for later inspection") | |
| # lower | |
| + pm.add_pass(NativeLowering, "native lowering") | |
| pm.add_pass(NoPythonBackend, "nopython mode backend") | |
| # finalise the contents | |
| @@ -1859,6 +1861,7 @@ class CapturingCompiler(CompilerBase): | |
| "ensure IR is legal prior to lowering") | |
| # lower | |
| + add_pass(NativeLowering, "native lowering") | |
| add_pass(NoPythonBackend, "nopython mode backend") | |
| pm.finalize() | |
| return [pm] | |
| diff --git a/numba/tests/test_remove_dead.py b/numba/tests/test_remove_dead.py | |
| index a061cdb1e..ec888cadc 100644 | |
| --- a/numba/tests/test_remove_dead.py | |
| +++ b/numba/tests/test_remove_dead.py | |
| @@ -23,7 +23,7 @@ from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode, Fixup | |
| from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes, | |
| NopythonRewrites, PreParforPass, ParforPass, | |
| DumpParforDiagnostics, NativeLowering, | |
| - IRLegalization, NoPythonBackend) | |
| + IRLegalization, NoPythonBackend, NativeLowering) | |
| import numpy as np | |
| from numba.tests.support import skip_parfors_unsupported, needs_blas | |
| import unittest | |
| @@ -288,6 +288,7 @@ class TestRemoveDead(unittest.TestCase): | |
| pm.add_pass(AnnotateTypes, "annotate types") | |
| # lower | |
| + pm.add_pass(NativeLowering, "native lowering") | |
| pm.add_pass(NoPythonBackend, "nopython mode backend") | |
| pm.finalize() | |
| return [pm] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment