Skip to content

Instantly share code, notes, and snippets.

@kortschak
Created May 11, 2013 10:20
Show Gist options
  • Save kortschak/5559519 to your computer and use it in GitHub Desktop.
Save kortschak/5559519 to your computer and use it in GitHub Desktop.
Example of how a Mul might look.
package matrix
import (
"github.com/gonum/blas"
)
var blasEngine blas.Float64
type Float64 struct {
mat BlasMatrix
}
func (m *Float64) isZero() bool {
return m.mat.Cols == 0 || m.mat.Rows == 0
}
func (m *Float64) At(r, c int) float64 {
if m.mat.Order == blas.RowMajor {
return m.mat.Data[r*m.mat.Stride+c]
} else if m.mat.Order == blas.ColMajor {
return m.mat.Data[c*m.mat.Stride+r]
}
panic("matrix: illegal order")
}
func (m *Float64) Dims() (r, c int) { return m.mat.Rows, m.mat.Cols }
func (m *Float64) Mul(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ac != br {
panic(ErrShape)
}
if m.isZero() {
m.mat = BlasMatrix{
Order: blas.RowMajor, // How do we know this? Assume blas.RowMajor for this example.
Rows: ar,
Cols: bc,
Stride: bc, //This is either ar or bc depending on Order - which we don't have a way of determining yet.
Data: make([]float64, ar*bc),
}
} else if ar != m.mat.Rows || bc != m.mat.Cols {
panic(ErrShape)
}
// This is the fast path; both are really BlasMatrix types.
if a, ok := a.(*Float64); ok {
if b, ok := b.(*Float64); ok {
if a.mat.Order != blas.RowMajor || b.mat.Order != blas.RowMajor {
panic("matrix: I'm not even going to bother if you don't make it easy for me.")
// i.e. order is more complicated than this for the general case
// At the very least they should agree.
}
blasEngine.Dgemm(
a.mat.Order,
blas.NoTrans, blas.NoTrans,
ar, bc, ac,
1.,
a.mat.Data, a.mat.Stride,
b.mat.Data, b.mat.Stride,
0.,
m.mat.Data, m.mat.Stride)
return
}
}
// Progressively worse runtime cases... not all of them - there are four, margin width yadda yadda.
if a, ok := a.(Vectorer); ok {
if b, ok := b.(Vectorer); ok {
// The order of the matrices will make a difference here to performance... this is just POC.
row := make([]float64, ac)
col := make([]float64, br)
for r := 0; r < ar; r++ {
for c := 0; c < bc; c++ {
// Assume blas.RowMajor.
m.mat.Data[r*m.mat.Stride+m.mat.Cols] = blasEngine.Ddot(ac, a.Row(r, row), 1, b.Col(c, col), 1)
}
}
return
}
}
row := make([]float64, ac)
for r := 0; r < ar; r++ {
for i := range row {
row[i] = a.At(r, i)
}
for c := 0; c < bc; c++ {
var v float64
row := make([]float64, ac)
for i, e := range row {
v += e * b.At(i, c)
}
// Assume blas.RowMajor.
m.mat.Data[r*m.mat.Stride+m.mat.Cols] = v
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment