Last active
March 23, 2017 05:33
-
-
Save mcg1969/694204e08df3a7f9e327e4791af7a721 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## A simple way to unify GUFunc signatures" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from numba.npyufunc.sigparse import parse_signature\n", | |
"from numba.sigutils import normalize_signature" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"GUFuncs currently need two signatures: a type-shape signature for the JIT, and a NumPY GUFunc signature for, well, the GUFunc processing." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((array(float64, 2d, A), float64, array(float64, 1d, A)),\n", | |
" ([('s', 't'), ()], [()]))" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inp = normalize_signature('f8[:,:], f8, f8[:]')[0]\n", | |
"otp = parse_signature('(s,t), () -> ()')\n", | |
"inp, otp" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"It sure would be nice if we could do this in one fell swoop. It turned out to be relatively simple to get most of the way there." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from numba import types\n", | |
"\n", | |
"class SigDict(object):\n", | |
" def __init__(self, mod):\n", | |
" self.mod = mod\n", | |
" self.dims = []\n", | |
" def __getitem__(self, symbol):\n", | |
" if symbol in self.mod:\n", | |
" return self.mod[symbol]\n", | |
" else:\n", | |
" self.dims.append(symbol)\n", | |
" return slice(None)\n", | |
" \n", | |
"def parse_gufunc_half(sig, output=False):\n", | |
" sig_dict = SigDict(types.__dict__)\n", | |
" sig = eval(sig, {}, sig_dict)\n", | |
" if not isinstance(sig, tuple): sig = sig,\n", | |
" dims = sig_dict.dims\n", | |
" sigtoks = []\n", | |
" gutoks = []\n", | |
" for tok in sig:\n", | |
" ndim = getattr(tok, 'ndim', 0)\n", | |
" gutoks.append(tuple(dims[:ndim]))\n", | |
" if output and ndim == 0:\n", | |
" tok = tok[:]\n", | |
" sigtoks.append(tok)\n", | |
" dims = dims[ndim:]\n", | |
" return sigtoks, gutoks\n", | |
" \n", | |
"def parse_gufunc_signature(sig):\n", | |
" inputs, outputs = sig.split('->')\n", | |
" inputs, inp_gu = parse_gufunc_half(inputs)\n", | |
" outputs, otp_gu = parse_gufunc_half(outputs, True)\n", | |
" return tuple(inputs + outputs), (inp_gu, otp_gu)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((array(float64, 2d, A), float64, array(float64, 1d, A)),\n", | |
" ([('s', 't'), ()], [()]))" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inp, otp = parse_gufunc_signature('f8[s,t], f8 -> f8')\n", | |
"inp, otp" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Unfortunately these signatures cannot be passed directly into Numba because the @guvectorize decorator requires stirngs. The second is easily converted. The first would take a bit more effort." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'(s,t), () -> ()'" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"' -> '.join(', '.join('(%s)' % ','.join(y) for y in x) for x in otp)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Missing:\n", | |
"- Full verification of the gufunc dimension signature\n", | |
"- C/Fortran array ordering support" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import numba as nb\n", | |
"def mcguvec(sig, **kwargs):\n", | |
" inp, otp = parse_gufunc_signature(sig)\n", | |
" otp = ' -> '.join(', '.join('(%s)' % ','.join(y) for y in x) for x in otp)\n", | |
" return nb.guvectorize([inp], otp, **kwargs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"@mcguvec('f8[m,n], f8[n,p] -> f8[m,p]')\n", | |
"def mymult(X, Y, Z):\n", | |
" for i in range(X.shape[0]):\n", | |
" for j in range(Y.shape[1]):\n", | |
" Z[i,j] = np.dot(X[i,:], Y[:,j])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(2, 3, 5)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"X = np.random.randn(2,3,4)\n", | |
"Y = np.random.randn(4,5)\n", | |
"mymult(X,Y).shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [conda env:_testy]", | |
"language": "python", | |
"name": "conda-env-_testy-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment