Created
March 27, 2023 13:47
-
-
Save nfisher/f31bdcb3fbc99686f2383315ced8bd4c to your computer and use it in GitHub Desktop.
Perceptrons
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sepal_length | sepal_width | petal_length | petal_width | species | |
---|---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | setosa | |
4.9 | 3.0 | 1.4 | 0.2 | setosa | |
4.7 | 3.2 | 1.3 | 0.2 | setosa | |
4.6 | 3.1 | 1.5 | 0.2 | setosa | |
5.0 | 3.6 | 1.4 | 0.2 | setosa | |
5.4 | 3.9 | 1.7 | 0.4 | setosa | |
4.6 | 3.4 | 1.4 | 0.3 | setosa | |
5.0 | 3.4 | 1.5 | 0.2 | setosa | |
4.4 | 2.9 | 1.4 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
5.4 | 3.7 | 1.5 | 0.2 | setosa | |
4.8 | 3.4 | 1.6 | 0.2 | setosa | |
4.8 | 3.0 | 1.4 | 0.1 | setosa | |
4.3 | 3.0 | 1.1 | 0.1 | setosa | |
5.8 | 4.0 | 1.2 | 0.2 | setosa | |
5.7 | 4.4 | 1.5 | 0.4 | setosa | |
5.4 | 3.9 | 1.3 | 0.4 | setosa | |
5.1 | 3.5 | 1.4 | 0.3 | setosa | |
5.7 | 3.8 | 1.7 | 0.3 | setosa | |
5.1 | 3.8 | 1.5 | 0.3 | setosa | |
5.4 | 3.4 | 1.7 | 0.2 | setosa | |
5.1 | 3.7 | 1.5 | 0.4 | setosa | |
4.6 | 3.6 | 1.0 | 0.2 | setosa | |
5.1 | 3.3 | 1.7 | 0.5 | setosa | |
4.8 | 3.4 | 1.9 | 0.2 | setosa | |
5.0 | 3.0 | 1.6 | 0.2 | setosa | |
5.0 | 3.4 | 1.6 | 0.4 | setosa | |
5.2 | 3.5 | 1.5 | 0.2 | setosa | |
5.2 | 3.4 | 1.4 | 0.2 | setosa | |
4.7 | 3.2 | 1.6 | 0.2 | setosa | |
4.8 | 3.1 | 1.6 | 0.2 | setosa | |
5.4 | 3.4 | 1.5 | 0.4 | setosa | |
5.2 | 4.1 | 1.5 | 0.1 | setosa | |
5.5 | 4.2 | 1.4 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
5.0 | 3.2 | 1.2 | 0.2 | setosa | |
5.5 | 3.5 | 1.3 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
4.4 | 3.0 | 1.3 | 0.2 | setosa | |
5.1 | 3.4 | 1.5 | 0.2 | setosa | |
5.0 | 3.5 | 1.3 | 0.3 | setosa | |
4.5 | 2.3 | 1.3 | 0.3 | setosa | |
4.4 | 3.2 | 1.3 | 0.2 | setosa | |
5.0 | 3.5 | 1.6 | 0.6 | setosa | |
5.1 | 3.8 | 1.9 | 0.4 | setosa | |
4.8 | 3.0 | 1.4 | 0.3 | setosa | |
5.1 | 3.8 | 1.6 | 0.2 | setosa | |
4.6 | 3.2 | 1.4 | 0.2 | setosa | |
5.3 | 3.7 | 1.5 | 0.2 | setosa | |
5.0 | 3.3 | 1.4 | 0.2 | setosa | |
7.0 | 3.2 | 4.7 | 1.4 | versicolor | |
6.4 | 3.2 | 4.5 | 1.5 | versicolor | |
6.9 | 3.1 | 4.9 | 1.5 | versicolor | |
5.5 | 2.3 | 4.0 | 1.3 | versicolor | |
6.5 | 2.8 | 4.6 | 1.5 | versicolor | |
5.7 | 2.8 | 4.5 | 1.3 | versicolor | |
6.3 | 3.3 | 4.7 | 1.6 | versicolor | |
4.9 | 2.4 | 3.3 | 1.0 | versicolor | |
6.6 | 2.9 | 4.6 | 1.3 | versicolor | |
5.2 | 2.7 | 3.9 | 1.4 | versicolor | |
5.0 | 2.0 | 3.5 | 1.0 | versicolor | |
5.9 | 3.0 | 4.2 | 1.5 | versicolor | |
6.0 | 2.2 | 4.0 | 1.0 | versicolor | |
6.1 | 2.9 | 4.7 | 1.4 | versicolor | |
5.6 | 2.9 | 3.6 | 1.3 | versicolor | |
6.7 | 3.1 | 4.4 | 1.4 | versicolor | |
5.6 | 3.0 | 4.5 | 1.5 | versicolor | |
5.8 | 2.7 | 4.1 | 1.0 | versicolor | |
6.2 | 2.2 | 4.5 | 1.5 | versicolor | |
5.6 | 2.5 | 3.9 | 1.1 | versicolor | |
5.9 | 3.2 | 4.8 | 1.8 | versicolor | |
6.1 | 2.8 | 4.0 | 1.3 | versicolor | |
6.3 | 2.5 | 4.9 | 1.5 | versicolor | |
6.1 | 2.8 | 4.7 | 1.2 | versicolor | |
6.4 | 2.9 | 4.3 | 1.3 | versicolor | |
6.6 | 3.0 | 4.4 | 1.4 | versicolor | |
6.8 | 2.8 | 4.8 | 1.4 | versicolor | |
6.7 | 3.0 | 5.0 | 1.7 | versicolor | |
6.0 | 2.9 | 4.5 | 1.5 | versicolor | |
5.7 | 2.6 | 3.5 | 1.0 | versicolor | |
5.5 | 2.4 | 3.8 | 1.1 | versicolor | |
5.5 | 2.4 | 3.7 | 1.0 | versicolor | |
5.8 | 2.7 | 3.9 | 1.2 | versicolor | |
6.0 | 2.7 | 5.1 | 1.6 | versicolor | |
5.4 | 3.0 | 4.5 | 1.5 | versicolor | |
6.0 | 3.4 | 4.5 | 1.6 | versicolor | |
6.7 | 3.1 | 4.7 | 1.5 | versicolor | |
6.3 | 2.3 | 4.4 | 1.3 | versicolor | |
5.6 | 3.0 | 4.1 | 1.3 | versicolor | |
5.5 | 2.5 | 4.0 | 1.3 | versicolor | |
5.5 | 2.6 | 4.4 | 1.2 | versicolor | |
6.1 | 3.0 | 4.6 | 1.4 | versicolor | |
5.8 | 2.6 | 4.0 | 1.2 | versicolor | |
5.0 | 2.3 | 3.3 | 1.0 | versicolor | |
5.6 | 2.7 | 4.2 | 1.3 | versicolor | |
5.7 | 3.0 | 4.2 | 1.2 | versicolor | |
5.7 | 2.9 | 4.2 | 1.3 | versicolor | |
6.2 | 2.9 | 4.3 | 1.3 | versicolor | |
5.1 | 2.5 | 3.0 | 1.1 | versicolor | |
5.7 | 2.8 | 4.1 | 1.3 | versicolor | |
6.3 | 3.3 | 6.0 | 2.5 | virginica | |
5.8 | 2.7 | 5.1 | 1.9 | virginica | |
7.1 | 3.0 | 5.9 | 2.1 | virginica | |
6.3 | 2.9 | 5.6 | 1.8 | virginica | |
6.5 | 3.0 | 5.8 | 2.2 | virginica | |
7.6 | 3.0 | 6.6 | 2.1 | virginica | |
4.9 | 2.5 | 4.5 | 1.7 | virginica | |
7.3 | 2.9 | 6.3 | 1.8 | virginica | |
6.7 | 2.5 | 5.8 | 1.8 | virginica | |
7.2 | 3.6 | 6.1 | 2.5 | virginica | |
6.5 | 3.2 | 5.1 | 2.0 | virginica | |
6.4 | 2.7 | 5.3 | 1.9 | virginica | |
6.8 | 3.0 | 5.5 | 2.1 | virginica | |
5.7 | 2.5 | 5.0 | 2.0 | virginica | |
5.8 | 2.8 | 5.1 | 2.4 | virginica | |
6.4 | 3.2 | 5.3 | 2.3 | virginica | |
6.5 | 3.0 | 5.5 | 1.8 | virginica | |
7.7 | 3.8 | 6.7 | 2.2 | virginica | |
7.7 | 2.6 | 6.9 | 2.3 | virginica | |
6.0 | 2.2 | 5.0 | 1.5 | virginica | |
6.9 | 3.2 | 5.7 | 2.3 | virginica | |
5.6 | 2.8 | 4.9 | 2.0 | virginica | |
7.7 | 2.8 | 6.7 | 2.0 | virginica | |
6.3 | 2.7 | 4.9 | 1.8 | virginica | |
6.7 | 3.3 | 5.7 | 2.1 | virginica | |
7.2 | 3.2 | 6.0 | 1.8 | virginica | |
6.2 | 2.8 | 4.8 | 1.8 | virginica | |
6.1 | 3.0 | 4.9 | 1.8 | virginica | |
6.4 | 2.8 | 5.6 | 2.1 | virginica | |
7.2 | 3.0 | 5.8 | 1.6 | virginica | |
7.4 | 2.8 | 6.1 | 1.9 | virginica | |
7.9 | 3.8 | 6.4 | 2.0 | virginica | |
6.4 | 2.8 | 5.6 | 2.2 | virginica | |
6.3 | 2.8 | 5.1 | 1.5 | virginica | |
6.1 | 2.6 | 5.6 | 1.4 | virginica | |
7.7 | 3.0 | 6.1 | 2.3 | virginica | |
6.3 | 3.4 | 5.6 | 2.4 | virginica | |
6.4 | 3.1 | 5.5 | 1.8 | virginica | |
6.0 | 3.0 | 4.8 | 1.8 | virginica | |
6.9 | 3.1 | 5.4 | 2.1 | virginica | |
6.7 | 3.1 | 5.6 | 2.4 | virginica | |
6.9 | 3.1 | 5.1 | 2.3 | virginica | |
5.8 | 2.7 | 5.1 | 1.9 | virginica | |
6.8 | 3.2 | 5.9 | 2.3 | virginica | |
6.7 | 3.3 | 5.7 | 2.5 | virginica | |
6.7 | 3.0 | 5.2 | 2.3 | virginica | |
6.3 | 2.5 | 5.0 | 1.9 | virginica | |
6.5 | 3.0 | 5.2 | 2.0 | virginica | |
6.2 | 3.4 | 5.4 | 2.3 | virginica | |
5.9 | 3.0 | 5.1 | 1.8 | virginica |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/csv" | |
"errors" | |
"fmt" | |
"log" | |
"math" | |
"math/rand" | |
"os" | |
"strconv" | |
) | |
func main() { | |
log.SetFlags(0) | |
input := &Data[float64]{} | |
err := LoadIris(input, true) | |
if err != nil { | |
log.Fatalln(err) | |
} | |
input.Split(0.70) | |
fmt.Println(input.TargetName) | |
_, features := input.Shape() | |
p := New(features, step[float64]) | |
err = Fit(p, 1000, 0.001, input) | |
if err != nil { | |
log.Fatalln("fit", err) | |
} | |
result, err := input.Test(func(X []float64) (float64, error) { | |
return Predict(p, X, false) | |
}) | |
if err != nil { | |
log.Println(err) | |
} | |
log.Printf("precision=%0.2f%%\n", result) | |
} | |
type stepFn[T Numeric] func(T, error) (T, error) | |
// Fit trains a perceptron (p) on the data set (d) over a number of iterations (iters). It uses the step function (step) | |
// to align the data to the labels. The learning rate (r) influences how much change there is to the weights for each | |
// correction. | |
func Fit[T Numeric](p *Perceptron[T], iters int, r T, d *Data[T]) error { | |
for n := 0; n < iters; n++ { | |
err := d.Train(func(X []T, target T) error { | |
Yjt, err := p.Step(Dot(p.Weights, X)) | |
if err != nil { | |
return err | |
} | |
if Yjt == target { | |
return nil | |
} | |
diff := Yjt - target | |
for j, w := range p.Weights { | |
p.Weights[j] = w - r*diff*X[j] | |
} | |
return nil | |
}) | |
if err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func step[T Numeric](x T, err error) (T, error) { | |
if err != nil { | |
return 0, err | |
} | |
s := math.Round(float64(x)) | |
if s > 0 { | |
return T(s), nil | |
} | |
return 0, nil | |
} | |
func Predict[T Numeric](p *Perceptron[T], x []T, addBiasCol bool) (T, error) { | |
if addBiasCol { | |
x = append(x, 1) | |
} | |
return p.Step(Dot(p.Weights, x)) | |
} | |
func New[T Numeric](features int, fn stepFn[T]) *Perceptron[T] { | |
weights := make([]T, features) | |
return &Perceptron[T]{ | |
Weights: weights, | |
Step: fn, | |
} | |
} | |
type Perceptron[T Numeric] struct { | |
Weights []T | |
Step stepFn[T] | |
} | |
type Data[T Numeric] struct { | |
Values [][]T | |
Target []T | |
TargetName map[string]T | |
Headers []string | |
features int | |
train []int | |
test []int | |
} | |
func (d *Data[T]) Shape() (int, int) { | |
return len(d.Values), d.features | |
} | |
func (d *Data[T]) Split(pct float64) { | |
order := make([]int, len(d.Values)) | |
for i := 0; i < len(order); i++ { | |
order[i] = i | |
} | |
rand.Shuffle(len(order), func(i, j int) { | |
order[i], order[j] = order[j], order[i] | |
}) | |
trainSz := int(pct * float64(len(d.Values))) | |
d.train = order[:trainSz] | |
d.test = order[trainSz:] | |
} | |
func (d *Data[T]) Train(fn func(X []T, target T) error) error { | |
for _, i := range d.train { | |
err := fn(d.Values[i], d.Target[i]) | |
if err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func (d *Data[T]) Test(fn func(X []T) (T, error)) (float64, error) { | |
var correct int | |
for _, i := range d.test { | |
p, err := fn(d.Values[i]) | |
if err != nil { | |
return 0, err | |
} | |
if p == d.Target[i] { | |
correct++ | |
} | |
} | |
return float64(correct) / float64(len(d.test)) * 100.0, nil | |
} | |
func LoadIris[T Numeric](d *Data[T], addBiasColumn bool) error { | |
r, err := os.Open("iris.csv") | |
if err != nil { | |
return err | |
} | |
f := csv.NewReader(r) | |
records, err := f.ReadAll() | |
if err != nil { | |
return err | |
} | |
d.Headers = records[0][:len(records[0])-1] | |
d.TargetName = map[string]T{} | |
records = records[1:] // trim header row | |
var a []T | |
for _, row := range records { | |
name := row[len(row)-1] | |
v, ok := d.TargetName[name] | |
if !ok { | |
v = T(len(d.TargetName)) + 1 | |
d.TargetName[name] = v | |
} | |
d.Target = append(d.Target, v) | |
a = []T{} | |
for _, s := range row[:len(row)-1] { | |
f, err := strconv.ParseFloat(s, 64) | |
if err != nil { | |
return err | |
} | |
a = append(a, T(f)) | |
} | |
if addBiasColumn { | |
a = append(a, 1) | |
} | |
d.Values = append(d.Values, a) | |
} | |
d.features = len(a) | |
return nil | |
} | |
func Dot[T Numeric](a, b []T) (T, error) { | |
if len(a) != len(b) { | |
return 0, fmt.Errorf("unaligned vectors a=%d, b=%d", len(a), len(b)) | |
} | |
if len(a) == 0 { | |
return 0, errors.New("empty vectors") | |
} | |
var sum T = 0 | |
for i := range a { | |
sum += a[i] * b[i] | |
} | |
return sum, nil | |
} | |
type Numeric interface { | |
~int | ~float32 | ~float64 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment