Created
May 30, 2012 13:25
-
-
Save lithtle/2836320 to your computer and use it in GitHub Desktop.
maximum likefood estimation and EM algorithm(エディタのインデント幅を2にするのが面倒だったので4のまま)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| *~ | |
| sample10* | |
| sample100* | |
| sample1000* | |
| sample10000* | |
| sample100000* |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! /usr/bin/ruby -Ku | |
| t = { | |
| "sample10" => 10, | |
| "sample100" => 100, | |
| "sample1000" => 1000, | |
| "sample10000" => 10000, | |
| "sample100000" => 100000 | |
| } | |
| IO::popen("gnuplot", "w") do |p| | |
| p.puts "set terminal eps" | |
| t.each do |k, v| | |
| p.puts "set output '#{k}_result.eps'" | |
| mean = open(k + "_mean").read | |
| p.puts "set arrow nohead from 0, #{mean} to #{v}, #{mean}" | |
| p.puts "plot '#{k}_out' w l" | |
| p.puts "rep" | |
| end | |
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! /usr/bin/ruby -Ku | |
| # -*- coding:utf-8 -*- | |
| =begin | |
| EMアルゴリズム | |
| @date: 2012-06-05 | |
| @author: lithtle | |
| =end | |
| require "pp" | |
| require "params" | |
| def head?(character) | |
| if character == "H" | |
| return true | |
| end | |
| return false | |
| end | |
| # テキストファイルのデータを読み込み配列として返す | |
| # === Args | |
| # _fname_ : ファイル名 | |
| # === Returns | |
| # _train_ : トレーニング用のデータ | |
| # === Example | |
| # train = [ ["H", "H", "T", "T"], ["T", "H", "H", "H"], ...] | |
| def get_testdata(fname) | |
| train = [] | |
| open(fname) do |f| | |
| f.each do |line| | |
| train << line.chomp.split(",") | |
| end | |
| end | |
| return train | |
| end | |
| # [H, T, T, H...]という配列から期待値を求める | |
| # === Args | |
| # _items_ : 1セットの試行 | |
| # ==== Example | |
| # items => ["H", "H", "T", "T", "H", "T", "H" ... ] | |
| # === Returns | |
| # コイン A が出る期待値 | |
| def expectation(theta, items) | |
| sum = 0.0 # 和 | |
| h_a = theta * HEAD_A # コインが A で オモテ が出る確率 | |
| t_a = theta * (1.0 - HEAD_A) # コインが A で ウラ が出る確率 | |
| h_b = (1.0 - theta) * HEAD_B # コインが B で オモテ が出る確率 | |
| t_b = (1.0 - theta) * (1.0 - HEAD_B) # コインが B で ウラ が出る確率 | |
| # 和を計算し、要素数で割る | |
| items.each do |item| | |
| if head?(item) | |
| sum += h_a / (h_a + h_b) | |
| else | |
| sum += t_a / (t_a + t_b) | |
| end | |
| end | |
| return (sum / items.length) | |
| end | |
| # EMアルゴリズム | |
| # === Returns | |
| # _result_ : 結果の配列 | |
| def estimate | |
| result = [] | |
| OUTPUTS.each_index do |i| | |
| train = get_testdata(OUTPUTS[i]) # ファイルからテストデータを読み込む | |
| tmp = [] | |
| theta = INIT | |
| train.each do |item| # item => ["H", "T", "H", "H"] | |
| # E-step | |
| expected_value = expectation(theta, item) | |
| # M-step: 値の更新 | |
| theta = expected_value | |
| tmp.push(theta) | |
| end | |
| # theta の値の変遷をファイルに出力 | |
| open(OUTPUTS[i] + "_out", "w") do |f| | |
| f.write(tmp.join("\n")) | |
| end | |
| open(OUTPUTS[i] + "_mean", "w") do |f| | |
| f.write(tmp.reduce(:+) / tmp.length) | |
| end | |
| result.push({ | |
| OUTPUTS[i] + "x" + TRIAL.to_s => tmp.reduce(:+) / tmp.length | |
| }) | |
| end | |
| return result | |
| end | |
| def main | |
| pp estimate | |
| end | |
| main |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! /usr/bin/ruby -Ku | |
| # -*- coding:utf-8 -*- | |
| =begin | |
| 学習に必要なテストデータを作成するプログラム | |
| @date: 2012-06-05 | |
| @author: lithtle | |
| @note: | |
| サンプル数10、100、1000、10000のファイルを作る。 | |
| コインAとコインBの割合は7:3とする。推定の段階では未知 | |
| =end | |
| # 定数を得る | |
| require "params" | |
| # コインを_TRIAL_回投げた時のテストデータ _n_ 個の配列を返す | |
| # === Args: | |
| # _prob_ : オモテの確率 | |
| # _n_ : 必要なテストデータ ["H", "T"...]の個数 | |
| # === Returns | |
| # _ret_ : テストデータの配列 | |
| # === Example | |
| # prob = 0.8, n = 2, TRIAL = 4 | |
| # => [["H", "H", "T", "H"], ["T", "H", "H", "H"]] | |
| def make(prob, n) | |
| ret = [] | |
| n.times do |i| | |
| t = [] | |
| TRIAL.times do |j| | |
| if rand < prob # Aが使われた | |
| t << ((rand < HEAD_A) ? "H" : "T") | |
| else # Bが使われた | |
| t << ((rand < HEAD_B) ? "H" : "T") | |
| end | |
| end | |
| ret << t | |
| end | |
| return ret | |
| end | |
| def main | |
| OUTPUTS.each_index do |i| | |
| content = (make(0.7, NUM[i])).sort_by{rand} | |
| open(OUTPUTS[i], "w") do |f| | |
| content.each do |item| | |
| f.write item.join(",") + "\n" | |
| end | |
| end | |
| end | |
| end | |
| main |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| all: sample em graph | |
| sample: | |
| ruby make_sample.rb | |
| em: | |
| ruby em.rb | |
| graph: | |
| ruby command_gnuplot.rb |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! /usr/bin/ruby -Ku | |
| =begin | |
| 最尤推定を行う簡単なプログラム | |
| @date: 2012-05-28 | |
| @author: lithtle | |
| @note: 適当な真の値を設定して平均していく、今回は2枚のコインのウラ表とする | |
| =end | |
| require "pp" | |
| # 真のコインの表の確率 | |
| HEADS_A = 0.8 | |
| HEADS_B = 0.3 | |
| # 表かどうかを判定 | |
| # === Args | |
| # _character_ : "H", "T"のいずれかの文字 | |
| # === Returns | |
| # _true_ : 表 | |
| # _false_ : 裏 | |
| def head?(character) | |
| if character == "H" | |
| return true | |
| end | |
| return false | |
| end | |
| # テスト用のトレーニングデータを取得する | |
| # === Args | |
| # _n_ : コインを投げた試行回数 | |
| # === Returns | |
| # _training_ : トレーニング用のデータ | |
| # ==== Examples | |
| # n = 3としたとき | |
| # data = ["H", "T"] => | |
| # training = [["H", "H", "H"], ["H", "H", "T"], ... ["T", "T", "T"]] | |
| def get_training(n) | |
| training = [] | |
| # ruby の整数は [] でアクセスすると | |
| # その数を2進数に変換した際の j 桁目に対応する値が出てくるのを利用 | |
| # Example | |
| # 5 => 101(2) なので 5[0] = 1, 5[1] = 0, 5[2] = 1, 5[3..] = 0 | |
| (2 ** n).times do |i| | |
| tmp = [] | |
| n.times do |j| | |
| tmp.push i[j] | |
| end | |
| tmp.reverse! | |
| training.push tmp | |
| end | |
| # map!を用いて "H", "T" に置換 | |
| training.each do |item| | |
| item.map! do |j| | |
| case j | |
| when 0 then "H" | |
| when 1 then "T" | |
| end | |
| end | |
| end | |
| return training | |
| end | |
| # 尤度(likefood)を得る | |
| # === Returns | |
| # _likefood_ : どちらのコインであるかの判定と、その時の両方の尤度。尤度はその結果がA(orB)とした時の表かどうかの確率の積で表す | |
| # ==== Examples | |
| # likefood = [ {"judge" => "A", "likefood_a" => 0.31, "likefood_b" => 0.07}, ...] | |
| def get_likefood(training) | |
| likefood = [] | |
| training.each do |train_set| | |
| # 乗算で求めるため1.0の浮動小数点数に初期化 | |
| likefood_a = 1.0 | |
| likefood_b = 1.0 | |
| train_set.each do |i| | |
| h = head? i | |
| likefood_a *= h ? HEADS_A : (1 - HEADS_A) | |
| likefood_b *= h ? HEADS_B : (1 - HEADS_B) | |
| end | |
| which = (likefood_a > likefood_b) ? "A" : "B" | |
| likefood.push({ | |
| "judge" => which, | |
| "likefood_a" => likefood_a, | |
| "likefood_b" => likefood_b | |
| }) | |
| end | |
| return likefood | |
| end | |
| # 整形して表示 | |
| # === Args | |
| # _training_ : テストデータ | |
| # _likefood_ : 尤度の結果データ | |
| def result(training, likefood) | |
| likefood.each_with_index do |item, i| | |
| puts "#{i} : #{training[i].join ", "}" | |
| print "\t" | |
| pp item | |
| end | |
| end | |
| def main | |
| training = get_training 4 | |
| likefood = get_likefood training | |
| result training, likefood | |
| end | |
| main |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! /usr/bin/ruby -Ku | |
| # -*- coding:utf-8 -*- | |
| =begin | |
| パラメータを定義したファイル | |
| @date: 2012-06-05 | |
| @author: lithtle | |
| @note: 大域変数群 | |
| =end | |
| # 真のコインの表の確率(既知) | |
| HEAD_A = 0.8 | |
| HEAD_B = 0.3 | |
| # コインを選んで投げる試行回数 | |
| TRIAL = 30 | |
| # 初期値 | |
| INIT = 0.2 | |
| # 出力ファイル名 | |
| OUTPUTS = ["sample10", "sample100", "sample1000", "sample10000", "sample100000"] | |
| # テストデータの個数 | |
| NUM = [10, 100, 1000, 10000, 100000] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment