Created
August 21, 2019 07:37
-
-
Save antoine-levitt/565750d1e7a323330e20e7b58e55c895 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import LinearAlgebra.BLAS | |
const libblas = Base.libblas_name | |
const liblapack = Base.liblapack_name | |
import LinearAlgebra | |
import LinearAlgebra: BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, checksquare, stride1, chkstride1, axpy! | |
import Libdl | |
for (gemm, elty) in | |
((:dgemm_,:Float64), | |
(:sgemm_,:Float32), | |
(:zgemm3m_,:ComplexF64), | |
(:cgemm3m_,:ComplexF32)) | |
@eval begin | |
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC) | |
# * .. Scalar Arguments .. | |
# DOUBLE PRECISION ALPHA,BETA | |
# INTEGER K,LDA,LDB,LDC,M,N | |
# CHARACTER TRANSA,TRANSB | |
# * .. Array Arguments .. | |
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*) | |
function gemm!(transA::Char, transB::Char, alpha::($elty), A::AbstractVecOrMat{$elty}, B::AbstractVecOrMat{$elty}, beta::($elty), C::AbstractVecOrMat{$elty}) | |
# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1) | |
# error("gemm!: BLAS module requires contiguous matrix columns") | |
# end # should this be checked on every call? | |
m = size(A, transA == 'N' ? 1 : 2) | |
ka = size(A, transA == 'N' ? 2 : 1) | |
kb = size(B, transB == 'N' ? 1 : 2) | |
n = size(B, transB == 'N' ? 2 : 1) | |
if ka != kb || m != size(C,1) || n != size(C,2) | |
throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))")) | |
end | |
chkstride1(A) | |
chkstride1(B) | |
chkstride1(C) | |
ccall((BLAS.@blasfunc($gemm), libblas), Cvoid, | |
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, | |
Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, | |
Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, | |
Ref{BlasInt}), | |
transA, transB, m, n, | |
ka, alpha, A, max(1,stride(A,2)), | |
B, max(1,stride(B,2)), beta, C, | |
max(1,stride(C,2))) | |
C | |
end | |
function gemm(transA::Char, transB::Char, alpha::($elty), A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty}) | |
gemm!(transA, transB, alpha, A, B, zero($elty), similar(B, $elty, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1)))) | |
end | |
function gemm(transA::Char, transB::Char, A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty}) | |
gemm(transA, transB, one($elty), A, B) | |
end | |
end | |
end | |
for N in 2 .^(1:10) | |
A = randn(N,N) + im*randn(N,N) | |
B = randn(N,N) + im*randn(N,N) | |
C = randn(N,N) + im*randn(N,N) | |
println(N) | |
@btime gemm!('N', 'N', one(ComplexF64), $A, $B, zero(ComplexF64), $C) | |
@btime BLAS.gemm!('N', 'N', one(ComplexF64), $A, $B, zero(ComplexF64), $C) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment