Last active
September 7, 2018 08:30
-
-
Save JonnoFTW/e9ced26d2c523d1937927dc0e59d3f89 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Rows: 11, cols=1 44 Bytes | |
CL Took: 0:00:00.003606 | |
NP Took: 0:00:00.000021 | |
Rows: 10, cols=3 120 Bytes | |
CL Took: 0:00:00.000312 | |
NP Took: 0:00:00.000012 | |
Rows: 10000, cols=1000 40.0 MB | |
CL Took: 0:00:00.005533 | |
NP Took: 0:00:00.017710 | |
Rows: 50000, cols=1200 240.0 MB | |
CL Took: 0:00:00.028232 | |
NP Took: 0:00:00.098110 | |
Rows: 1200, cols=50000 240.0 MB | |
CL Took: 0:00:00.025554 | |
NP Took: 0:00:00.064121 | |
Rows: 784, cols=60000 188.2 MB | |
CL Took: 0:00:00.019778 | |
NP Took: 0:00:00.048747 | |
Rows: 60000, cols=784 188.2 MB | |
CL Took: 0:00:00.023760 | |
NP Took: 0:00:00.097961 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import humanize | |
import numpy as np | |
import pyopencl as cl | |
from pyopencl import cltypes | |
from pyopencl import array | |
from datetime import datetime | |
device = cl.get_platforms()[0].get_devices()[0] | |
ctx = cl.Context([device]) | |
queue = cl.CommandQueue(ctx) | |
np.set_printoptions(suppress=True) | |
src = """ | |
__kernel void shuffle_data( | |
__global float* data, // array to be shuffled | |
__constant uint* swaps // array of swaps to perform | |
) { | |
const int row = get_global_id(1); | |
const int col = get_global_id(0); | |
const int num_cols = get_global_size(0); | |
const int idx1 = num_cols * swaps[2*row] + col; | |
const int idx2 = num_cols * swaps[2*row+1] + col; | |
// if(col==0) printf("Swapping %d with %d\\n", swaps[2*row], swaps[2*row+1]); | |
float tmp = data[idx2]; | |
data[idx2] = data[idx1]; | |
data[idx1] = tmp; | |
} | |
""" | |
shuffle_prog = cl.Program(ctx, src).build() | |
shuffle_krnl = shuffle_prog.shuffle_data | |
def read_only_arr(numbytes): | |
return cl.Buffer(ctx, cl.mem_flags.READ_ONLY, numbytes) | |
def shuffle(x_data, rows, cols): | |
""" | |
Odd sized row count will not have 1 row shuffled | |
:param x_data: | |
:param rows: | |
:param cols: | |
:param swaps_g: | |
:return: | |
""" | |
swaps_np = np.arange(rows, dtype=cltypes.uint) | |
np.random.shuffle(swaps_np) | |
swaps_g = array.to_device(queue, swaps_np, allocator=read_only_arr) | |
e1 = shuffle_krnl(queue, (cols, len(swaps_np) // 2), None, x_data, swaps_g.data) | |
e1.wait() | |
return swaps_g | |
def test_shuffle(rows, cols, verbose=False): | |
x_data_np = np.arange(rows * cols, dtype=cl.cltypes.float).reshape(rows, cols) | |
x_data = array.to_device(queue, x_data_np) | |
if verbose: | |
print("Before:") | |
print("X:") | |
for idx, row in enumerate(x_data_np): | |
print(idx, row) | |
print(f"Rows: {rows}, cols={cols} {humanize.naturalsize(x_data_np.nbytes)}") | |
start = datetime.now() | |
swaps = shuffle(x_data.data, rows, cols) | |
print("\tCL Took:", datetime.now() - start) | |
if verbose: | |
print("After:") | |
for idx, row in enumerate(x_data.get()): | |
print(idx, row) | |
print("swaps:") | |
print(swaps.get()) | |
start = datetime.now() | |
np.random.shuffle(x_data_np) | |
print("\tNP Took: ", datetime.now() - start) | |
if __name__ == "__main__": | |
test_shuffle(11, 1) | |
test_shuffle(10, 3) | |
test_shuffle(10000, 1000) | |
test_shuffle(50000, 1200) | |
test_shuffle(1200, 50000) | |
test_shuffle(784, 60000) | |
test_shuffle(60000, 784) | |
# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment