Skip to content

Instantly share code, notes, and snippets.

@SqrtRyan
Created March 28, 2020 01:56
Show Gist options
  • Select an option

  • Save SqrtRyan/828c10a142a7afe86d7b5fd5751eaca7 to your computer and use it in GitHub Desktop.

Select an option

Save SqrtRyan/828c10a142a7afe86d7b5fd5751eaca7 to your computer and use it in GitHub Desktop.
#TAG Strassen's Algorithm
from rp import * #pip install rp
import numba #pip install numba
def strassen(A,B):
#See https://4.bp.blogspot.com/-vIUyUdAtSpo/V-N3jvhf4RI/AAAAAAAAibE/85PpZt3rS7MYNq17WNrTshpW01D5Ad-bQCLcB/s1600/strassen%2Balgorithm.GIF
global total_strassen_additions,total_strassen_multiplications
A=np.asarray(A)
B=np.asarray(B)
NA=matrix_size(A)
NB=matrix_size(B)
N=assert_equality(NA,NB)
size_threshold=2**7#For speed, make this a large number like 2**11. But since we want to know how many multiplications we need, let's do pure strassen by setting the threshold to 1.
if N<=size_threshold:
total_strassen_multiplications+=N**3
total_strassen_additions+=N**3-N**2
return naive_matrix_multiply(A,B)
A11=A[:N//2,:N//2];A21=A[N//2:,:N//2]
A12=A[:N//2,N//2:];A22=A[N//2:,N//2:]
B11=B[:N//2,:N//2];B21=B[N//2:,:N//2]
B12=B[:N//2,N//2:];B22=B[N//2:,N//2:]
P1=strassen(A11+A22,B11+B22)
P2=strassen(A21+A22,B11 )
P3=strassen(A11 ,B12-B22)
P4=strassen(A22 ,B21-B11)
P5=strassen(A11+A12,B22 )
P6=strassen(A21-A11,B11+B12)
P7=strassen(A12-A22,B21+B22)
total_strassen_additions+=10*(N//2)**2
C11=P1+P4-P5+P7;C12=P3+P5
C21=P2+P4 ;C22=P1-P2+P3+P6
total_strassen_additions+=8*(N//2)**2
cat=np.concatenate
return cat((cat((C11,C21),axis=0),
cat((C12,C22),axis=0)),axis=1)
def is_power_of_two(n):
while n>=1:
if n==1:return True
n/=2
return False
def matrix_size(matrix):
#assert len(matrix.shape)==2
N=assert_equality(*matrix.shape)#Assert it's a square matrix
assert is_power_of_two(N)
return N
def naive_matrix_multiply(A,B):
return A@B
#THIS CODE IS TOO SLOW FOR N=2^12. I'VE LOST PATIENCE. IT IS CORRECT THOUGH (Uncomment the following function to verify for yourself. It will override the previous definition)
# @numba.njit # Use numba to speed this code up by a LOT (on N=8, decreased runtime from 27 seconds to .02 seconds)
# def naive_matrix_multiply(A,B):
# N=len(A)
# C=np.zeros((N,N))
# for rownum in range(N):
# for colnum in range(N):
# C[rownum][colnum]=np.dot(A[rownum][index]*B[index][colnum])
# return C
total_strassen_additions=0
total_strassen_multiplications=0
def test(K):
global total_strassen_additions,total_strassen_multiplications
total_strassen_additions=total_strassen_multiplications=0
print("Testing for square matrices of size 2^"+str(K)+":")
N=2**K
A=np.random.rand(N,N)
B=np.random.rand(N,N)
tic()
C=A@B#Builtin matrix multipliation
print('\tNumpy time:',toc(),'seconds')#Numpy's builtin matrix multiplication is blazingly fast...
tic()
C_strassen=strassen(A,B)
assert np.allclose(C,C_strassen)#Make sure C_naive gave the right answer
print('\tStrassen time:',toc(),'seconds')
print("\tTotal number of additions for strassen:",total_strassen_additions)
print("\tTotal number of multiplications for strassen:",total_strassen_multiplications)
tic()
C_naive=naive_matrix_multiply(A,B)
total_naive_additions=N**3-N**2
total_naive_multiplications=N**3
assert np.allclose(C,C_naive)#Make sure C_naive gave the right answer
print('\tNaive time:',toc(),'seconds')
print("\tTotal number of additions for naive:",total_naive_additions)
print("\tTotal number of multiplications for naive:",total_naive_multiplications)
test(K=10)
test(K=12)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment