Last active
December 21, 2021 16:52
-
-
Save soypat/8dd5bca0a99a3defd858e58ea2af4946 to your computer and use it in GitHub Desktop.
Dgemm vector->matrix LAPACK translation corner case.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
! To compile program with LAPACK: | |
! gfortran main.f90 -L/home/pato/src/lia/lapack -llapack -lrefblas | |
program main | |
implicit none | |
integer, parameter :: is=2, js=2, nb=2,mb=2 | |
double precision :: one, rhs(8) | |
integer, parameter :: m=4, n=4 ! constants for test. | |
integer, parameter :: lda=m, ldb=n, ldc=m, ldd=m, lde=n, ldf=m ! expected leading dimensions | |
double precision :: a(lda, m), c(ldc, n) | |
rhs = [2.00, & | |
1.00, & | |
-1.00, & | |
3.00, & | |
3.00, & | |
1.00, & | |
-1.00, & | |
2.00] | |
a = reshape([4.00,0.00,0.00,0.00, & | |
1.00,3.00,1.00,0.00, & | |
1.00,4.00,3.00,0.00, & | |
2.00,1.00,1.00,6.00], [4,4]) | |
c = reshape([1.00,-1.00,-1.00,-1.00, & | |
9.00,2.00,1.00,1.00, & | |
7.00,-1.00,3.00,-1.00, & | |
16.00,-1.00,7.00,20.00], [4,4]) | |
one = 1.0D+0 | |
CALL DGEMM( 'N', 'N', is-1, nb, mb, -one, a( 1, IS ), lda, rhs( 1 ), mb, one, c( 1, js ), ldc ) | |
call printmat(c,4, 4) | |
! call printgo(c, 4, 4) | |
! Output: | |
! []float64{1.0000, 6.0000, 5.0000, 16.0000, -1.0000, 2.0000, -1.0000, -1.0000, -1.0000, 1.0000, 3.0000, 7.0000, -1.0000, 1.0000, -1.0000, 20.0000, } | |
! 1.0000, 6.0000, 5.0000, 16.0000, | |
! -1.0000, 2.0000, -1.0000, -1.0000, | |
! -1.0000, 1.0000, 3.0000, 7.0000, | |
! -1.0000, 1.0000, -1.0000, 20.0000, | |
end program | |
! prints matrix human-readable representation | |
subroutine printmat(a, r, c) | |
implicit none | |
double precision :: a (r,c) | |
integer i,j,r,c | |
do i=1, r | |
do j=1,c | |
WRITE (*,'(f0.4)',advance="no") a(i,j) ! print elements | |
WRITE (*,'(a)', advance="no") ", " ! separate elements with comma | |
enddo | |
print *, '' ! print newline | |
enddo | |
end subroutine | |
! Prints as a []float64{} array in row major storage format | |
subroutine printgo(a, r, c) | |
implicit none | |
integer i,j,r,c | |
double precision :: a (r,c) | |
write(*, '(a)', advance="no") "[]float64{" | |
do i=1, r | |
do j=1,c | |
WRITE (*,'(f0.4)',advance="no") a(i,j) ! print elements | |
WRITE (*,'(a)', advance="no") ", " ! separate elements with comma | |
enddo | |
enddo | |
write(*, '(a)') "}" | |
end subroutine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"gonum.org/v1/gonum/blas" | |
blasi "gonum.org/v1/gonum/blas/gonum" | |
) | |
// CALL DGEMM( 'N', 'N', is-1, nb, mb, -one, a( 1, IS ), lda, rhs( 1 ), mb, one, c( 1, js ), ldc ) | |
func main() { | |
const ( | |
is = 1 | |
js = is | |
nb = 2 | |
mb = nb | |
lda = 4 | |
ldc = 4 | |
) | |
a := []float64{4, 1, 1, 2, 0, 3, 4, 1, 0, 1, 3, 1, 0, 0, 0, 6} | |
rhs := []float64{2, 1, -1, 3, 3, 1, -1, 2} | |
c := []float64{1, 9, 7, 16, -1, 2, -1, -1, -1, 1, 3, 7, -1, 1, -1, 20} | |
// Generate fortran code: | |
// printFortran("a", a, 4, 4) | |
// printFortran("c", c, 4, 4) | |
// printFortran("rhs", rhs, 8, 1) | |
bp := blasi.Implementation{} | |
// The problem here is that LAPACK interprets rhs's first elements as the first column | |
// while Gonum interprets first elements as the row. Fix this by changing second argument to blas.Trans | |
bp.Dgemm(blas.NoTrans, blas.NoTrans, is, nb, mb, -1, a[is:], lda, rhs, mb, 1, c[js:], ldc) | |
printmat(c, 4, 4) | |
// Output: | |
// 1.00, 8.00, 3.00, 16.00, | |
// -1.00, 2.00, -1.00, -1.00, | |
// -1.00, 1.00, 3.00, 7.00, | |
// -1.00, 1.00, -1.00, 20.00, | |
// cfortran := tranpose(c, 4, 4) | |
// fmt.Println(cfortran) | |
// Output: [1 -1 -1 -1 8 2 1 1 3 -1 3 -1 16 -1 7 20] | |
} | |
// transposes matrix (useful for making row-major column-major) | |
func tranpose(a []float64, r, c int) []float64 { | |
z := make([]float64, r*c) | |
if c == 1 || r == 1 { | |
copy(z, a) | |
return z | |
} | |
for i := 0; i < r; i++ { | |
for j := 0; j < r; j++ { | |
z[j*r+i] = a[i*c+j] | |
} | |
} | |
return z | |
} | |
// print matrix representation in human-readable form | |
func printmat(a []float64, r, c int) { | |
for i := 0; i < r; i++ { | |
for j := 0; j < r; j++ { | |
fmt.Printf("%.2f, ", a[i*c+j]) | |
} | |
fmt.Println() | |
} | |
} | |
// Prints fortran array ready for compiling. | |
func printFortran(name string, a []float64, r, lda int) { | |
ra := tranpose(a, r, lda) | |
fmt.Printf("%v = reshape([", name) | |
for i := 0; i < r; i++ { | |
for j := 0; j < lda; j++ { | |
fmt.Printf("%.2f", ra[i*lda+j]) | |
if j != lda-1 { | |
fmt.Print(",") | |
} | |
} | |
if i != r-1 { | |
fmt.Print(",") | |
fmt.Println(" &") | |
} | |
} | |
fmt.Printf("], [%d,%d])\n", r, lda) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment