Created
March 28, 2020 01:56
-
-
Save SqrtRyan/828c10a142a7afe86d7b5fd5751eaca7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #TAG Strassen's Algorithm | |
| from rp import * #pip install rp | |
| import numba #pip install numba | |
| def strassen(A,B): | |
| #See https://4.bp.blogspot.com/-vIUyUdAtSpo/V-N3jvhf4RI/AAAAAAAAibE/85PpZt3rS7MYNq17WNrTshpW01D5Ad-bQCLcB/s1600/strassen%2Balgorithm.GIF | |
| global total_strassen_additions,total_strassen_multiplications | |
| A=np.asarray(A) | |
| B=np.asarray(B) | |
| NA=matrix_size(A) | |
| NB=matrix_size(B) | |
| N=assert_equality(NA,NB) | |
| size_threshold=2**7#For speed, make this a large number like 2**11. But since we want to know how many multiplications we need, let's do pure strassen by setting the threshold to 1. | |
| if N<=size_threshold: | |
| total_strassen_multiplications+=N**3 | |
| total_strassen_additions+=N**3-N**2 | |
| return naive_matrix_multiply(A,B) | |
| A11=A[:N//2,:N//2];A21=A[N//2:,:N//2] | |
| A12=A[:N//2,N//2:];A22=A[N//2:,N//2:] | |
| B11=B[:N//2,:N//2];B21=B[N//2:,:N//2] | |
| B12=B[:N//2,N//2:];B22=B[N//2:,N//2:] | |
| P1=strassen(A11+A22,B11+B22) | |
| P2=strassen(A21+A22,B11 ) | |
| P3=strassen(A11 ,B12-B22) | |
| P4=strassen(A22 ,B21-B11) | |
| P5=strassen(A11+A12,B22 ) | |
| P6=strassen(A21-A11,B11+B12) | |
| P7=strassen(A12-A22,B21+B22) | |
| total_strassen_additions+=10*(N//2)**2 | |
| C11=P1+P4-P5+P7;C12=P3+P5 | |
| C21=P2+P4 ;C22=P1-P2+P3+P6 | |
| total_strassen_additions+=8*(N//2)**2 | |
| cat=np.concatenate | |
| return cat((cat((C11,C21),axis=0), | |
| cat((C12,C22),axis=0)),axis=1) | |
| def is_power_of_two(n): | |
| while n>=1: | |
| if n==1:return True | |
| n/=2 | |
| return False | |
| def matrix_size(matrix): | |
| #assert len(matrix.shape)==2 | |
| N=assert_equality(*matrix.shape)#Assert it's a square matrix | |
| assert is_power_of_two(N) | |
| return N | |
| def naive_matrix_multiply(A,B): | |
| return A@B | |
| #THIS CODE IS TOO SLOW FOR N=2^12. I'VE LOST PATIENCE. IT IS CORRECT THOUGH (Uncomment the following function to verify for yourself. It will override the previous definition) | |
| # @numba.njit # Use numba to speed this code up by a LOT (on N=8, decreased runtime from 27 seconds to .02 seconds) | |
| # def naive_matrix_multiply(A,B): | |
| # N=len(A) | |
| # C=np.zeros((N,N)) | |
| # for rownum in range(N): | |
| # for colnum in range(N): | |
| # C[rownum][colnum]=np.dot(A[rownum][index]*B[index][colnum]) | |
| # return C | |
| total_strassen_additions=0 | |
| total_strassen_multiplications=0 | |
| def test(K): | |
| global total_strassen_additions,total_strassen_multiplications | |
| total_strassen_additions=total_strassen_multiplications=0 | |
| print("Testing for square matrices of size 2^"+str(K)+":") | |
| N=2**K | |
| A=np.random.rand(N,N) | |
| B=np.random.rand(N,N) | |
| tic() | |
| C=A@B#Builtin matrix multipliation | |
| print('\tNumpy time:',toc(),'seconds')#Numpy's builtin matrix multiplication is blazingly fast... | |
| tic() | |
| C_strassen=strassen(A,B) | |
| assert np.allclose(C,C_strassen)#Make sure C_naive gave the right answer | |
| print('\tStrassen time:',toc(),'seconds') | |
| print("\tTotal number of additions for strassen:",total_strassen_additions) | |
| print("\tTotal number of multiplications for strassen:",total_strassen_multiplications) | |
| tic() | |
| C_naive=naive_matrix_multiply(A,B) | |
| total_naive_additions=N**3-N**2 | |
| total_naive_multiplications=N**3 | |
| assert np.allclose(C,C_naive)#Make sure C_naive gave the right answer | |
| print('\tNaive time:',toc(),'seconds') | |
| print("\tTotal number of additions for naive:",total_naive_additions) | |
| print("\tTotal number of multiplications for naive:",total_naive_multiplications) | |
| test(K=10) | |
| test(K=12) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment