Skip to content

Instantly share code, notes, and snippets.

@soypat
Last active December 21, 2021 16:52
Show Gist options
  • Save soypat/8dd5bca0a99a3defd858e58ea2af4946 to your computer and use it in GitHub Desktop.
Save soypat/8dd5bca0a99a3defd858e58ea2af4946 to your computer and use it in GitHub Desktop.
Dgemm vector->matrix LAPACK translation corner case.
! To compile program with LAPACK:
! gfortran main.f90 -L/home/pato/src/lia/lapack -llapack -lrefblas
program main
implicit none
integer, parameter :: is=2, js=2, nb=2,mb=2
double precision :: one, rhs(8)
integer, parameter :: m=4, n=4 ! constants for test.
integer, parameter :: lda=m, ldb=n, ldc=m, ldd=m, lde=n, ldf=m ! expected leading dimensions
double precision :: a(lda, m), c(ldc, n)
rhs = [2.00, &
1.00, &
-1.00, &
3.00, &
3.00, &
1.00, &
-1.00, &
2.00]
a = reshape([4.00,0.00,0.00,0.00, &
1.00,3.00,1.00,0.00, &
1.00,4.00,3.00,0.00, &
2.00,1.00,1.00,6.00], [4,4])
c = reshape([1.00,-1.00,-1.00,-1.00, &
9.00,2.00,1.00,1.00, &
7.00,-1.00,3.00,-1.00, &
16.00,-1.00,7.00,20.00], [4,4])
one = 1.0D+0
CALL DGEMM( 'N', 'N', is-1, nb, mb, -one, a( 1, IS ), lda, rhs( 1 ), mb, one, c( 1, js ), ldc )
call printmat(c,4, 4)
! call printgo(c, 4, 4)
! Output:
! []float64{1.0000, 6.0000, 5.0000, 16.0000, -1.0000, 2.0000, -1.0000, -1.0000, -1.0000, 1.0000, 3.0000, 7.0000, -1.0000, 1.0000, -1.0000, 20.0000, }
! 1.0000, 6.0000, 5.0000, 16.0000,
! -1.0000, 2.0000, -1.0000, -1.0000,
! -1.0000, 1.0000, 3.0000, 7.0000,
! -1.0000, 1.0000, -1.0000, 20.0000,
end program
! prints matrix human-readable representation
subroutine printmat(a, r, c)
implicit none
double precision :: a (r,c)
integer i,j,r,c
do i=1, r
do j=1,c
WRITE (*,'(f0.4)',advance="no") a(i,j) ! print elements
WRITE (*,'(a)', advance="no") ", " ! separate elements with comma
enddo
print *, '' ! print newline
enddo
end subroutine
! Prints as a []float64{} array in row major storage format
subroutine printgo(a, r, c)
implicit none
integer i,j,r,c
double precision :: a (r,c)
write(*, '(a)', advance="no") "[]float64{"
do i=1, r
do j=1,c
WRITE (*,'(f0.4)',advance="no") a(i,j) ! print elements
WRITE (*,'(a)', advance="no") ", " ! separate elements with comma
enddo
enddo
write(*, '(a)') "}"
end subroutine
package main
import (
"fmt"
"gonum.org/v1/gonum/blas"
blasi "gonum.org/v1/gonum/blas/gonum"
)
// CALL DGEMM( 'N', 'N', is-1, nb, mb, -one, a( 1, IS ), lda, rhs( 1 ), mb, one, c( 1, js ), ldc )
func main() {
const (
is = 1
js = is
nb = 2
mb = nb
lda = 4
ldc = 4
)
a := []float64{4, 1, 1, 2, 0, 3, 4, 1, 0, 1, 3, 1, 0, 0, 0, 6}
rhs := []float64{2, 1, -1, 3, 3, 1, -1, 2}
c := []float64{1, 9, 7, 16, -1, 2, -1, -1, -1, 1, 3, 7, -1, 1, -1, 20}
// Generate fortran code:
// printFortran("a", a, 4, 4)
// printFortran("c", c, 4, 4)
// printFortran("rhs", rhs, 8, 1)
bp := blasi.Implementation{}
// The problem here is that LAPACK interprets rhs's first elements as the first column
// while Gonum interprets first elements as the row. Fix this by changing second argument to blas.Trans
bp.Dgemm(blas.NoTrans, blas.NoTrans, is, nb, mb, -1, a[is:], lda, rhs, mb, 1, c[js:], ldc)
printmat(c, 4, 4)
// Output:
// 1.00, 8.00, 3.00, 16.00,
// -1.00, 2.00, -1.00, -1.00,
// -1.00, 1.00, 3.00, 7.00,
// -1.00, 1.00, -1.00, 20.00,
// cfortran := tranpose(c, 4, 4)
// fmt.Println(cfortran)
// Output: [1 -1 -1 -1 8 2 1 1 3 -1 3 -1 16 -1 7 20]
}
// transposes matrix (useful for making row-major column-major)
func tranpose(a []float64, r, c int) []float64 {
z := make([]float64, r*c)
if c == 1 || r == 1 {
copy(z, a)
return z
}
for i := 0; i < r; i++ {
for j := 0; j < r; j++ {
z[j*r+i] = a[i*c+j]
}
}
return z
}
// print matrix representation in human-readable form
func printmat(a []float64, r, c int) {
for i := 0; i < r; i++ {
for j := 0; j < r; j++ {
fmt.Printf("%.2f, ", a[i*c+j])
}
fmt.Println()
}
}
// Prints fortran array ready for compiling.
func printFortran(name string, a []float64, r, lda int) {
ra := tranpose(a, r, lda)
fmt.Printf("%v = reshape([", name)
for i := 0; i < r; i++ {
for j := 0; j < lda; j++ {
fmt.Printf("%.2f", ra[i*lda+j])
if j != lda-1 {
fmt.Print(",")
}
}
if i != r-1 {
fmt.Print(",")
fmt.Println(" &")
}
}
fmt.Printf("], [%d,%d])\n", r, lda)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment