Created
September 1, 2016 10:27
-
-
Save genya0407/16e6ed23906c261f00761ef10305abd2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/csv" | |
"fmt" | |
"github.com/gonum/plot" | |
"github.com/gonum/plot/plotter" | |
"github.com/gonum/plot/plotutil" | |
"github.com/gonum/plot/vg" | |
flags "github.com/jessevdk/go-flags" | |
//"github.com/k0kubun/pp" | |
"io/ioutil" | |
"math" | |
"math/rand" | |
"os" | |
"strconv" | |
"strings" | |
"time" | |
) | |
const ( | |
a = 0.1 | |
) | |
var K int | |
type Network struct { | |
OutputWeight [][]float64 | |
MiddleWeight [][]float64 | |
} | |
type Perceptrons struct { | |
Output []float64 | |
Middle []float64 | |
} | |
func initial_perceptrons() Perceptrons { | |
p := Perceptrons{} | |
p.Output = make([]float64, 10) | |
p.Middle = make([]float64, K+1) | |
return p | |
} | |
type Data struct { | |
input []float64 | |
output []float64 | |
} | |
func get_data(filename string) []Data { | |
var Datas []Data | |
raw_data, _ := ioutil.ReadFile(filename) | |
splitted_data := strings.Split(string(raw_data), "\n") | |
for n := 0; n+10 < len(splitted_data); n = n + 12 { | |
var input []float64 | |
var output []float64 | |
input = []float64{-1} | |
for _, line := range splitted_data[n : n+9] { | |
for _, bit := range strings.Split(strings.TrimSpace(line), " ") { | |
num, _ := strconv.Atoi(bit) | |
input = append(input, float64(num)) | |
} | |
} | |
for _, bit := range strings.Split(splitted_data[n+10], " ") { | |
num, _ := strconv.Atoi(bit) | |
output = append(output, float64(num)) | |
} | |
Datas = append(Datas, Data{input, output}) | |
} | |
return Datas | |
} | |
func initial_network() Network { | |
network := Network{} | |
network.OutputWeight = make([][]float64, K+1) | |
for i, _ := range network.OutputWeight { | |
network.OutputWeight[i] = make([]float64, 9+1) | |
} | |
network.MiddleWeight = make([][]float64, 81+1) | |
for i, _ := range network.MiddleWeight { | |
network.MiddleWeight[i] = make([]float64, K+1) | |
} | |
for i := 0; i <= 81; i++ { | |
for m := 0; m <= K; m++ { | |
network.MiddleWeight[i][m] = 0.005 - rand.Float64()*0.01 | |
} | |
} | |
for m := 0; m <= K; m++ { | |
for o := 0; o <= 9; o++ { | |
network.OutputWeight[m][o] = 0.005 - rand.Float64()*0.01 | |
} | |
} | |
return network | |
} | |
func sig(x float64) float64 { | |
return (1.0 / (1.0 + math.Exp(-x))) | |
} | |
func output_to_number(O []float64) int { | |
max := 0.0 | |
num := 0 | |
for i, v := range O { | |
if max < v { | |
max = v | |
num = i | |
} | |
} | |
return num | |
} | |
func calc_error(O []float64, T []float64) float64 { | |
error_sum := 0.0 | |
for o := 0; o < 10; o++ { | |
error_sum += math.Pow(O[o]-T[o], 2) | |
} | |
return error_sum / 2.0 | |
} | |
func calc_perceptrons(network Network, input []float64) Perceptrons { | |
p := initial_perceptrons() | |
p.Middle[0] = -1 | |
for m := 1; m < K+1; m++ { | |
s := 0.0 | |
for i := 0; i < 81+1; i++ { | |
s += network.MiddleWeight[i][m] * input[i] | |
} | |
p.Middle[m] = sig(s) | |
} | |
for o := 0; o < 9+1; o++ { | |
s := 0.0 | |
for m := 0; m < K+1; m++ { | |
s += network.OutputWeight[m][o] * p.Middle[m] | |
} | |
p.Output[o] = sig(s) | |
} | |
return p | |
} | |
func epoc(network Network, teachers []Data) (Network, float64, int) { | |
// 重み配列の差分 | |
var OutputWeightDelta [][]float64 | |
var MiddleWeightDelta [][]float64 | |
OutputWeightDelta = make([][]float64, K+1) | |
for i, _ := range OutputWeightDelta { | |
OutputWeightDelta[i] = make([]float64, 9+1) | |
} | |
MiddleWeightDelta = make([][]float64, 81+1) | |
for i, _ := range MiddleWeightDelta { | |
MiddleWeightDelta[i] = make([]float64, K+1) | |
} | |
err := 0.0 | |
success_count := 0 | |
for _, teacher := range teachers { | |
// 内部状態の計算 | |
perceptrons := calc_perceptrons(network, teacher.input) | |
// 認識成功してたらsuccess_countをインクリメント | |
if output_to_number(perceptrons.Output[:]) == output_to_number(teacher.output) { | |
success_count++ | |
} | |
// 二乗誤差関数の値を計算 | |
err += calc_error(perceptrons.Output, teacher.output) | |
// 入力層から中間層への重みの差分を計算 | |
for i := 0; i <= 81; i++ { | |
for m := 0; m <= K; m++ { | |
var second_sect float64 | |
for o := 0; o < 10; o++ { | |
second_sect = second_sect + (perceptrons.Output[o]-teacher.output[o])*perceptrons.Output[o]*(1.0-perceptrons.Output[o])*network.OutputWeight[m][o] | |
} | |
MiddleWeightDelta[i][m] = -a * second_sect * perceptrons.Middle[m] * (1.0 - perceptrons.Middle[m]) * teacher.input[i] | |
} | |
} | |
// 中間層から出力層への重みの差分を計算 | |
for m := 0; m <= K; m++ { | |
for o := 0; o <= 9; o++ { | |
OutputWeightDelta[m][o] = -a * (perceptrons.Output[o] - teacher.output[o]) * perceptrons.Output[o] * (1.0 - perceptrons.Output[o]) * perceptrons.Middle[m] | |
} | |
} | |
// 入力層から中間層への重みの学習 | |
for i := 0; i <= 81; i++ { | |
for m := 0; m <= K; m++ { | |
network.MiddleWeight[i][m] += MiddleWeightDelta[i][m] | |
} | |
} | |
// 中間層から出力層への重みの学習 | |
for m := 0; m <= K; m++ { | |
for o := 0; o <= 9; o++ { | |
network.OutputWeight[m][o] += OutputWeightDelta[m][o] | |
} | |
} | |
} | |
// 二乗誤差関数の平均を取る | |
err = err / 100.0 | |
return network, err, success_count | |
} | |
func train() Network { | |
network := initial_network() | |
teachers := get_data(opts.Teacher) | |
err := math.Inf(1) | |
ep := 0 | |
success_count := 0 | |
var recog_rates plotter.XYs | |
var errors plotter.XYs | |
for ep < opts.Epoc { | |
network, err, success_count = epoc(network, teachers) | |
ep++ | |
_recog_rate := make(plotter.XYs, 1) | |
_recog_rate[0].X = float64(ep) | |
_recog_rate[0].Y = float64(success_count) | |
recog_rates = append(recog_rates, _recog_rate...) | |
_error := make(plotter.XYs, 1) | |
_error[0].X = float64(ep) | |
_error[0].Y = err | |
errors = append(errors, _error...) | |
} | |
if len(opts.RenderGraph) > 0 { | |
p, _ := plot.New() | |
p.X.Label.Text = "epoc count" | |
p.Y.Label.Text = "recognition rate(%)" | |
p.Y.Max = 100 | |
p.Y.Min = 0 | |
plotutil.AddLinePoints(p, "", recog_rates) | |
p.Save(5*vg.Inch, 5*vg.Inch, "../認識率.png") | |
q, _ := plot.New() | |
q.X.Label.Text = "epoc count" | |
q.Y.Label.Text = "error" | |
q.Y.Min = 0 | |
plotutil.AddLinePoints(q, "", errors) | |
q.Save(5*vg.Inch, 5*vg.Inch, "../二乗誤差.png") | |
} | |
return network | |
} | |
func recognize(network Network, filename string) float64 { | |
datas := get_data(filename) | |
success_count := 0 | |
for _, data := range datas { | |
p := calc_perceptrons(network, data.input) | |
if output_to_number(p.Output[:]) == output_to_number(data.output) { | |
success_count++ | |
} | |
} | |
return float64(success_count) / float64(len(datas)) | |
} | |
type Options struct { | |
Epoc int `long:"epoc" default:"180"` | |
MiddleNumber int `short:"K" long:"MiddleNumber" default:"100"` | |
OutputFile string `short:"o" default:"/dev/null"` | |
RenderGraph []bool `short:"g"` | |
Teacher string `long:"teacher" default:"./data/mlp_train.data"` | |
AdditionalData string `short:"a" default:""` | |
} | |
var opts Options | |
func main() { | |
flags.Parse(&opts) | |
K = opts.MiddleNumber | |
rand.Seed(time.Now().UnixNano()) | |
network := train() | |
csv_file, _ := os.Create(opts.OutputFile) | |
defer csv_file.Close() | |
writer := csv.NewWriter(csv_file) | |
writer.Write([]string{"入力データの誤り率(%)", "認識率(%)"}) | |
recog_rate := recognize(network, "./data/mlp_train.data") | |
writer.Write([]string{"0", fmt.Sprint(math.Trunc(recog_rate * 100.0))}) | |
for _, i := range []string{"1", "3", "5", "10"} { | |
recog_rate := recognize(network, "./data/mlp_test"+i+".data") | |
writer.Write([]string{i, fmt.Sprint(math.Trunc(recog_rate * 100.0))}) | |
} | |
if opts.AdditionalData != "" { | |
recog_rate := recognize(network, opts.AdditionalData) | |
writer.Write([]string{"***", fmt.Sprint(math.Trunc(recog_rate * 100.0))}) | |
} | |
writer.Flush() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment