Created
October 26, 2016 14:32
-
-
Save rutj3/bbf2137dd85c9abc8b555195f10886d3 to your computer and use it in GitHub Desktop.
numba_speedup
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import itertools\n", | |
"import sys\n", | |
"from numba import njit, jit\n", | |
"import math\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Original Python" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def fast_glynn_perm(M):\n", | |
" row_comb = []\n", | |
" for column in range(len(M[0])):\n", | |
" t = 0\n", | |
" for row in M:\n", | |
" t += row[column]\n", | |
" row_comb.append(t)\n", | |
" \n", | |
" n=len(M)\n", | |
"\n", | |
" total = 0\n", | |
" old_grey = 0\n", | |
" sign = +1\n", | |
"\n", | |
"# binary_power_dict = {2**i:i for i in range(n)}\n", | |
" num_loops = 2**(n-1)\n", | |
"\n", | |
" for bin_index in range(1, num_loops + 1):\n", | |
" i_product = 1\n", | |
" for row_num in row_comb:\n", | |
" i_product = i_product * row_num\n", | |
" total += sign * i_product\n", | |
"# total += sign * reduce(mul, row_comb)\n", | |
"\n", | |
" new_grey = bin_index^(bin_index//2)\n", | |
" grey_diff = old_grey ^ new_grey\n", | |
"\n", | |
" grey_diff_index = grey_diff.bit_length()-1\n", | |
"# grey_diff_index = binary_power_dict[grey_diff]\n", | |
"\n", | |
" new_vector = M[grey_diff_index]\n", | |
" #direction = 2 * cmp(old_grey,new_grey)\n", | |
" direction = 2 * (old_grey > new_grey) - (old_grey < new_grey)\n", | |
"\n", | |
" for i in range(n):\n", | |
" row_comb[i] += new_vector[i] * direction\n", | |
"\n", | |
" sign = -sign\n", | |
" old_grey = new_grey\n", | |
"\n", | |
" return total/num_loops " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"@njit\n", | |
"def bitcounter(n):\n", | |
" # from http://stackoverflow.com/a/2654159/1755432\n", | |
" if n == 0:\n", | |
" return 0\n", | |
" else:\n", | |
" return int(np.floor(np.log2(n)) + 1)\n", | |
" \n", | |
"@njit\n", | |
"def fast_glynn_perm_numba(M):\n", | |
" \n", | |
" ys, xs = M.shape\n", | |
" \n", | |
" row_comb = np.zeros((ys), dtype=M.dtype)\n", | |
" \n", | |
" for column in range(xs):\n", | |
" for row in range(ys):\n", | |
" row_comb[column] += M[row, column]\n", | |
" \n", | |
" n=len(M)\n", | |
"\n", | |
" total = 0\n", | |
" old_grey = 0\n", | |
" sign = +1\n", | |
"\n", | |
" num_loops = 2**(n-1)\n", | |
"\n", | |
" for bin_index in range(1, num_loops + 1):\n", | |
" i_product = 1\n", | |
" for row_num in row_comb:\n", | |
" i_product = i_product * row_num\n", | |
" total += sign * i_product\n", | |
"\n", | |
" new_grey = bin_index^(bin_index//2)\n", | |
" grey_diff = old_grey^new_grey\n", | |
" \n", | |
" grey_diff_index = bitcounter(grey_diff)-1\n", | |
"\n", | |
" new_vector = M[grey_diff_index]\n", | |
" \n", | |
" direction = 2 * (old_grey > new_grey) - (old_grey < new_grey)\n", | |
"\n", | |
" for i in range(n):\n", | |
" row_comb[i] += new_vector[i] * direction\n", | |
"\n", | |
" sign = -sign\n", | |
" old_grey = new_grey\n", | |
"\n", | |
" return total/num_loops\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Numba, Python\n", | |
" -19106.52, -19106.52\n", | |
" -29196.57, -29196.57\n", | |
" -11757.61, -11757.61\n", | |
" -7109.50, -7109.50\n", | |
" -32504.01, -32504.01\n", | |
" -23125.93, -23125.93\n", | |
" -10681.19, -10681.19\n", | |
" -41466.23, -41466.23\n", | |
" -16332.61, -16332.61\n", | |
" -8473.75, -8473.75\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Numba, Python')\n", | |
"for i in range(10):\n", | |
" \n", | |
" p = 6\n", | |
" A = np.random.rand(p,p)\n", | |
"\n", | |
" print('%10.2f, %10.2f' % (fast_glynn_perm_numba(A), fast_glynn_perm(A)))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"100000 loops, best of 3: 3.55 µs per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit \n", | |
"fast_glynn_perm_numba(A)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 265 µs per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"fast_glynn_perm(A)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment