Last active
January 17, 2020 09:28
-
-
Save luk-f-a/6d0eef6534a8aae27a6650093321d833 to your computer and use it in GitHub Desktop.
Class overload proposal for numba-scipy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Current state\n", | |
"\n", | |
"Overloading a python class requires several steps, dealing directly with the typing, lowering, boxing and unboxing\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Creating a new Numba type**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from numba import types\n", | |
"\n", | |
"class IntervalType(types.Type):\n", | |
" def __init__(self):\n", | |
" super(IntervalType, self).__init__(name='Interval')\n", | |
"\n", | |
"interval_type = IntervalType()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Type inference for Python values**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from numba.extending import typeof_impl\n", | |
"\n", | |
"@typeof_impl.register(Interval)\n", | |
"def typeof_index(val, c):\n", | |
" return interval_type" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Type inference for operations**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from numba.extending import type_callable\n", | |
"\n", | |
"@type_callable(Interval)\n", | |
"def type_interval(context):\n", | |
" def typer(lo, hi):\n", | |
" if isinstance(lo, types.Float) and isinstance(hi, types.Float):\n", | |
" return interval_type\n", | |
" return typer" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Defining the data model**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@register_model(IntervalType)\n", | |
"class IntervalModel(models.StructModel):\n", | |
" def __init__(self, dmm, fe_type):\n", | |
" members = [\n", | |
" ('lo', types.float64),\n", | |
" ('hi', types.float64),\n", | |
" ]\n", | |
" models.StructModel.__init__(self, dmm, fe_type, members)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Exposing data model attributes**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"make_attribute_wrapper(IntervalType, 'lo', 'lo')\n", | |
"make_attribute_wrapper(IntervalType, 'hi', 'hi')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Exposing a property**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@overload_attribute(IntervalType, \"width\")\n", | |
"def get_width(interval):\n", | |
" def getter(interval):\n", | |
" return interval.hi - interval.lo\n", | |
" return getter" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Implementing the constructor**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@lower_builtin(Interval, types.Float, types.Float)\n", | |
"def impl_interval(context, builder, sig, args):\n", | |
" typ = sig.return_type\n", | |
" lo, hi = args\n", | |
" interval = cgutils.create_struct_proxy(typ)(context, builder)\n", | |
" interval.lo = lo\n", | |
" interval.hi = hi\n", | |
" return interval._getvalue()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Problem \n", | |
"\n", | |
"These steps require advanced knowledge of Numba. `Numba-scipy` requires hundreds of overloads so we need to make it more accessible\n", | |
"in order to reach a wider pool of contributors." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Proposal using Jitclass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numba\n", | |
"from numba import types\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"There is a certain Python class which we want to \"overload\" (provide a version that can run in jitted code including jit-transparency in references)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Interval(object):\n", | |
" \"\"\"\n", | |
" A half-open interval on the real number line.\n", | |
" \"\"\"\n", | |
" def __init__(self, lo, hi):\n", | |
" self.lo = lo\n", | |
" self.hi = hi\n", | |
"\n", | |
" def __repr__(self):\n", | |
" return 'Interval(%f, %f)' % (self.lo, self.hi)\n", | |
"\n", | |
" @property\n", | |
" def width(self):\n", | |
" return self.hi - self.lo" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**First step**: the user would provide a jitclass that has the desired behaviour.\n", | |
"\n", | |
"In this case it's identical to the original, but for more complex objects it could be subset of the original pure-Python implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@numba.jitclass(spec= [\n", | |
" ('lo', types.float64),\n", | |
" ('hi', types.float64),\n", | |
" ])\n", | |
"class IntervalJit(object):\n", | |
" \"\"\"\n", | |
" A half-open interval on the real number line.\n", | |
" \"\"\"\n", | |
" def __init__(self, lo, hi):\n", | |
" self.lo = lo\n", | |
" self.hi = hi\n", | |
"\n", | |
" def __repr__(self):\n", | |
" return 'Interval(%f, %f)' % (self.lo, self.hi)\n", | |
"\n", | |
" @property\n", | |
" def width(self):\n", | |
" return self.hi - self.lo" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Second step**: declare the Jitclass as overloading the Python class.\n", | |
"\n", | |
"`overload_pyclass` to be provided by `numba.extending` or `numba_scipy.extending`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"overload_pyclass(Interval, IntervalJit)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Third step**: None. Sit back and enjoy your overloaded `Interval`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Additions to** `numba.extending` or `numba_scipy.extending` **(all in very alpha state below)**. \n", | |
"\n", | |
"Not everything is new code, some are existing functions that I have brought to the notebook" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def overload_pyclass1(pyclass, jitclass):\n", | |
" from numba.extending import typeof_impl\n", | |
" @typeof_impl.register(pyclass)\n", | |
" def typeof_index(val, c):\n", | |
" return jitclass.class_type.instance_type\n", | |
" \n", | |
" \n", | |
"def overload_pyclass2(pyclass, jitclass):\n", | |
" from numba.targets.registry import cpu_target\n", | |
" # Register resolution of the class object\n", | |
" typingctx = cpu_target.typing_context\n", | |
" typingctx.insert_global(pyclass, jitclass.class_type)\n", | |
"\n", | |
"\n", | |
" \n", | |
" \n", | |
"def _add_linking_libs(context, call):\n", | |
" \"\"\"\n", | |
" Add the required libs for the callable to allow inlining.\n", | |
" \"\"\"\n", | |
" libs = getattr(call, \"libs\", ())\n", | |
" if libs:\n", | |
" context.add_linking_libs(libs)\n", | |
" \n", | |
"def imp_dtor(context, module, instance_type):\n", | |
" from llvmlite import ir as llvmir\n", | |
" llvoidptr = context.get_value_type(types.voidptr)\n", | |
" llsize = context.get_value_type(types.uintp)\n", | |
" dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),\n", | |
" [llvoidptr, llsize, llvoidptr])\n", | |
"\n", | |
" fname = \"_Dtor.{0}\".format(instance_type.name)\n", | |
" dtor_fn = module.get_or_insert_function(dtor_ftype,\n", | |
" name=fname)\n", | |
" if dtor_fn.is_declaration:\n", | |
" # Define\n", | |
" builder = llvmir.IRBuilder(dtor_fn.append_basic_block())\n", | |
"\n", | |
" alloc_fe_type = instance_type.get_data_type()\n", | |
" alloc_type = context.get_value_type(alloc_fe_type)\n", | |
"\n", | |
" ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())\n", | |
" data = context.make_helper(builder, alloc_fe_type, ref=ptr)\n", | |
"\n", | |
" context.nrt.decref(builder, alloc_fe_type, data._getvalue())\n", | |
"\n", | |
" builder.ret_void()\n", | |
"\n", | |
" return dtor_fn\n", | |
"\n", | |
"def unbox_pyclass4(pyclass, jitclass):\n", | |
" from numba.extending import unbox, NativeValue\n", | |
" from numba import cgutils \n", | |
" from numba.pythonapi import _unboxers\n", | |
" del _unboxers.functions[types.ClassInstanceType]\n", | |
"\n", | |
" @unbox(types.ClassInstanceType)\n", | |
" def unbox_interval(typ, obj, c):\n", | |
" \"\"\"\n", | |
" Convert a Interval object to a native interval structure.\n", | |
" \"\"\"\n", | |
" obj_list = []\n", | |
" type_inst_list = []\n", | |
" for attr_name, attr_typ in typ.struct.items():\n", | |
" obj_list.append(c.pyapi.object_getattr_string(obj, attr_name))\n", | |
" type_inst_list.append(attr_typ)\n", | |
" \n", | |
" type_inst_list = tuple(type_inst_list)\n", | |
" \n", | |
" # Allocate the instance\n", | |
" inst_typ = typ\n", | |
" context = c.context\n", | |
" builder = c.builder\n", | |
" alloc_type = context.get_data_type(inst_typ.get_data_type())\n", | |
" alloc_size = context.get_abi_sizeof(alloc_type)\n", | |
"\n", | |
" meminfo = context.nrt.meminfo_alloc_dtor(\n", | |
" builder,\n", | |
" context.get_constant(types.uintp, alloc_size),\n", | |
" imp_dtor(context, builder.module, inst_typ),\n", | |
" )\n", | |
" data_pointer = context.nrt.meminfo_data(builder, meminfo)\n", | |
" data_pointer = builder.bitcast(data_pointer,\n", | |
" alloc_type.as_pointer())\n", | |
"\n", | |
" # Nullify all data\n", | |
" builder.store(cgutils.get_null_value(alloc_type),\n", | |
" data_pointer)\n", | |
"\n", | |
" inst_struct = context.make_helper(builder, inst_typ)\n", | |
" inst_struct.meminfo = meminfo\n", | |
" inst_struct.data = data_pointer\n", | |
"\n", | |
" #TODO: fill attributes with actual values\n", | |
" #IDEA: instead of doing automatically, call an __unbox__ method in the jitclass\n", | |
" # to allow user customization\n", | |
"\n", | |
" # Prepare return value\n", | |
" ret = inst_struct._getvalue()\n", | |
" \n", | |
" return NativeValue(ret, is_error=c.pyapi.c_api_error())\n", | |
" \n", | |
" \n", | |
"\n", | |
"def overload_pyclass(pyclass, jitclass):\n", | |
" overload_pyclass1(pyclass, jitclass)\n", | |
" overload_pyclass2(pyclass, jitclass)\n", | |
" unbox_pyclass4(pyclass, jitclass)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Tests" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Create Test1: True\n", | |
"Create Test2: True\n", | |
"Unbox Test1 \n", | |
" \tUnboxing: True \n", | |
" \tMember transfer: False\n", | |
"Unbox Test2 \n", | |
" \tUnboxing: True \n", | |
" \tMember transfer: False\n", | |
"Property Test1 \n", | |
" \tUnboxing: True \n", | |
" \tProperty calculation: False\n", | |
"Box Test1 \n", | |
" \tBoxing: False\n" | |
] | |
} | |
], | |
"source": [ | |
"from numba import njit\n", | |
"\n", | |
"@njit\n", | |
"def create_interval1():\n", | |
" a = Interval(2.1, 3.1)\n", | |
" return a.lo, a.width\n", | |
"\n", | |
"temp = create_interval1()\n", | |
"print(\"Create Test1: \", temp==(2.1, 1.0))\n", | |
"\n", | |
"@njit\n", | |
"def create_interval2(i, j):\n", | |
" a = Interval(i, j)\n", | |
" return a.lo, a.width\n", | |
"\n", | |
"\n", | |
"temp = create_interval2(4.1,5.6)\n", | |
"print(\"Create Test2: \", temp==(4.1, 1.5))\n", | |
"\n", | |
"\n", | |
"inter = Interval(2.1, 3.1)\n", | |
"\n", | |
"@njit\n", | |
"def inside_interval1(interval):\n", | |
" return interval.lo \n", | |
"\n", | |
"temp = inside_interval1(inter)\n", | |
"print('Unbox Test1', \"\\n \\tUnboxing: True\", \"\\n \\tMember transfer: \" + str(temp==2.1))\n", | |
"\n", | |
"@njit\n", | |
"def inside_interval2(interval, x):\n", | |
" return interval.lo <= x < interval.hi\n", | |
"\n", | |
"temp = inside_interval2(inter, 2.5)\n", | |
"print('Unbox Test2', \"\\n \\tUnboxing: True\", \"\\n \\tMember transfer: \" + str(temp))\n", | |
"\n", | |
"\n", | |
"@njit\n", | |
"def interval_width(interval):\n", | |
" return interval.width\n", | |
"\n", | |
"temp = interval_width(inter)\n", | |
"print('Property Test1', \"\\n \\tUnboxing: True\", \"\\n \\tProperty calculation: \" + str(temp==1.0))\n", | |
"\n", | |
"\n", | |
"@njit\n", | |
"\n", | |
"def sum_intervals(i, j):\n", | |
" #return Interval(i.lo + j.lo, i.hi + j.hi)\n", | |
" a = Interval(i.lo + j.lo, i.hi + j.hi)\n", | |
" return \"success\"\n", | |
"\n", | |
"\n", | |
"try:\n", | |
" temp = sum_intervals(inter, inter)\n", | |
"except:\n", | |
" print('Box Test1', \"\\n \\tBoxing: False\")\n", | |
"else:\n", | |
" print('Box Test1', \"\\n \\tBoxing: True\", \"\\n \\tAttributes: \" + str(temp.lo==4.2) + \" \" + str(temp.hi==6.2))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python (latest_versions)", | |
"language": "python", | |
"name": "latest_versions" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment