Last active
April 1, 2020 16:40
-
-
Save jpivarski/7bc83e5aa70d5e3dd8483eb49800885c to your computer and use it in GitHub Desktop.
GrowableBuffer in Numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import operator | |
import numpy | |
# Note: these are pre-0.49 locations; things move around in Numba 0.49. | |
import numba | |
import numba.typing.arraydecl | |
# First, let's define the class in Python. | |
class GrowableBuffer: | |
def __init__(self, dtype, initial=1024, resize=1.5): | |
assert initial > 0 | |
assert resize > 1.0 | |
self._initial = initial | |
self._resize = resize | |
self._buffer = numpy.empty(initial, dtype=dtype) | |
self._length = numpy.array([0], dtype=numpy.intp) | |
def __str__(self): | |
return str(self.__array__()) | |
def __repr__(self): | |
return "growable({0})".format(str(self)) | |
def __len__(self): | |
# The length is in an array so that we can update it in-place in | |
# lowered code. | |
return self._length[0] | |
def __getitem__(self, where): | |
return self._buffer[where] | |
def __array__(self): | |
return self._buffer[:self._length[0]] | |
@property | |
def reserved(self): | |
return len(self._buffer) | |
def _ensure_reserved(self): | |
# This is called infrequently enough that we can have the lowered | |
# code call this Python function. That way, we don't have to | |
# reproduce this logic in lowered code. | |
while self._length[0] >= len(self._buffer): | |
reservation = int(numpy.ceil(len(self._buffer) * self._resize)) | |
newbuffer = numpy.empty(reservation, dtype=self._buffer.dtype) | |
newbuffer[:len(self._buffer)] = self._buffer | |
self._buffer = newbuffer | |
def append(self, what): | |
# This is the logic we will have to reproduce in the lowered code | |
# because it's called frequently. | |
if self._length[0] >= len(self._buffer): | |
self._ensure_reserved() | |
self._buffer[self._length[0]] = what | |
self._length[0] += 1 | |
# To start Numbafying this class, we define a Type. This is everything we | |
# need to know at compile-time. | |
class GrowableBufferType(numba.types.Type): | |
def __init__(self, dtype): | |
# This type depends on the dtype of the data (int64, float32, etc.). | |
super(GrowableBufferType, self).__init__(name= | |
"GrowableBufferType({0})".format(dtype.name)) | |
self.dtype = dtype | |
# We often need to know the type of the buffer array, so construct it | |
# from the dtype. | |
@property | |
def buffertype(self): | |
# dtype, 1-dim, C-contiguous | |
return numba.types.Array(self.dtype, 1, "C") | |
# Next, we have to identify the Type from a Python instance. | |
@numba.extending.typeof_impl.register(GrowableBuffer) | |
def typeof_GrowableBuffer(growablebuffer, c): | |
return GrowableBufferType( | |
numba.from_dtype(growablebuffer._buffer.dtype)) | |
# Next, we define what information is available at runtime. | |
# A model is a copy-by-value struct, and is therefore immutable. | |
# However, we need update the length with every call to 'append', and we | |
# need to update the buffer whenever the reservation changes. | |
# So we do it with pointers: we'll allocate the "buffer" pointer and fill it | |
# ourselves, but "length" will point to the one-element array from the | |
# Python object. The "pyobj" is a reference-counted pointer to the | |
# original Python object. | |
@numba.extending.register_model(GrowableBufferType) | |
class GrowableBufferModel(numba.datamodel.models.StructModel): | |
def __init__(self, dmm, fe_type): | |
members = [("buffer", numba.types.CPointer(fe_type.buffertype)), | |
("length", numba.types.CPointer(numba.intp)), | |
("pyobj", numba.types.pyobject)] | |
super(GrowableBufferModel, self).__init__(dmm, fe_type, members) | |
# "Unboxing" means converting a Python object into a lowered model. | |
# This function generates LLVM assembly to do the transformation. | |
@numba.extending.unbox(GrowableBufferType) | |
def unbox_GrowableBuffer(typ, obj, c): | |
# To build the lowered model, we have to extract some attributes from | |
# the Python object "obj". These are Python-C API calls (through c.pyapi). | |
buffer_obj = c.pyapi.object_getattr_string(obj, "_buffer") | |
length_obj = c.pyapi.object_getattr_string(obj, "_length") | |
ctypes_obj = c.pyapi.object_getattr_string(length_obj, "ctypes") | |
lenptr_obj = c.pyapi.object_getattr_string(ctypes_obj, "data") | |
# A proxy helps us generate LLVM assembly for getting or setting model | |
# attributes. If constructed without "value" (as below), we *set* values. | |
proxy = c.context.make_helper(c.builder, typ) | |
# For the "buffer" model attribute, we generate an instruction to allocate | |
# memory to hold a lowered NumPy array object. "alloca_once" returns a | |
# pointer, so we're setting proxy.buffer to a pointer. | |
proxy.buffer = numba.cgutils.alloca_once(c.builder, | |
c.context.get_value_type(typ.buffertype)) | |
# builder.store(value, pointer) assigns to the newly allocated memory. | |
c.builder.store(c.pyapi.to_native_value(typ.buffertype, buffer_obj).value, | |
proxy.buffer) | |
# The "length" is a pointer, too, but instead of allocating space for an | |
# integer (numba.intp, which is ssize_t), we'll take the already allocated | |
# "length" array in the GrowableBuffer Python object (which has room for | |
# one integer). | |
proxy.length = c.builder.inttoptr( | |
c.pyapi.number_as_ssize_t(lenptr_obj), | |
c.context.get_value_type(numba.types.CPointer(numba.intp))) | |
# Assign the Python object to this model. | |
proxy.pyobj = obj | |
# Turn the proxy into a value. (The underscored function is necessary.) | |
out = proxy._getvalue() | |
# Since this model contains a Python reference, ask Numba's RunTime (NRT) | |
# to reference-count it. | |
if c.context.enable_nrt: | |
c.context.nrt.incref(c.builder, typ, out) | |
# All the Python objects we created in this function have to be decrefed. | |
c.pyapi.decref(buffer_obj) | |
c.pyapi.decref(buffer_obj) | |
c.pyapi.decref(length_obj) | |
c.pyapi.decref(ctypes_obj) | |
c.pyapi.decref(lenptr_obj) | |
# Check for an error and return. | |
is_error = numba.cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) | |
return numba.extending.NativeValue(out, is_error) | |
# "Boxing" means converting a lowered model into a Python object. | |
# Since we have a refernce to the Python object, we just have to return that. | |
@numba.extending.box(GrowableBufferType) | |
def box_GrowableBuffer(typ, val, c): | |
# This time, we construct the proxy with a value, so we'll be *getting* | |
# fields from it. | |
proxy = c.context.make_helper(c.builder, typ, value=val) | |
# And increment the Python object because we're returning a new reference | |
# to it. | |
c.pyapi.incref(proxy.pyobj) | |
return proxy.pyobj | |
# Define the type for __getitem__. | |
@numba.typing.templates.infer_global(operator.getitem) | |
class type_getitem(numba.typing.templates.AbstractTemplate): | |
def generic(self, args, kwargs): | |
# If this raises an error or returns None, Numba will swallow the | |
# error and keep checking other possible types. | |
if (len(args) == 2 and len(kwargs) == 0 and | |
isinstance(args[0], GrowableBufferType)): | |
# This __getitem__ is generic: wheretype could be an integer, | |
# but it could also be an array or anything __getitem__ takes. | |
# Since we pass it on to Numba's handling of NumPy __getitem__, | |
# we get all the functionality of Numba's NumPy handling. | |
growabletype, wheretype = args | |
outtype = numba.typing.arraydecl.get_array_index_type( | |
growabletype.buffertype, wheretype).result | |
# The output is a Signature, which we construct as outtype(args...) | |
return outtype(growabletype, wheretype) | |
# The lowering function for __getitem__ with an integer or slice. | |
@numba.extending.lower_builtin(operator.getitem, GrowableBufferType, | |
numba.types.Integer) | |
@numba.extending.lower_builtin(operator.getitem, GrowableBufferType, | |
numba.types.SliceType) | |
def lower_getitem_int_slice(context, builder, sig, args): | |
growablebuffertype, wheretype = sig.args | |
growablebufferval, whereval = args | |
proxy = context.make_helper(builder, growablebuffertype, growablebufferval) | |
# Trim the buffer to the length of the length of the valid part. | |
trimmed = trim(context, builder, growablebuffertype.buffertype, | |
builder.load(proxy.buffer), builder.load(proxy.length)) | |
# Calls Numba's function for NumPy __getitem__ with an integer or slice. | |
return numba.targets.arrayobj.getitem_arraynd_intp(context, builder, | |
sig.return_type(growablebuffertype.buffertype, wheretype), | |
(trimmed, whereval)) | |
# The lowering function for __getitem__ with any other type. | |
@numba.extending.lower_builtin(operator.getitem, GrowableBufferType, | |
numba.types.Any) | |
def lower_getitem_array(context, builder, sig, args): | |
growablebuffertype, wheretype = sig.args | |
growablebufferval, whereval = args | |
proxy = context.make_helper(builder, growablebuffertype, growablebufferval) | |
# Trim the buffer to the length of the length of the valid part. | |
trimmed = trim(context, builder, growablebuffertype.buffertype, | |
builder.load(proxy.buffer), builder.load(proxy.length)) | |
# Calls Numba's function for NumPy __getitem__ with an integer or slice. | |
return numba.targets.arrayobj.fancy_getitem_array(context, builder, | |
sig.return_type(growablebuffertype.buffertype, wheretype), | |
(trimmed, whereval)) | |
# This "trim" function uses Numba's __getitem__ again. | |
def trim(context, builder, buffertype, bufferarray, length): | |
sliceproxy = context.make_helper(builder, numba.types.slice2_type) | |
sliceproxy.start = context.get_constant(numba.intp, 0) | |
sliceproxy.stop = length | |
sliceproxy.step = context.get_constant(numba.intp, 1) | |
return numba.targets.arrayobj.getitem_arraynd_intp(context, builder, | |
buffertype(buffertype, numba.types.slice2_type), | |
(bufferarray, sliceproxy._getvalue())) | |
# Define the type for __len__. | |
@numba.typing.templates.infer_global(len) | |
class type_len(numba.typing.templates.AbstractTemplate): | |
def generic(self, args, kwargs): | |
if (len(args) == 1 and len(kwargs) == 0 and | |
isinstance(args[0], GrowableBufferType)): | |
# This one is simple: take a GrowableBuffer in, return an intp. | |
return numba.intp(args[0]) | |
# The lowering function for __len__ is easy enough to do it directly. | |
@numba.extending.lower_builtin(len, GrowableBufferType) | |
def lower_len(context, builder, sig, args): | |
growablebuffertype, = sig.args | |
growablebufferval, = args | |
# Create a proxy from the input value, as before. | |
proxy = context.make_helper(builder, growablebuffertype, growablebufferval) | |
# And dereference the "length" pointer to return the appropriate value. | |
return builder.load(proxy.length) | |
# Define the type for attributes and methods. | |
@numba.typing.templates.infer_getattr | |
class GrowableBuffer_attrib(numba.typing.templates.AttributeTemplate): | |
key = GrowableBufferType | |
# This method defines all the attributes. Now the return value is a Type, | |
# not a Signature. | |
def generic_resolve(self, growablebuffertype, attr): | |
if attr == "_buffer": | |
return growablebuffertype.buffertype | |
elif attr == "reserved": | |
return numba.intp | |
# The methods could be defined with generic_resolve, but it's easier to | |
# use the bound_function decorator. For the methods, we return Signatures. | |
@numba.typing.templates.bound_function("_ensure_reserved") | |
def resolve__ensure_reserved(self, growablebuffertype, args, kwargs): | |
if len(args) == 0 and len(kwargs) == 0: | |
return numba.types.none() | |
@numba.typing.templates.bound_function("append") | |
def resolve_append(self, growablebuffertype, args, kwargs): | |
if (len(args) == 1 and len(kwargs) == 0 and | |
isinstance(args[0], numba.types.Number)): | |
return numba.types.none(args[0]) | |
# The lowering function for all attributes. | |
@numba.extending.lower_getattr_generic(GrowableBufferType) | |
def lower_getattr_generic(context, builder, | |
growablebuffertype, growablebufferval, | |
attr): | |
proxy = context.make_helper(builder, growablebuffertype, | |
value=growablebufferval) | |
if attr == "_buffer": | |
# Dereference the "buffer" pointer as we did with the "length" pointer | |
# before. | |
return builder.load(proxy.buffer) | |
elif attr == "reserved": | |
# This calls Numba's __len__ implementation for NumPy arrays. | |
sig = numba.types.intp(growablebuffertype.buffertype) | |
args = (builder.load(proxy.buffer),) | |
return numba.targets.arrayobj.array_len(context, builder, sig, args) | |
# The lowering function for the _ensure_reserved method. For this one, | |
# we dont't want to reimplement the logic in lowered Numba. | |
@numba.extending.lower_builtin("_ensure_reserved", | |
GrowableBufferType, numba.types.Integer) | |
def lower__ensure_reserved(context, builder, sig, args): | |
growablebuffertype, = sig.args | |
growablebufferval, = args | |
proxy = context.make_helper(builder, growablebuffertype, | |
value=growablebufferval) | |
# To call Python from a lowered function, we need to get a Python API. | |
pyapi = context.get_python_api(builder) | |
# And this means acquiring the Global Interpreter Lock (GIL). | |
gil = pyapi.gil_ensure() | |
# Call the Python function. | |
pyapi.incref(proxy.pyobj) | |
none_obj = pyapi.call_method(proxy.pyobj, "_ensure_reserved", ()) | |
# Since this has changed the buffer, we need to replace the lowered | |
# NumPy array in our "_buffer" pointer with a new one. This is a | |
# subset of the unboxing code. | |
newbuffer_obj = pyapi.object_getattr_string(proxy.pyobj, "_buffer") | |
newbufferval = pyapi.to_native_value(growablebuffertype.buffertype, | |
newbuffer_obj).value | |
# Assign it to the pointer! | |
builder.store(newbufferval, proxy.buffer) | |
# Decrement all those Python objects! | |
pyapi.decref(newbuffer_obj) | |
pyapi.decref(newbuffer_obj) | |
pyapi.decref(proxy.pyobj) | |
pyapi.decref(none_obj) | |
# Release the GIL! | |
pyapi.gil_release(gil) | |
# This function returns the lowered equivalent of None. | |
return context.get_dummy_value() | |
# The lowering function for the append method. Unlike _ensure_reserved, this | |
# one is called frequently and has to be fast. We will *not* defer to the | |
# Python implementation (or acquire the GIL, or anything like that). | |
@numba.extending.lower_builtin("append", | |
GrowableBufferType, numba.types.Number) | |
def lower_append(context, builder, sig, args): | |
growablebuffertype, numbertype = sig.args | |
growablebufferval, numberval = args | |
proxy = context.make_helper(builder, growablebuffertype, | |
value=growablebufferval) | |
# Get the current length and the size of the buffer to see if we have | |
# to call _ensure_reserved. | |
lengthval = builder.load(proxy.length) | |
reservedval = numba.targets.arrayobj.array_len(context, builder, | |
numba.types.intp(growablebuffertype.buffertype), | |
(builder.load(proxy.buffer),)) | |
# LLVM for an "if" statement can be generated using a context manager. | |
# (Remember that this function doesn't *run* append, it generates the | |
# code for append.) | |
# | |
# builder.icmp_signed(">=", ...) generates the code for the predicate. | |
# | |
# likely=False is a compiler hint that the predicate is rarely true. | |
# | |
# Compile-time control flow always goes through the "with" body, but | |
# run-time control flow rarely enters the "if" body that is generated. | |
with builder.if_then(builder.icmp_signed(">=", lengthval, reservedval), | |
likely=False): | |
ensure_sig = numba.types.none(growablebuffertype) | |
lower__ensure_reserved(context, builder, | |
ensure_sig, (growablebufferval,)) | |
# We have to make the proxy again so that we get the post-updated buffer | |
# in case the _ensure_reserved was called. | |
newproxy = context.make_helper(builder, growablebuffertype, | |
value=growablebufferval) | |
# Now call Numba's __setitem__ on the buffer array to write a new value | |
# at the current "length". | |
setitem_sig = numba.types.none(growablebuffertype.buffertype, | |
numba.intp, | |
numbertype) | |
numba.targets.arrayobj.setitem_array(context, builder, | |
setitem_sig, (builder.load(newproxy.buffer), lengthval, numberval)) | |
# Add one to the length and store it in its place. | |
builder.store(builder.add(lengthval, context.get_constant(numba.intp, 1)), | |
newproxy.length) | |
# Return None. | |
return context.get_dummy_value() | |
############################################################ tests | |
# Use a GrowableBuffer in Python. | |
buf = GrowableBuffer(float, initial=10) | |
buf.append(1.1) | |
buf.append(2.2) | |
buf.append(3.3) | |
# Get another reference to it so we can check its reference count. | |
tmp = buf._buffer | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
@numba.njit | |
def test1(x): | |
return 3.14 | |
# Test 1: unboxing doesn't crash. | |
test1(buf) | |
# Keep calling it and ensure that the reference counts don't grow. | |
for i in range(10): | |
test1(buf) | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
@numba.njit | |
def test2(x): | |
return x | |
# Test 2: unboxing and boxing doesn't crash and returns a usable object. | |
assert numpy.asarray(test2(buf)).tolist() == [1.1, 2.2, 3.3] | |
# Keep calling it and ensure that the reference counts don't grow. | |
for i in range(10): | |
test2(buf) | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
@numba.njit | |
def test3(x): | |
return x, x | |
# Test3: do that returning two references for every one that goes in, to make | |
# sure that the above didn't pass by accident. | |
for i in range(10): | |
test3(buf) | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
@numba.njit | |
def test4(x, i): | |
return x[i] | |
# Test 4: verify that __getitem__ works. | |
assert test4(buf, 0) == 1.1 | |
assert test4(buf, 1) == 2.2 | |
assert test4(buf, 2) == 3.3 | |
for i in range(10): | |
test4(buf, 0) | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
# ... for integers | |
assert test4(buf, 1) == 2.2 | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
# ... for slices | |
assert test4(buf, slice(1, None)).tolist() == [2.2, 3.3] | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
# ... for arrays | |
assert test4(buf, numpy.array([2, 1, 1, 0])).tolist() == [3.3, 2.2, 2.2, 1.1] | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
@numba.njit | |
def test5(x): | |
return len(x), x.reserved | |
# Test 5: verify that __len__ works and the "reserved" property works. | |
assert test5(buf) == (3, 10) | |
@numba.njit | |
def test6(x): | |
x.append(4.4) | |
x.append(5.5) | |
# Test 6: verify that we can append to the GrowableBuffer. | |
assert numpy.asarray(buf).tolist() == [1.1, 2.2, 3.3] | |
test6(buf) | |
assert numpy.asarray(buf).tolist() == [1.1, 2.2, 3.3, 4.4, 5.5] | |
# Go through 3 resizings to make sure it doesn't crash and none of the | |
# reference counts grow. | |
tmp = buf._buffer | |
for i in range(30): | |
test6(buf) | |
assert sys.getrefcount(tmp) in (2, 3) | |
tmp = buf._buffer | |
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3) | |
# The final value should have a lot of 4.4's and 5.5's in it. | |
assert numpy.asarray(buf).tolist() == [1.1, 2.2, 3.3, 4.4, 5.5] + [4.4, 5.5]*30 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment