Skip to content

Instantly share code, notes, and snippets.

@mcg1969
Last active March 23, 2017 05:33
Show Gist options
  • Save mcg1969/694204e08df3a7f9e327e4791af7a721 to your computer and use it in GitHub Desktop.
Save mcg1969/694204e08df3a7f9e327e4791af7a721 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A simple way to unify GUFunc signatures"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from numba.npyufunc.sigparse import parse_signature\n",
"from numba.sigutils import normalize_signature"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"GUFuncs currently need two signatures: a type-shape signature for the JIT, and a NumPY GUFunc signature for, well, the GUFunc processing."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"((array(float64, 2d, A), float64, array(float64, 1d, A)),\n",
" ([('s', 't'), ()], [()]))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inp = normalize_signature('f8[:,:], f8, f8[:]')[0]\n",
"otp = parse_signature('(s,t), () -> ()')\n",
"inp, otp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It sure would be nice if we could do this in one fell swoop. It turned out to be relatively simple to get most of the way there."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from numba import types\n",
"\n",
"class SigDict(object):\n",
" def __init__(self, mod):\n",
" self.mod = mod\n",
" self.dims = []\n",
" def __getitem__(self, symbol):\n",
" if symbol in self.mod:\n",
" return self.mod[symbol]\n",
" else:\n",
" self.dims.append(symbol)\n",
" return slice(None)\n",
" \n",
"def parse_gufunc_half(sig, output=False):\n",
" sig_dict = SigDict(types.__dict__)\n",
" sig = eval(sig, {}, sig_dict)\n",
" if not isinstance(sig, tuple): sig = sig,\n",
" dims = sig_dict.dims\n",
" sigtoks = []\n",
" gutoks = []\n",
" for tok in sig:\n",
" ndim = getattr(tok, 'ndim', 0)\n",
" gutoks.append(tuple(dims[:ndim]))\n",
" if output and ndim == 0:\n",
" tok = tok[:]\n",
" sigtoks.append(tok)\n",
" dims = dims[ndim:]\n",
" return sigtoks, gutoks\n",
" \n",
"def parse_gufunc_signature(sig):\n",
" inputs, outputs = sig.split('->')\n",
" inputs, inp_gu = parse_gufunc_half(inputs)\n",
" outputs, otp_gu = parse_gufunc_half(outputs, True)\n",
" return tuple(inputs + outputs), (inp_gu, otp_gu)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"((array(float64, 2d, A), float64, array(float64, 1d, A)),\n",
" ([('s', 't'), ()], [()]))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inp, otp = parse_gufunc_signature('f8[s,t], f8 -> f8')\n",
"inp, otp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unfortunately these signatures cannot be passed directly into Numba because the @guvectorize decorator requires stirngs. The second is easily converted. The first would take a bit more effort."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"'(s,t), () -> ()'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' -> '.join(', '.join('(%s)' % ','.join(y) for y in x) for x in otp)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Missing:\n",
"- Full verification of the gufunc dimension signature\n",
"- C/Fortran array ordering support"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numba as nb\n",
"def mcguvec(sig, **kwargs):\n",
" inp, otp = parse_gufunc_signature(sig)\n",
" otp = ' -> '.join(', '.join('(%s)' % ','.join(y) for y in x) for x in otp)\n",
" return nb.guvectorize([inp], otp, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"@mcguvec('f8[m,n], f8[n,p] -> f8[m,p]')\n",
"def mymult(X, Y, Z):\n",
" for i in range(X.shape[0]):\n",
" for j in range(Y.shape[1]):\n",
" Z[i,j] = np.dot(X[i,:], Y[:,j])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(2, 3, 5)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"X = np.random.randn(2,3,4)\n",
"Y = np.random.randn(4,5)\n",
"mymult(X,Y).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda env:_testy]",
"language": "python",
"name": "conda-env-_testy-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment