Skip to content

Instantly share code, notes, and snippets.

@kidtronnix
Created September 7, 2018 14:58
Show Gist options
  • Save kidtronnix/ad7317c5693135967ee01f4bdcf833c2 to your computer and use it in GitHub Desktop.
Save kidtronnix/ad7317c5693135967ee01f4bdcf833c2 to your computer and use it in GitHub Desktop.
package nn
import (
"fmt"
"gonum.org/v1/gonum/mat"
)
func (n *MLP) Train(x, y *mat.Dense) {
r, cx := x.Dims()
_, cy := y.Dims()
b := n.config.BatchSize
for e := 1; e < n.config.Epochs+1; e++ {
for i := 0; i < r; i += b {
k := i + b
if k > r {
k = r
}
_x := x.Slice(i, k, 0, cx)
_y := y.Slice(i, k, 0, cy)
n.backward(_x, _y)
}
a := n.Evaluate(x, y)
fmt.Printf("epoch=%d accuracy=%0.1f%%\n", e, a)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment