Created
May 11, 2013 10:20
-
-
Save kortschak/5559519 to your computer and use it in GitHub Desktop.
Example of how a Mul might look.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package matrix | |
import ( | |
"github.com/gonum/blas" | |
) | |
var blasEngine blas.Float64 | |
type Float64 struct { | |
mat BlasMatrix | |
} | |
func (m *Float64) isZero() bool { | |
return m.mat.Cols == 0 || m.mat.Rows == 0 | |
} | |
func (m *Float64) At(r, c int) float64 { | |
if m.mat.Order == blas.RowMajor { | |
return m.mat.Data[r*m.mat.Stride+c] | |
} else if m.mat.Order == blas.ColMajor { | |
return m.mat.Data[c*m.mat.Stride+r] | |
} | |
panic("matrix: illegal order") | |
} | |
func (m *Float64) Dims() (r, c int) { return m.mat.Rows, m.mat.Cols } | |
func (m *Float64) Mul(a, b Matrix) { | |
ar, ac := a.Dims() | |
br, bc := b.Dims() | |
if ac != br { | |
panic(ErrShape) | |
} | |
if m.isZero() { | |
m.mat = BlasMatrix{ | |
Order: blas.RowMajor, // How do we know this? Assume blas.RowMajor for this example. | |
Rows: ar, | |
Cols: bc, | |
Stride: bc, //This is either ar or bc depending on Order - which we don't have a way of determining yet. | |
Data: make([]float64, ar*bc), | |
} | |
} else if ar != m.mat.Rows || bc != m.mat.Cols { | |
panic(ErrShape) | |
} | |
// This is the fast path; both are really BlasMatrix types. | |
if a, ok := a.(*Float64); ok { | |
if b, ok := b.(*Float64); ok { | |
if a.mat.Order != blas.RowMajor || b.mat.Order != blas.RowMajor { | |
panic("matrix: I'm not even going to bother if you don't make it easy for me.") | |
// i.e. order is more complicated than this for the general case | |
// At the very least they should agree. | |
} | |
blasEngine.Dgemm( | |
a.mat.Order, | |
blas.NoTrans, blas.NoTrans, | |
ar, bc, ac, | |
1., | |
a.mat.Data, a.mat.Stride, | |
b.mat.Data, b.mat.Stride, | |
0., | |
m.mat.Data, m.mat.Stride) | |
return | |
} | |
} | |
// Progressively worse runtime cases... not all of them - there are four, margin width yadda yadda. | |
if a, ok := a.(Vectorer); ok { | |
if b, ok := b.(Vectorer); ok { | |
// The order of the matrices will make a difference here to performance... this is just POC. | |
row := make([]float64, ac) | |
col := make([]float64, br) | |
for r := 0; r < ar; r++ { | |
for c := 0; c < bc; c++ { | |
// Assume blas.RowMajor. | |
m.mat.Data[r*m.mat.Stride+m.mat.Cols] = blasEngine.Ddot(ac, a.Row(r, row), 1, b.Col(c, col), 1) | |
} | |
} | |
return | |
} | |
} | |
row := make([]float64, ac) | |
for r := 0; r < ar; r++ { | |
for i := range row { | |
row[i] = a.At(r, i) | |
} | |
for c := 0; c < bc; c++ { | |
var v float64 | |
row := make([]float64, ac) | |
for i, e := range row { | |
v += e * b.At(i, c) | |
} | |
// Assume blas.RowMajor. | |
m.mat.Data[r*m.mat.Stride+m.mat.Cols] = v | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment