Last active
October 31, 2021 12:53
-
-
Save smontanaro/80f788a506d2f41156dae779562fd08d to your computer and use it in GitHub Desktop.
Simple multi-threaded implementation of matrix multiplication in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""multi-threaded matrix multiply, demonstrating no-GIL benefits. | |
No help, crappy CLI. Run something like this: | |
python3 matmul.py NTHREADS SIZE | |
NTHREADS is (wait for it) the number of threads to spin up. | |
SIZE is the number of cells in your A array. Your B array will always | |
be twice that size. | |
Array shape is computed by... Well, just read the code. It's the first | |
thing I thought of. | |
""" | |
import queue | |
import sys | |
import threading | |
import numpy as np | |
# Unthreaded versions (no Thread or Queue usage) | |
def vecmul(a, b): | |
"basic vector multiply (sans threads)" | |
result = 0.0 | |
assert len(a) == len(b), (len(a), len(b)) | |
for a1, b1 in zip(a, b): | |
result += a1 * b1 | |
return result | |
def matmul(a, b): | |
"matrix multiple sans threads" | |
result = np.zeros(a.shape[0] * b.shape[1]).reshape(a.shape[0], b.shape[1]) | |
print("a:", a.shape, "b:", b.shape, "result:", result.shape, "->", | |
result.shape[0] * result.shape[1]) | |
result = np.zeros(a.shape[0] * b.shape[1]).reshape(a.shape[0], b.shape[1]) | |
b = b.transpose() | |
a = [a[i] for i in range(a.shape[0])] | |
b = [b[i] for i in range(b.shape[0])] | |
for i in range(len(a)): | |
for j in range(len(b)): | |
result[i][j] = vecmul(a[i], b[j]) | |
return result | |
# Original threaded versions | |
def vecmul_t(qin, qout): | |
"Multiply one-row A and B, returning scalar result." | |
while True: | |
args = qin.get() | |
if args is None: | |
qout.put(None) | |
return | |
(a, b, i, j) = args | |
result = vecmul(a, b) | |
qout.put((result, i, j)) | |
def matmul_t(a, b, qin, qout, nthreads): | |
"Matrix multiply A and B, returning matrix result." | |
result = np.zeros(a.shape[0] * b.shape[1]).reshape(a.shape[0], b.shape[1]) | |
print("a:", a.shape, "b:", b.shape, "result:", result.shape, "->", | |
result.shape[0] * result.shape[1]) | |
b = b.transpose() | |
a = [a[i] for i in range(a.shape[0])] | |
b = [b[i] for i in range(b.shape[0])] | |
for i in range(len(a)): | |
for j in range(len(b)): | |
qin.put((a[i], b[j], i, j)) | |
# nthreads sentinels to signal end of processing | |
for _ in range(nthreads): | |
qin.put(None) | |
n = 0 | |
while True: | |
stuff = qout.get() | |
if stuff is None: | |
n += 1 | |
if n == nthreads: | |
return result | |
continue | |
(val, i, j) = stuff | |
result[i][j] = val | |
return result | |
def prime_factors(n): | |
factors = [] | |
while n % 2 == 0: | |
factors.append(2) | |
n //= 2 | |
for i in range(3, n, 2): | |
if i * i > n: | |
break | |
while n % i == 0: | |
factors.append(i) | |
n //= i | |
if n > 2: | |
factors.append(n) | |
return factors | |
def test(): | |
nthreads = int(sys.argv[1]) | |
size = int(sys.argv[2]) | |
# Just to avoid problems determining array dims (I'm too lazy for | |
# that). | |
assert size >= 10 | |
factors = prime_factors(size) | |
common_dim = 1 | |
while common_dim * common_dim < size: | |
common_dim *= factors[-1] | |
del factors[-1] | |
off_dim = size // common_dim | |
a = np.random.random(size).reshape(off_dim, common_dim) | |
b = np.random.random(size * 2).reshape(common_dim, off_dim * 2) | |
if nthreads == 0: | |
# straightforward matrix multiply | |
res = matmul(a, b) | |
# nres = np.matmul(a, b) | |
# print(np.allclose(res, nres)) | |
return | |
threads = [] | |
qin = queue.Queue() | |
qout = queue.Queue() | |
for _ in range(nthreads): | |
t = threading.Thread(target=vecmul_t, args=(qin, qout), daemon=True) | |
threads.append(t) | |
t.start() | |
res = matmul_t(a, b, qin, qout, nthreads) | |
# nres = np.matmul(a, b) | |
# print(np.allclose(res, nres)) | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment