Skip to content

Instantly share code, notes, and snippets.

@gmarkall
gmarkall / pq.py
Created July 30, 2021 10:15
PQ implementation modified
import itertools
import numba as nb
from numba.experimental import jitclass
from typing import List, Tuple, Dict
from heapq import heappush, heappop
# @jitclass
class PurePythonPriorityQueue:
@gmarkall
gmarkall / cuda_demo.py
Created August 3, 2021 17:59
CUDA demo presented at the 2021-08-03 Numba meeting (not executable, was modified to exemplify various things)
import math
from numba import cuda, njit, objmode
from time import perf_counter
import numpy as np
import cupy as cp
@njit
Iterating
Stream <CUDA stream 93950825260944 on <CUDA context c_void_p(93950819544272) of device 0>> done
Stream <CUDA stream 93950825975488 on <CUDA context c_void_p(93950819544272) of device 0>> done
# Works in conjunction with https://github.com/numba/numba/pull/7453
from numba import cuda
import asyncio
async def f():
s1 = cuda.stream()
s2 = cuda.stream()
@gmarkall
gmarkall / valgrind_test_ufuncs.log
Created October 14, 2021 19:52
Running test_ufuncs under valgrind with NumPy 1.21 and Numba PR #7483
$ PYTHONMALLOC=malloc valgrind-numba python -m numba.runtests numba.tests.test_ufuncs
==7578== Memcheck, a memory error detector
==7578== Copyright (C) 2002-2017, and GNU GPL'd, by Julian Seward et al.
==7578== Using Valgrind-3.17.0 and LibVEX; rerun with -h for copyright info
==7578== Command: python -m numba.runtests numba.tests.test_ufuncs
==7578==
==7579== Warning: invalid file descriptor 1024 in syscall close()
==7579== Warning: invalid file descriptor 1025 in syscall close()
==7579== Warning: invalid file descriptor 1026 in syscall close()
==7579== Warning: invalid file descriptor 1027 in syscall close()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
from numba import njit, f8
from numba.typed import List
from numba.extending import models, register_model
class Interval(object):
"""
A half-open interval on the real number line.
"""
diff --git a/numba/core/extending.py b/numba/core/extending.py
index 9d005fe74..b42442a38 100644
--- a/numba/core/extending.py
+++ b/numba/core/extending.py
@@ -155,8 +155,10 @@ def register_jitable(*args, **kwargs):
def wrap(fn):
# It is just a wrapper for @overload
inline = kwargs.pop('inline', 'never')
+ target = kwargs.pop('target', 'cpu')
# Implements unicode equality for the CUDA target
from numba import cuda, types
from numba.core.extending import overload
from numba.core.pythonapi import (PY_UNICODE_1BYTE_KIND,
PY_UNICODE_2BYTE_KIND,
PY_UNICODE_4BYTE_KIND)
from numba.cpython.unicode import deref_uint8, deref_uint16, deref_uint32
import numpy as np
import operator
# Use with https://github.com/gmarkall/numba/tree/cuda-linker-options
from numba import cuda, float32, void
def axpy(r, a, x, y):
start = cuda.grid(1)
step = cuda.gridsize(1)
for i in range(start, len(r), step):