Created
          August 18, 2015 05:52 
        
      - 
      
 - 
        
Save peace098beat/2358ee62fb031976abc5 to your computer and use it in GitHub Desktop.  
    AdaBoost
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | #! coding:utf-8 | |
| # ****************************************** # | |
| # | |
| # adboost.py | |
| # | |
| # ****************************************** # | |
| from __future__ import division | |
| from numpy import * | |
| class AdaBoost: | |
| def __init__(self, training_set): | |
| # トレーニングデータ | |
| self.training_set = training_set | |
| # N | |
| self.N = len(self.training_set) | |
| # 重みベクトル | |
| self.weights = ones(self.N)/self.N | |
| # ルール:弱識別機 | |
| self.RULES = [] | |
| # アルファ | |
| self.ALPHA = [] | |
| def set_rule(self, func, test=False): | |
| # 誤差関数から誤差Jmを求める | |
| errors = array([t[1] != func(t[0]) for t in self.training_set]) | |
| e = (errors*self.weights).sum() | |
| # testフラグがあれば、エラー値を返す | |
| if test: | |
| return e | |
| # 重み係数αの算出 | |
| alpha = 0.5 * log((1-e)/e) | |
| print 'e=%.2f a=%.2f' % (e, alpha) | |
| # 重みベクトルBuffer | |
| w = zeros(self.N) | |
| # 重みベクトルの更新 | |
| for i in range(self.N): | |
| if errors[i] == 1: | |
| w[i] = self.weights[i] * exp(alpha) | |
| else: | |
| w[i] = self.weights[i] * exp(-alpha) | |
| # 重み係数の正規化 | |
| self.weights = w / w.sum() | |
| self.RULES.append(func) | |
| self.ALPHA.append(alpha) | |
| def evaluate(self): | |
| # 実際に弱識別機基底に投影し、重み値と閾値から、0・1クラスタリング | |
| NR = len(self.RULES) | |
| for (x, l) in self.training_set: | |
| hx = [self.ALPHA[i]*self.RULES[i](x) for i in range(NR)] | |
| print x, sign(l) == sign(sum(hx)) | |
| if __name__ == '__main__': | |
| # 学習データ | |
| examples = [] | |
| examples.append(((1, 2), 1)) | |
| examples.append(((1, 4), 1)) | |
| examples.append(((2.5, 5.5), 1)) | |
| examples.append(((3.5, 6.5), 1)) | |
| examples.append(((4, 5.4), 1)) | |
| examples.append(((2, 1), -1)) | |
| examples.append(((2, 4), -1)) | |
| examples.append(((3.5, 3.5), -1)) | |
| examples.append(((5, 2), -1)) | |
| examples.append(((5, 5.5), -1)) | |
| # アダブーストクラス生成 | |
| m = AdaBoost(examples) | |
| # 弱識別機 | |
| m.set_rule(lambda x: 2*(x[0] < 1.5)-1) | |
| m.set_rule(lambda x: 2*(x[0] < 4.5)-1) | |
| m.set_rule(lambda x: 2*(x[1] > 5)-1) | |
| # クラスタ結果 | |
| m.evaluate() | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment