Last active
March 31, 2020 12:15
-
-
Save yangyushi/00bf3f394d1d5a3dd8735422d1ad773c to your computer and use it in GitHub Desktop.
This gist shows a `numba.njit` accelerated code is faster than its `einsum` equivalent for composing/applying a lot of rotation matrices
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import time | |
import numpy as np | |
from numba import njit, prange | |
@njit(parallel=True) | |
def ndot1(A, B): | |
""" | |
Calculate C with | |
C[i, j, k] = sum_q( A[i, j, q] · B[i, q, k] ) | |
Args: | |
A (np.ndarray): shape is (N, 3, 3) | |
B (np.ndarray): shape is (N, 3, 3) | |
Return: | |
np.ndarray: shape is (N, 3, 3) | |
""" | |
C = np.zeros(A.shape) | |
for i in prange(A.shape[0]): | |
for j in range(3): | |
for k in range(3): | |
for q in range(3): | |
C[i, j, k] += A[i, j, q] * B[i, q, k] | |
return C | |
@njit(parallel=True) | |
def ndot2(A, B): | |
""" | |
Calculate C with | |
C[i, j] = sum_q( A[i, j, q] · B[i, q] ) | |
Args: | |
A (np.ndarray): shape is (N, 3, 3) | |
B (np.ndarray): shape is (N, 3) | |
Return: | |
np.ndarray: shape is (N, 3) | |
""" | |
C = np.zeros(B.shape) | |
for i in prange(A.shape[0]): | |
for j in range(3): | |
for q in range(3): | |
C[i, j] += A[i, j, q] * B[i, q] | |
return C | |
print("Composing Rotations") | |
A = np.random.random((5000000, 3, 3)) | |
B = np.random.random((5000000, 3, 3)) | |
C2 = ndot1(A, B) | |
t0 = time() | |
C1 = np.einsum('ijq,iqk->ijk', A, B) | |
print('einsum: ', time() - t0) | |
t0 = time() | |
C2 = ndot1(A, B) | |
print('numba: ', time() - t0) | |
assert np.allclose(C1, C2) | |
print("\nApplying Rotations") | |
A = np.random.random((10000000, 3, 3)) | |
B = np.random.random((10000000, 3)) | |
C2 = ndot2(A, B) | |
t0 = time() | |
C1 = np.einsum('ijq,iq->ij', A, B) | |
print('einsum: ', time() - t0) | |
t0 = time() | |
C2 = ndot2(A, B) | |
print('numba: ', time() - t0) | |
assert np.allclose(C1, C2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment