Created
October 21, 2022 11:16
-
-
Save Steboss89/aa76f0a1990b3d4676ac1f4a9fc3cba1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def strassen(A, B): | |
""" | |
Parameters | |
---------- | |
A: np.array: block matrix A | |
B: np.array: block matrix B | |
Return | |
------ | |
C: np.array: product matrix C | |
""" | |
current_size = len(A) # we are taking just matrix A, as B has the same size | |
if current_size <= THRESHOLD: | |
# if the current size is below a threshold, run normally | |
C = np.matmul(A, B) | |
return C | |
else: | |
current_new_size = current_size//2 | |
# now extract all the blocks from the input matrix of size current_new_size | |
a11 = A[0:current_new_size,0:current_new_size] # top left | |
a12 = A[0:current_new_size,current_new_size:current_size] # top right | |
a21 = A[current_new_size:current_size, 0:current_new_size] # bottom left | |
a22 = A[current_new_size:current_size, current_new_size:current_size] # bottom right | |
# same fo rB | |
b11 = B[0:current_new_size,0:current_new_size] # top left | |
b12 = B[0:current_new_size,current_new_size:current_size] # top right | |
b21 = B[current_new_size:current_size, 0:current_new_size] # bottom left | |
b22 = B[current_new_size:current_size, current_new_size:current_size] # bottom right | |
# roll over Strassen | |
a_ = np.add(a11, a22) | |
b_ = np.add(b11, b22) | |
prod1 = strassen(a_, b_) # iterate over the first multiplication | |
a_ = np.add(a21, a22) | |
prod2 = strassen(a_, b11) # second product | |
b_ = np.subtract(b12, b22) | |
prod3 = strassen(a11, b_) # third product | |
b_ = np.subtract(b21, b11) | |
prod4 = strassen(a22, b_) # fourth product | |
a_ = np.add(a11, a12) | |
prod5 = strassen(a_, b22) # fifth product | |
a_ = np.subtract(a21, a11) | |
b_ = np.add(b11, b12) | |
prod6 = strassen(a_, b_) # sixth product | |
a_ = np.subtract(a12, a22) | |
b_ = np.add(b21, b22) | |
prod7 = strassen(a_, b_) # seventh product | |
# compute the c element for the product matrix | |
c12 = np.add(prod3, prod5) | |
c21 = np.add(prod2, prod4) | |
a_ = np.add(prod1, prod4) | |
b_ = np.add(a_, prod7) | |
c11 = np.subtract(b_, prod5) | |
a_ = np.add(prod1, prod3) | |
b_ = np.add(a_, prod6) | |
c22 = np.subtract(b_, prod2) | |
# return the final matrix | |
C = np.zeros([current_size, current_size]) | |
C[0:current_new_size, 0:current_new_size] = c11 # top left | |
C[0:current_new_size,current_new_size:current_size] = c12 # top right | |
C[current_new_size:current_size, 0:current_new_size] = c21 # bottom left | |
C[current_new_size:current_size, current_new_size:current_size] = c22 # bottom right | |
return C |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment