Skip to content

Instantly share code, notes, and snippets.

@loganbvh
Last active May 17, 2023 19:08
Show Gist options
  • Save loganbvh/6401fd8a744fb950e4fcdc472d0833ad to your computer and use it in GitHub Desktop.
Save loganbvh/6401fd8a744fb950e4fcdc472d0833ad to your computer and use it in GitHub Desktop.
Fast matmul with numba
import numba
import numpy as np
@numba.njit(fastmath=True, parallel=True)
def fast_matmul(A: np.ndarray, B: np.ndarray) -> np.ndarray:
"""Performs ``A @ B`` for 2D matrices efficiently using numba."""
# I have found that pre-allocating ``out`` and passing it in as an argument does
# not speed things up very much.
# On my M1 MacBook with 10 cores, this function is faster and has lower CPU
# utilization than A @ B using numpy/BLAS. It is also slightly faster than
# jax.numpy.matmul(A, B) running on the CPU.
# Your mileage may vary...
assert A.ndim == 2, A.ndim
assert B.ndim == 2, B.ndim
assert A.shape[1] == B.shape[0], f"{A.shape[1]} != {B.shape[0]}"
out = np.empty((A.shape[0], B.shape[1]), dtype=A.dtype)
for i in numba.prange(A.shape[0]):
for j in range(B.shape[1]):
tmp = 0.0
for k in range(B.shape[0]):
tmp += A[i, k] * B[k, j]
out[i, j] = tmp
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment