Last active
August 8, 2022 09:59
-
-
Save gmarkall/23c0d5e1e879a117bd84bb95a2d8f1c8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Implements unicode equality for the CUDA target | |
from numba import cuda, types | |
from numba.core.extending import overload | |
from numba.core.pythonapi import (PY_UNICODE_1BYTE_KIND, | |
PY_UNICODE_2BYTE_KIND, | |
PY_UNICODE_4BYTE_KIND) | |
from numba.cpython.unicode import deref_uint8, deref_uint16, deref_uint32 | |
import numpy as np | |
import operator | |
# Copied / modified from numba/cpython/unicode.py | |
@overload(len, target='cuda') | |
def unicode_len(s): | |
if isinstance(s, types.UnicodeType): | |
def len_impl(s): | |
return s._length | |
return len_impl | |
def get_code_point(a, i): | |
pass | |
@overload(get_code_point, target='cuda') | |
def get_code_point_ol(a, i): | |
def get_code_point_impl(a, i): | |
if a._kind == PY_UNICODE_1BYTE_KIND: | |
return deref_uint8(a._data, i) | |
elif a._kind == PY_UNICODE_2BYTE_KIND: | |
return deref_uint16(a._data, i) | |
elif a._kind == PY_UNICODE_4BYTE_KIND: | |
return deref_uint32(a._data, i) | |
else: | |
# there's also a wchar kind, but that's one of the above, so | |
# skipping for this example | |
return 0 | |
return get_code_point_impl | |
def cmp_region(a, a_offset, b, b_offset, n): | |
pass | |
@overload(cmp_region, target='cuda') | |
def cmp_region_ol(a, a_offset, b, b_offset, n): | |
def cmp_region_impl(a, a_offset, b, b_offset, n): | |
if n == 0: | |
return 0 | |
elif a_offset + n > a._length: | |
return -1 | |
elif b_offset + n > b._length: | |
return 1 | |
for i in range(n): | |
a_chr = get_code_point(a, a_offset + i) | |
b_chr = get_code_point(b, b_offset + i) | |
if a_chr < b_chr: | |
return -1 | |
elif a_chr > b_chr: | |
return 1 | |
return 0 | |
return cmp_region_impl | |
@overload(operator.eq, target='cuda') | |
def unicode_eq(a, b): | |
if not (a.is_internal and b.is_internal): | |
return | |
if isinstance(a, types.Optional): | |
check_a = a.type | |
else: | |
check_a = a | |
if isinstance(b, types.Optional): | |
check_b = b.type | |
else: | |
check_b = b | |
accept = (types.UnicodeType, types.StringLiteral, types.UnicodeCharSeq) | |
a_unicode = isinstance(check_a, accept) | |
b_unicode = isinstance(check_b, accept) | |
if a_unicode and b_unicode: | |
def eq_impl(a, b): | |
# handle Optionals at runtime | |
a_none = a is None | |
b_none = b is None | |
if a_none or b_none: | |
if a_none and b_none: | |
return True | |
else: | |
return False | |
# the str() is for UnicodeCharSeq, it's a nop else | |
# (commented out for CUDA to avoid implementing str()) | |
# a = str(a) | |
# b = str(b) | |
if len(a) != len(b): | |
return False | |
return cmp_region(a, 0, b, 0, len(a)) == 0 | |
return eq_impl | |
elif a_unicode ^ b_unicode: | |
# one of the things is unicode, everything compares False | |
def eq_impl(a, b): | |
return False | |
return eq_impl | |
@cuda.jit | |
def find_fruit(arr, string): | |
y = 0 | |
for x in arr: | |
if x == string: | |
return y | |
break | |
y += 1 | |
return -1 | |
@cuda.jit | |
def kernel(loc): | |
fruits = ('apple', 'banana', 'cherry') | |
loc[()] = find_fruit(fruits, 'banana') | |
c1 = np.ndarray((), dtype=np.int64) | |
kernel[1, 1](c1) | |
print(c1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment