Last active
October 31, 2017 06:49
-
-
Save 43x2/adae99e3122b0d1b4a041ba0736cb6a4 to your computer and use it in GitHub Desktop.
最小二乗法を最急降下法で解く Python スクリプト
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## | |
## Usage: lsm.py -d(--data) (input_data_file) -r(--ratio) (ratio) (initial0) (initial1) ... (initialN) | |
## | |
## -d, --data -- x,y データファイル名 (タブ区切り) | |
## -r, --ratio -- 修正係数(η) | |
## (initial0)...(initialN) -- x^0...x^N の係数初期値 | |
## | |
import sys | |
def main(): | |
import argparse | |
import copy | |
# 引数をパースする | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-d', '--data', type=argparse.FileType('r'), required=True) | |
parser.add_argument('-r', '--ratio', type=float, required=True) | |
parser.add_argument('initials', nargs='+', type=float) | |
args = parser.parse_args(sys.argv[1:]) | |
# 計測データを読み込む | |
inputs = [] | |
for line in args.data: | |
items = line.split('\t') | |
inputs.append([float(items[0]), float(items[1])]) | |
args.data.close() | |
coefficients = copy.deepcopy(args.initials) # 係数初期値からスタート | |
times = desired_times = 0 | |
while (True): | |
# 各計測データから f(x) と残差平方和を求める (f(x) は再利用するので配列に入れておく) | |
fxs = [] | |
rss = 0.0 | |
for each_input in inputs: | |
fx = coefficients[0] | |
for order in range(1, len(coefficients)): | |
fx += coefficients[order] * each_input[0] ** order | |
rss += (each_input[1] - fx) ** 2 | |
fxs.append(fx) | |
# 経過出力・学習回数入力 | |
if times == desired_times: | |
print('After {0} times...'.format(times)) | |
print(' Initial: {0}'.format(args.initials)) | |
print(' Trained: {0}'.format(coefficients)) | |
print(' RSS: {0}'.format(rss)) | |
input_times = input(' How many (more) times? (0 to quit) ') | |
input_times = input_times.strip() | |
if input_times == '' or input_times == '0': | |
return 0 # 終了 | |
desired_times = times + int(input_times) | |
elif times % 100 == 0 or times <= 30: # 最初の30回は学習効果の確認用 | |
print('After {0} times... RSS: {1}'.format(times, rss)) | |
# 学習 | |
dRSS_dCOEFF = [0] * len(coefficients) # ∂rss/∂係数 | |
for (each_input, fx) in zip(inputs, fxs): | |
for order in range(0, len(coefficients)): | |
dRSS_dCOEFF[order] += -2.0 * (each_input[1] - fx) * each_input[0] ** order # -2.0 はこちらで計算してしまう | |
for order in range(0, len(coefficients)): | |
coefficients[order] -= args.ratio * dRSS_dCOEFF[order] # 最急降下法 | |
times += 1 | |
if __name__ == '__main__': | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment