Skip to content

Instantly share code, notes, and snippets.

@gmarkall
Created August 8, 2022 09:04
Show Gist options
  • Save gmarkall/03e657fa950697d59495846a1dea8b78 to your computer and use it in GitHub Desktop.
Save gmarkall/03e657fa950697d59495846a1dea8b78 to your computer and use it in GitHub Desktop.
diff --git a/numba/core/extending.py b/numba/core/extending.py
index 9d005fe74..b42442a38 100644
--- a/numba/core/extending.py
+++ b/numba/core/extending.py
@@ -155,8 +155,10 @@ def register_jitable(*args, **kwargs):
def wrap(fn):
# It is just a wrapper for @overload
inline = kwargs.pop('inline', 'never')
+ target = kwargs.pop('target', 'cpu')
- @overload(fn, jit_options=kwargs, inline=inline, strict=False)
+ @overload(fn, jit_options=kwargs, inline=inline, strict=False,
+ target=target)
def ov_wrap(*args, **kwargs):
return fn
return fn
diff --git a/numba/cpython/unicode.py b/numba/cpython/unicode.py
index a14f31910..c0874cf80 100644
--- a/numba/cpython/unicode.py
+++ b/numba/cpython/unicode.py
@@ -280,7 +280,7 @@ def _empty_string(kind, length, is_ascii=0):
# Disable RefCt for performance.
-@register_jitable(_nrt=False)
+@register_jitable(_nrt=False, target='generic')
def _get_code_point(a, i):
if a._kind == PY_UNICODE_1BYTE_KIND:
return deref_uint8(a._data, i)
@@ -384,7 +384,7 @@ def _kind_to_byte_width(kind):
raise AssertionError("Unexpected unicode encoding encountered")
-@register_jitable(_nrt=False)
+@register_jitable(_nrt=False, target='generic')
def _cmp_region(a, a_offset, b, b_offset, n):
if n == 0:
return 0
@@ -439,7 +439,7 @@ def unicode_str(s):
return lambda s: s
-@overload(len)
+@overload(len, target='generic')
def unicode_len(s):
if isinstance(s, types.UnicodeType):
def len_impl(s):
@@ -447,7 +447,7 @@ def unicode_len(s):
return len_impl
-@overload(operator.eq)
+@overload(operator.eq, target='generic')
def unicode_eq(a, b):
if not (a.is_internal and b.is_internal):
return
@@ -473,8 +473,8 @@ def unicode_eq(a, b):
else:
return False
# the str() is for UnicodeCharSeq, it's a nop else
- a = str(a)
- b = str(b)
+ #a = str(a)
+ #b = str(b)
if len(a) != len(b):
return False
return _cmp_region(a, 0, b, 0, len(a)) == 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment