Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created October 21, 2022 11:16
Show Gist options
  • Save Steboss89/aa76f0a1990b3d4676ac1f4a9fc3cba1 to your computer and use it in GitHub Desktop.
Save Steboss89/aa76f0a1990b3d4676ac1f4a9fc3cba1 to your computer and use it in GitHub Desktop.
def strassen(A, B):
"""
Parameters
----------
A: np.array: block matrix A
B: np.array: block matrix B
Return
------
C: np.array: product matrix C
"""
current_size = len(A) # we are taking just matrix A, as B has the same size
if current_size <= THRESHOLD:
# if the current size is below a threshold, run normally
C = np.matmul(A, B)
return C
else:
current_new_size = current_size//2
# now extract all the blocks from the input matrix of size current_new_size
a11 = A[0:current_new_size,0:current_new_size] # top left
a12 = A[0:current_new_size,current_new_size:current_size] # top right
a21 = A[current_new_size:current_size, 0:current_new_size] # bottom left
a22 = A[current_new_size:current_size, current_new_size:current_size] # bottom right
# same fo rB
b11 = B[0:current_new_size,0:current_new_size] # top left
b12 = B[0:current_new_size,current_new_size:current_size] # top right
b21 = B[current_new_size:current_size, 0:current_new_size] # bottom left
b22 = B[current_new_size:current_size, current_new_size:current_size] # bottom right
# roll over Strassen
a_ = np.add(a11, a22)
b_ = np.add(b11, b22)
prod1 = strassen(a_, b_) # iterate over the first multiplication
a_ = np.add(a21, a22)
prod2 = strassen(a_, b11) # second product
b_ = np.subtract(b12, b22)
prod3 = strassen(a11, b_) # third product
b_ = np.subtract(b21, b11)
prod4 = strassen(a22, b_) # fourth product
a_ = np.add(a11, a12)
prod5 = strassen(a_, b22) # fifth product
a_ = np.subtract(a21, a11)
b_ = np.add(b11, b12)
prod6 = strassen(a_, b_) # sixth product
a_ = np.subtract(a12, a22)
b_ = np.add(b21, b22)
prod7 = strassen(a_, b_) # seventh product
# compute the c element for the product matrix
c12 = np.add(prod3, prod5)
c21 = np.add(prod2, prod4)
a_ = np.add(prod1, prod4)
b_ = np.add(a_, prod7)
c11 = np.subtract(b_, prod5)
a_ = np.add(prod1, prod3)
b_ = np.add(a_, prod6)
c22 = np.subtract(b_, prod2)
# return the final matrix
C = np.zeros([current_size, current_size])
C[0:current_new_size, 0:current_new_size] = c11 # top left
C[0:current_new_size,current_new_size:current_size] = c12 # top right
C[current_new_size:current_size, 0:current_new_size] = c21 # bottom left
C[current_new_size:current_size, current_new_size:current_size] = c22 # bottom right
return C
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment