Created
November 7, 2014 11:10
-
-
Save gmarkall/1564a9e92ff3392c35af to your computer and use it in GitHub Desktop.
Extending the Numba frontend with an interval type - corrected ordering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define a new class | |
class Interval(object): | |
'''A half-open interval on the real number line.''' | |
def __init__(self, lo, hi): | |
self.lo = lo | |
self.hi = hi | |
def __repr__(self): | |
return 'Interval(%f, %f)' % (self.lo, self.hi) | |
# Create a registry for our set of types | |
from numba.typing.templates import (AttributeTemplate, ConcreteTemplate, | |
signature, Registry) | |
registry = Registry() | |
builtin = registry.register | |
builtin_attr = registry.register_attr | |
builtin_global = registry.register_global | |
# Creating a New Numba Type | |
from numba.types import float32, bool_, Type | |
class IntervalType(Type): | |
def __init__(self): | |
super(IntervalType, self).__init__(name='Interval') | |
interval_type = IntervalType() | |
# Adding an Attribute Value Type Signature | |
from numba.types import float32 | |
from numba.typing.templates import AttributeTemplate | |
from numba.targets.imputils import impl_attribute | |
@builtin_attr | |
class IntervalAttributes(AttributeTemplate): | |
key = interval_type | |
# We will store the interval bounds as 32-bit floats | |
_attributes = dict(lo=float32, hi=float32) | |
def generic_resolve(self, value, attr): | |
return self._attributes[attr] | |
# The registry needs installing in the typing context after signatures have been | |
# registered | |
from numba.targets.registry import target_registry | |
# Assuming the CPU target | |
target = target_registry['cpu'] | |
target.targetdescr.typing_context.install(registry) | |
# jit functions | |
from numba import njit | |
# Note that this will fail in the backend because it doesn't know how to convert | |
# an Interval type to native values | |
@njit(bool_(interval_type, float32)) | |
def inside(interval, x): | |
return interval.lo <= x < interval.hi | |
# jit function call | |
i = Interval(0, 2) | |
print(inside(i, 1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment