Created
June 8, 2022 14:52
-
-
Save matteo-grella/6c1da99f3248b997612531fd8d531350 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"log" | |
"reflect" | |
"github.com/nlpodyssey/spago/ag" | |
"github.com/nlpodyssey/spago/gd" | |
"github.com/nlpodyssey/spago/gd/sgd" | |
"github.com/nlpodyssey/spago/initializers" | |
"github.com/nlpodyssey/spago/losses" | |
"github.com/nlpodyssey/spago/mat" | |
"github.com/nlpodyssey/spago/mat/float" | |
"github.com/nlpodyssey/spago/mat/rand" | |
"github.com/nlpodyssey/spago/nn" | |
) | |
const ( | |
epochs = 100 // number of epochs | |
examples = 1000 // number of examples | |
) | |
type Linear struct { | |
nn.Module | |
W nn.Param `spago:"type:weights"` | |
B nn.Param `spago:"type:biases"` | |
} | |
func NewLinear[T float.DType](in, out int) *Linear { | |
return &Linear{ | |
W: nn.NewParam(mat.NewEmptyDense[T](out, in)), | |
B: nn.NewParam(mat.NewEmptyVecDense[T](out)), | |
} | |
} | |
func (m *Linear) InitWithRandomWeights(seed uint64) *Linear { | |
initializers.XavierUniform(m.W.Value(), 1.0, rand.NewLockedRand(seed)) | |
return m | |
} | |
func (m *Linear) Forward(x ag.Node) ag.Node { | |
return ag.Add(ag.Mul(m.W, x), m.B) | |
} | |
func main() { | |
m := NewLinear[float64](1, 1).InitWithRandomWeights(42) | |
optimizer := gd.NewOptimizer(m, sgd.New[float64](sgd.NewConfig(0.001, 0.9, true))) | |
normalize := func(x float64) float64 { return x / float64(examples) } | |
objective := func(x float64) float64 { return 3*x + 1 } | |
criterion := losses.MSE | |
learn := func(input, expected float64) float64 { | |
x, target := ag.Scalar(input), ag.Scalar(expected) | |
y := m.Forward(x) | |
loss := criterion(y, target, true) | |
defer ag.Backward(loss) // free the memory of the graph before return | |
return loss.Value().Scalar().F64() | |
} | |
for epoch := 0; epoch < epochs; epoch++ { | |
for i := 0; i < examples; i++ { | |
x := normalize(float64(i)) | |
loss := learn(x, objective(x)) | |
if i%100 == 0 { | |
fmt.Printf("Loss: %.6f\n", loss) | |
} | |
} | |
optimizer.Do() | |
} | |
fmt.Printf("\nW: %.2f | B: %.2f\n\n", m.W.Value().Scalar().F64(), m.B.Value().Scalar().F64()) | |
fmt.Printf("%#v", m) | |
err := nn.DumpToFile(m, "model") | |
if err != nil { | |
log.Fatal(err) | |
} | |
m2, err := nn.LoadFromFile[*Linear]("model") | |
if err != nil { | |
log.Fatal(err) | |
} | |
fmt.Printf("\nW: %.2f | B: %.2f\n\n", m2.W.Value().Scalar().F64(), m2.B.Value().Scalar().F64()) | |
fmt.Println(reflect.TypeOf(m2).Kind()) | |
// Save the model to a file | |
err = nn.DumpToFile(m.W.Value(), "w") | |
if err != nil { | |
log.Fatal(err) | |
} | |
// Load the model from a file | |
w, err := nn.LoadFromFile[mat.Dense[float64]]("w") | |
if err != nil { | |
log.Fatal(err) | |
} | |
fmt.Printf("\nW: %.2f\n", w.Scalar().F64()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment