-
-
Save gmarkall/a32d155eb31c0449b9d9f4aad4dfe067 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import njit, f8 | |
from numba.typed import List | |
from numba.extending import models, register_model | |
class Interval(object): | |
""" | |
A half-open interval on the real number line. | |
""" | |
def __init__(self, lo, hi): | |
self.lo = lo | |
self.hi = hi | |
def __repr__(self): | |
return 'Interval(%f, %f)' % (self.lo, self.hi) | |
@property | |
def width(self): | |
return self.hi - self.lo | |
from numba import types | |
class IntervalType(types.IterableType): | |
def __init__(self): | |
super(IntervalType, self).__init__(name='Interval') | |
@property | |
def iterator_type(self): | |
return IntervalTypeIterableType(self).iterator_type | |
class IntervalTypeIterableType(types.SimpleIterableType): | |
def __init__(self, parent): | |
assert isinstance(parent, IntervalType) | |
self.parent = parent | |
self.yield_type = types.float64 | |
name = f"Interval[{self.parent.name}]" | |
iterator_type = IntervalTypeIteratorType(self) | |
super().__init__(name, iterator_type) | |
class IntervalTypeIteratorType(types.SimpleIteratorType): | |
def __init__(self, iterable): | |
self.parent = iterable.parent | |
self.iterable = iterable | |
yield_type = iterable.yield_type | |
name = f"iter[{iterable.parent}->{yield_type}]" | |
super().__init__(name, yield_type) | |
@register_model(IntervalTypeIterableType) | |
@register_model(IntervalTypeIteratorType) | |
class IntervalIterModel(models.StructModel): | |
def __init__(self, dmm, fe_type): | |
members = [ | |
('parent', fe_type.parent), | |
('index', types.EphemeralPointer(types.intp)) | |
] | |
super().__init__(dmm, fe_type, members) | |
from numba.extending import typeof_impl | |
@typeof_impl.register(Interval) | |
def typeof_index(val, c): | |
return interval_type | |
interval_type = IntervalType() | |
from numba.extending import as_numba_type | |
from numba.extending import type_callable | |
@type_callable(Interval) | |
def type_interval(context): | |
def typer(lo, hi): | |
if isinstance(lo, types.Float) and isinstance(hi, types.Float): | |
return interval_type | |
return typer | |
as_numba_type.register(Interval, interval_type) | |
from numba.extending import models, register_model | |
@register_model(IntervalType) | |
class IntervalModel(models.StructModel): | |
def __init__(self, dmm, fe_type): | |
members = [ | |
('lo', types.float64), | |
('hi', types.float64), | |
] | |
models.StructModel.__init__(self, dmm, fe_type, members) | |
from numba.extending import make_attribute_wrapper | |
make_attribute_wrapper(IntervalType, 'lo', 'lo') | |
make_attribute_wrapper(IntervalType, 'hi', 'hi') | |
from numba.extending import overload_attribute | |
@overload_attribute(IntervalType, "width") | |
def get_width(interval): | |
def getter(interval): | |
return interval.hi - interval.lo | |
return getter | |
from numba.extending import lower_builtin | |
from numba.core import cgutils | |
@lower_builtin(Interval, types.Float, types.Float) | |
def impl_interval(context, builder, sig, args): | |
typ = sig.return_type | |
lo, hi = args | |
interval = cgutils.create_struct_proxy(typ)(context, builder) | |
interval.lo = lo | |
interval.hi = hi | |
return interval._getvalue() | |
from numba.extending import unbox, NativeValue | |
@unbox(IntervalType) | |
def unbox_interval(typ, obj, c): | |
""" | |
Convert a Interval object to a native interval structure. | |
""" | |
lo_obj = c.pyapi.object_getattr_string(obj, "lo") | |
hi_obj = c.pyapi.object_getattr_string(obj, "hi") | |
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder) | |
interval.lo = c.pyapi.float_as_double(lo_obj) | |
interval.hi = c.pyapi.float_as_double(hi_obj) | |
c.pyapi.decref(lo_obj) | |
c.pyapi.decref(hi_obj) | |
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) | |
return NativeValue(interval._getvalue(), is_error=is_error) | |
from numba.extending import box | |
@box(IntervalType) | |
def box_interval(typ, val, c): | |
""" | |
Convert a native interval structure to an Interval object. | |
""" | |
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) | |
lo_obj = c.pyapi.float_from_double(interval.lo) | |
hi_obj = c.pyapi.float_from_double(interval.hi) | |
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval)) | |
res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj)) | |
c.pyapi.decref(lo_obj) | |
c.pyapi.decref(hi_obj) | |
c.pyapi.decref(class_obj) | |
return res | |
from numba import jit | |
@jit(nopython=True) | |
def inside_interval(interval, x): | |
return interval.lo <= x < interval.hi | |
@jit(nopython=True) | |
def interval_width(interval): | |
return interval.width | |
@jit(nopython=True) | |
def sum_intervals(i, j): | |
return Interval(i.lo + j.lo, i.hi + j.hi) | |
assert inside_interval(Interval(1.0,5.0),4) == True | |
assert inside_interval(Interval(1.0,5.0),6) == False | |
print(interval_width(Interval(1.0,6.0))) | |
print(sum_intervals(Interval(1.0,6.0),Interval(1.0,6.0))) | |
########### | |
## ^ Above all from https://numba.pydata.org/numba-doc/latest/extending/interval-example.html | |
## v Below a test of implementing "getiter" | |
########### | |
from numba import f8 | |
from numba.extending import lower_builtin | |
@lower_builtin("getiter", interval_type) | |
def iterval_getiter(context, builder, sig, args): | |
print("THIS SHOULD GET PRINTED!!",args[0]) | |
assert False | |
@jit(nopython=True) | |
def iter_interval(i, j): | |
for i in Interval(f8(i),f8(j)): | |
print(i) | |
iter_interval(1,10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment