Created
March 3, 2016 11:31
-
-
Save pitrou/e12353a5839fb60d3fc1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numba as nb | |
import numpy as np | |
from time import process_time as clock | |
class Timer: | |
def __init__(self,title=None): | |
self.title=title | |
def __enter__(self): | |
if self.title: | |
print( 'Beginning {0}'.format( self.title ) ) | |
self.start = clock() | |
return self | |
def __exit__(self, *args): | |
self.end = clock() | |
self.interval = self.end - self.start | |
if self.title: | |
print( '{1} took {0:0.4f} seconds'.format( self.interval, self.title ) ) | |
else: | |
pass# | |
#print( 'Timer took {0:0.4f} seconds'.format( self.interval ) ) | |
@nb.jit(nopython=True) | |
def insertion_sort(A, low, high): | |
""" | |
Insertion sort A[low:high + 1]. Note the inclusive bounds. | |
""" | |
for i in range(low + 1, high + 1): | |
v = A[i] | |
# Insert v into A[low:i] | |
j = i | |
while j > low and v < A[j - 1]: | |
# Make place for moving A[i] downwards | |
A[j] = A[j - 1] | |
j -= 1 | |
A[j] = v | |
@nb.jit( nopython=True ) | |
def merge2(x): | |
n = x.shape[0] | |
r = x.copy() | |
tgt = np.zeros_like(r) | |
# Start with an insertion sort of small chunks | |
width = 25 | |
i = 0 | |
while i < n: | |
istart = i | |
iend = istart + width | |
if iend > n: | |
iend = n | |
insertion_sort(r, istart, iend - 1) | |
i = iend | |
# Merge sorted chunks, bottom-up | |
while width < n: | |
i = 0 | |
while i < n: | |
istart = i | |
imid = i + width | |
iend = imid + width | |
# i has become i + 2*width | |
if imid > n: | |
imid = n | |
if iend > n: | |
iend = n | |
i = iend | |
_merge2(r, tgt, istart, imid, iend) | |
# Swap them round, so that the partially sorted tgt becomes the result, | |
# and the result becomes a new target buffer | |
r, tgt = tgt, r | |
width *= 2 | |
return r | |
@nb.jit( nopython=True ) | |
def _merge2(src_arr, tgt_arr, istart, imid, iend): | |
""" The merge part of the merge sort """ | |
i0 = istart | |
i1 = imid | |
ipos = istart | |
v0 = src_arr[i0] | |
v1 = src_arr[i1] | |
while i0 < imid and i1 < iend: | |
if v0 <= v1: | |
tgt_arr[ipos] = v0 | |
i0 += 1 | |
v0 = src_arr[i0] | |
else: | |
tgt_arr[ipos] = v1 | |
i1 += 1 | |
v1 = src_arr[i1] | |
ipos += 1 | |
while i0 < imid: | |
tgt_arr[ipos] = src_arr[i0] | |
ipos += 1 | |
i0 += 1 | |
while i1 < iend: | |
tgt_arr[ipos] = src_arr[i1] | |
ipos += 1 | |
i1 += 1 | |
def test_merge_multi(): | |
np.random.seed(42) | |
n0 = 20 | |
n1 = 500000 | |
nsteps = 30 | |
src = np.random.random_integers(0, n1, size=n1).astype(np.int32) | |
# JIT warmup | |
merge2(src[:2]) | |
for n in np.logspace(np.log10(n0), np.log10(n1), nsteps, dtype=np.intc): | |
x = src[:n] | |
with Timer() as t0: | |
r = merge2(x) | |
with Timer() as t1: | |
e = np.sort(x, kind='merge') | |
print('n = %6s => nb/np duration %.2f' % (n, t0.interval / t1.interval)) | |
np.testing.assert_equal(e, r) | |
test_merge_multi() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment