Skip to content

Instantly share code, notes, and snippets.

@yangyushi
Last active March 31, 2020 12:15
Show Gist options
  • Save yangyushi/00bf3f394d1d5a3dd8735422d1ad773c to your computer and use it in GitHub Desktop.
Save yangyushi/00bf3f394d1d5a3dd8735422d1ad773c to your computer and use it in GitHub Desktop.
This gist shows a `numba.njit` accelerated code is faster than its `einsum` equivalent for composing/applying a lot of rotation matrices
from time import time
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def ndot1(A, B):
"""
Calculate C with
C[i, j, k] = sum_q( A[i, j, q] · B[i, q, k] )
Args:
A (np.ndarray): shape is (N, 3, 3)
B (np.ndarray): shape is (N, 3, 3)
Return:
np.ndarray: shape is (N, 3, 3)
"""
C = np.zeros(A.shape)
for i in prange(A.shape[0]):
for j in range(3):
for k in range(3):
for q in range(3):
C[i, j, k] += A[i, j, q] * B[i, q, k]
return C
@njit(parallel=True)
def ndot2(A, B):
"""
Calculate C with
C[i, j] = sum_q( A[i, j, q] · B[i, q] )
Args:
A (np.ndarray): shape is (N, 3, 3)
B (np.ndarray): shape is (N, 3)
Return:
np.ndarray: shape is (N, 3)
"""
C = np.zeros(B.shape)
for i in prange(A.shape[0]):
for j in range(3):
for q in range(3):
C[i, j] += A[i, j, q] * B[i, q]
return C
print("Composing Rotations")
A = np.random.random((5000000, 3, 3))
B = np.random.random((5000000, 3, 3))
C2 = ndot1(A, B)
t0 = time()
C1 = np.einsum('ijq,iqk->ijk', A, B)
print('einsum: ', time() - t0)
t0 = time()
C2 = ndot1(A, B)
print('numba: ', time() - t0)
assert np.allclose(C1, C2)
print("\nApplying Rotations")
A = np.random.random((10000000, 3, 3))
B = np.random.random((10000000, 3))
C2 = ndot2(A, B)
t0 = time()
C1 = np.einsum('ijq,iq->ij', A, B)
print('einsum: ', time() - t0)
t0 = time()
C2 = ndot2(A, B)
print('numba: ', time() - t0)
assert np.allclose(C1, C2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment