Skip to content

Instantly share code, notes, and snippets.

@genya0407
Created September 1, 2016 10:27
Show Gist options
  • Save genya0407/16e6ed23906c261f00761ef10305abd2 to your computer and use it in GitHub Desktop.
Save genya0407/16e6ed23906c261f00761ef10305abd2 to your computer and use it in GitHub Desktop.
package main
import (
"encoding/csv"
"fmt"
"github.com/gonum/plot"
"github.com/gonum/plot/plotter"
"github.com/gonum/plot/plotutil"
"github.com/gonum/plot/vg"
flags "github.com/jessevdk/go-flags"
//"github.com/k0kubun/pp"
"io/ioutil"
"math"
"math/rand"
"os"
"strconv"
"strings"
"time"
)
const (
a = 0.1
)
var K int
type Network struct {
OutputWeight [][]float64
MiddleWeight [][]float64
}
type Perceptrons struct {
Output []float64
Middle []float64
}
func initial_perceptrons() Perceptrons {
p := Perceptrons{}
p.Output = make([]float64, 10)
p.Middle = make([]float64, K+1)
return p
}
type Data struct {
input []float64
output []float64
}
func get_data(filename string) []Data {
var Datas []Data
raw_data, _ := ioutil.ReadFile(filename)
splitted_data := strings.Split(string(raw_data), "\n")
for n := 0; n+10 < len(splitted_data); n = n + 12 {
var input []float64
var output []float64
input = []float64{-1}
for _, line := range splitted_data[n : n+9] {
for _, bit := range strings.Split(strings.TrimSpace(line), " ") {
num, _ := strconv.Atoi(bit)
input = append(input, float64(num))
}
}
for _, bit := range strings.Split(splitted_data[n+10], " ") {
num, _ := strconv.Atoi(bit)
output = append(output, float64(num))
}
Datas = append(Datas, Data{input, output})
}
return Datas
}
func initial_network() Network {
network := Network{}
network.OutputWeight = make([][]float64, K+1)
for i, _ := range network.OutputWeight {
network.OutputWeight[i] = make([]float64, 9+1)
}
network.MiddleWeight = make([][]float64, 81+1)
for i, _ := range network.MiddleWeight {
network.MiddleWeight[i] = make([]float64, K+1)
}
for i := 0; i <= 81; i++ {
for m := 0; m <= K; m++ {
network.MiddleWeight[i][m] = 0.005 - rand.Float64()*0.01
}
}
for m := 0; m <= K; m++ {
for o := 0; o <= 9; o++ {
network.OutputWeight[m][o] = 0.005 - rand.Float64()*0.01
}
}
return network
}
func sig(x float64) float64 {
return (1.0 / (1.0 + math.Exp(-x)))
}
func output_to_number(O []float64) int {
max := 0.0
num := 0
for i, v := range O {
if max < v {
max = v
num = i
}
}
return num
}
func calc_error(O []float64, T []float64) float64 {
error_sum := 0.0
for o := 0; o < 10; o++ {
error_sum += math.Pow(O[o]-T[o], 2)
}
return error_sum / 2.0
}
func calc_perceptrons(network Network, input []float64) Perceptrons {
p := initial_perceptrons()
p.Middle[0] = -1
for m := 1; m < K+1; m++ {
s := 0.0
for i := 0; i < 81+1; i++ {
s += network.MiddleWeight[i][m] * input[i]
}
p.Middle[m] = sig(s)
}
for o := 0; o < 9+1; o++ {
s := 0.0
for m := 0; m < K+1; m++ {
s += network.OutputWeight[m][o] * p.Middle[m]
}
p.Output[o] = sig(s)
}
return p
}
func epoc(network Network, teachers []Data) (Network, float64, int) {
// 重み配列の差分
var OutputWeightDelta [][]float64
var MiddleWeightDelta [][]float64
OutputWeightDelta = make([][]float64, K+1)
for i, _ := range OutputWeightDelta {
OutputWeightDelta[i] = make([]float64, 9+1)
}
MiddleWeightDelta = make([][]float64, 81+1)
for i, _ := range MiddleWeightDelta {
MiddleWeightDelta[i] = make([]float64, K+1)
}
err := 0.0
success_count := 0
for _, teacher := range teachers {
// 内部状態の計算
perceptrons := calc_perceptrons(network, teacher.input)
// 認識成功してたらsuccess_countをインクリメント
if output_to_number(perceptrons.Output[:]) == output_to_number(teacher.output) {
success_count++
}
// 二乗誤差関数の値を計算
err += calc_error(perceptrons.Output, teacher.output)
// 入力層から中間層への重みの差分を計算
for i := 0; i <= 81; i++ {
for m := 0; m <= K; m++ {
var second_sect float64
for o := 0; o < 10; o++ {
second_sect = second_sect + (perceptrons.Output[o]-teacher.output[o])*perceptrons.Output[o]*(1.0-perceptrons.Output[o])*network.OutputWeight[m][o]
}
MiddleWeightDelta[i][m] = -a * second_sect * perceptrons.Middle[m] * (1.0 - perceptrons.Middle[m]) * teacher.input[i]
}
}
// 中間層から出力層への重みの差分を計算
for m := 0; m <= K; m++ {
for o := 0; o <= 9; o++ {
OutputWeightDelta[m][o] = -a * (perceptrons.Output[o] - teacher.output[o]) * perceptrons.Output[o] * (1.0 - perceptrons.Output[o]) * perceptrons.Middle[m]
}
}
// 入力層から中間層への重みの学習
for i := 0; i <= 81; i++ {
for m := 0; m <= K; m++ {
network.MiddleWeight[i][m] += MiddleWeightDelta[i][m]
}
}
// 中間層から出力層への重みの学習
for m := 0; m <= K; m++ {
for o := 0; o <= 9; o++ {
network.OutputWeight[m][o] += OutputWeightDelta[m][o]
}
}
}
// 二乗誤差関数の平均を取る
err = err / 100.0
return network, err, success_count
}
func train() Network {
network := initial_network()
teachers := get_data(opts.Teacher)
err := math.Inf(1)
ep := 0
success_count := 0
var recog_rates plotter.XYs
var errors plotter.XYs
for ep < opts.Epoc {
network, err, success_count = epoc(network, teachers)
ep++
_recog_rate := make(plotter.XYs, 1)
_recog_rate[0].X = float64(ep)
_recog_rate[0].Y = float64(success_count)
recog_rates = append(recog_rates, _recog_rate...)
_error := make(plotter.XYs, 1)
_error[0].X = float64(ep)
_error[0].Y = err
errors = append(errors, _error...)
}
if len(opts.RenderGraph) > 0 {
p, _ := plot.New()
p.X.Label.Text = "epoc count"
p.Y.Label.Text = "recognition rate(%)"
p.Y.Max = 100
p.Y.Min = 0
plotutil.AddLinePoints(p, "", recog_rates)
p.Save(5*vg.Inch, 5*vg.Inch, "../認識率.png")
q, _ := plot.New()
q.X.Label.Text = "epoc count"
q.Y.Label.Text = "error"
q.Y.Min = 0
plotutil.AddLinePoints(q, "", errors)
q.Save(5*vg.Inch, 5*vg.Inch, "../二乗誤差.png")
}
return network
}
func recognize(network Network, filename string) float64 {
datas := get_data(filename)
success_count := 0
for _, data := range datas {
p := calc_perceptrons(network, data.input)
if output_to_number(p.Output[:]) == output_to_number(data.output) {
success_count++
}
}
return float64(success_count) / float64(len(datas))
}
type Options struct {
Epoc int `long:"epoc" default:"180"`
MiddleNumber int `short:"K" long:"MiddleNumber" default:"100"`
OutputFile string `short:"o" default:"/dev/null"`
RenderGraph []bool `short:"g"`
Teacher string `long:"teacher" default:"./data/mlp_train.data"`
AdditionalData string `short:"a" default:""`
}
var opts Options
func main() {
flags.Parse(&opts)
K = opts.MiddleNumber
rand.Seed(time.Now().UnixNano())
network := train()
csv_file, _ := os.Create(opts.OutputFile)
defer csv_file.Close()
writer := csv.NewWriter(csv_file)
writer.Write([]string{"入力データの誤り率(%)", "認識率(%)"})
recog_rate := recognize(network, "./data/mlp_train.data")
writer.Write([]string{"0", fmt.Sprint(math.Trunc(recog_rate * 100.0))})
for _, i := range []string{"1", "3", "5", "10"} {
recog_rate := recognize(network, "./data/mlp_test"+i+".data")
writer.Write([]string{i, fmt.Sprint(math.Trunc(recog_rate * 100.0))})
}
if opts.AdditionalData != "" {
recog_rate := recognize(network, opts.AdditionalData)
writer.Write([]string{"***", fmt.Sprint(math.Trunc(recog_rate * 100.0))})
}
writer.Flush()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment