Skip to content

Instantly share code, notes, and snippets.

@matthieubulte
Last active January 14, 2025 12:45
Show Gist options
  • Save matthieubulte/ef0a5edf6edbc544c476072e6d481be7 to your computer and use it in GitHub Desktop.
Save matthieubulte/ef0a5edf6edbc544c476072e6d481be7 to your computer and use it in GitHub Desktop.
matmul.py
######################################################################## PREP
import numpy as np
from scipy.linalg.blas import sgemm
import jax
import jax.numpy as jnp
DIM = 4096
accelerate = ctypes.cdll.LoadLibrary(
"/System/Library/Frameworks/Accelerate.framework/Accelerate"
)
a_sgemm = accelerate.cblas_sgemm
a_sgemm.argtypes = [
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_float,
ndpointer(dtype=np.float32),
ctypes.c_int,
ndpointer(dtype=np.float32),
ctypes.c_int,
ctypes.c_float,
ndpointer(dtype=np.float32),
ctypes.c_int,
]
# Constants
CblasRowMajor = 101
CblasColMajor = 102
CblasNoTrans = 111
CblasTrans = 112
def matmul_accelerate(A, B, C):
a_sgemm(
CblasRowMajor,
CblasNoTrans,
CblasNoTrans,
DIM,
DIM,
DIM,
1.0,
A,
DIM,
B,
DIM,
0.0,
C,
DIM,
)
return C
A = np.random.rand(DIM, DIM).astype(dtype=np.float32)
B = np.random.rand(DIM, DIM).astype(dtype=np.float32)
C = np.zeros((DIM, DIM), dtype=np.float32)
bA = np.asfortranarray(A, dtype=np.float32)
bB = np.asfortranarray(B, dtype=np.float32)
bC = np.zeros((DIM, DIM), dtype=np.float32, order="F")
cA = np.ascontiguousarray(A, dtype=np.float32)
cB = np.ascontiguousarray(B, dtype=np.float32)
cC = np.zeros((DIM, DIM), dtype=np.float32)
jA = jax.device_put(A)
jB = jax.device_put(B)
######################################################################## EVAL
%timeit A @ B
446 ms ± 30.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.matmul(A, B, out=C)
441 ms ± 25.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit sgemm(1.0, bA, bB, beta=1.0, c=bC, trans_a=0, trans_b=0, overwrite_c=1)
460 ms ± 37.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_accelerate(cA, cB, cC)
168 ms ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit jnp.matmul(jA, jB).block_until_ready()
352 ms ± 24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment