Skip to content

Instantly share code, notes, and snippets.

@sn1p3r46
Last active January 18, 2018 00:24
Show Gist options
  • Save sn1p3r46/f5f2b36880a5a79023cee78a60e93c39 to your computer and use it in GitHub Desktop.
Save sn1p3r46/f5f2b36880a5a79023cee78a60e93c39 to your computer and use it in GitHub Desktop.
Recursive Matrix Multiplication
#!/usr/bin/python3
# ref http://www.cs.mcgill.ca/~pnguyen/251F09/matrix-mult.pdf
import numpy as np
def matrix_mul(A,B):
if A.shape == (1,1):
return A.dot(B)
else:
dim = A.shape[0]//2
X1 = matrix_mul(A[:dim,:dim],B[:dim,:dim]) # A11 B11
X2 = matrix_mul(A[:dim,dim:],B[dim:,:dim]) # A12 B21
X3 = matrix_mul(A[:dim,:dim],B[:dim,dim:]) # A11 B12
X4 = matrix_mul(A[:dim,dim:],B[dim:,dim:]) # A12 B22
X5 = matrix_mul(A[dim:,:dim],B[:dim,:dim]) # A21 B11
X6 = matrix_mul(A[dim:,dim:],B[dim:,:dim]) # A22 B21
X7 = matrix_mul(A[dim:,:dim],B[:dim,dim:]) # A21 B12
X8 = matrix_mul(A[dim:,dim:],B[dim:,dim:]) # A22 B22
C = np.empty((2*dim,2*dim))
C[:dim,:dim] = X1 + X2
C[:dim,dim:] = X3 + X4
C[dim:,:dim] = X5 + X6
C[dim:,dim:] = X7 + X8
print (C)
return C
if __name__=="__main__":
dim = 8
A = np.asarray([[1+j*dim+i for i in range(dim)] for j in range(dim)])
print (A)
print (matrix_mul(A,A))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment