Last active
January 13, 2016 15:05
-
-
Save karino2/6bad6cf310b495f44e9c to your computer and use it in GitHub Desktop.
RandomForestをIPython notebookで作っていった
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 224, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import pandas as pd" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 226, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "df = pd.read_table('iris.txt', sep='\\s+', names=['x1', 'x2', 'x3','x4', 'y'])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 603, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "numradfeatures = 2\n", | |
| "numrandpos = 5\n", | |
| "featurenum = 4\n", | |
| "minnodesize = 1\n", | |
| "maxlevel = 16\n", | |
| "\n", | |
| "sampler = np.random.permutation(len(df))\n", | |
| "randomdf = df.take(sampler)\n", | |
| "dfTraining = randomdf[0:25]\n", | |
| "dfTest = randomdf[25:150]\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 604, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "\n", | |
| "class DecisionTreeBuilder(object):\n", | |
| " def __init__(self, dfTraining):\n", | |
| " self.dfTraining = dfTraining\n", | |
| " self.curidx = 0\n", | |
| " tnodes = []\n", | |
| " root = {\"isleaf\": False, \"level\": 0, \"featureidx\": -1, \"separateval\":0, \"answerclass\":0, \"leftidx\":-1, \"rightidx\": -1 }\n", | |
| " tnodes.append(root)\n", | |
| " self.tnodes = tnodes\n", | |
| " root['valids'] = np.random.randint(0, len(dfTraining), size=numdata)\n", | |
| " self.curdf = dfTraining.take(root['valids']).copy()\n", | |
| " def setupOneNode(self):\n", | |
| " self.cur = self.tnodes[self.curidx]\n", | |
| " # rootじゃない時は... rootはvalidsがちょっと意味が違うので特別に事前処理\n", | |
| " if self.curidx != 0:\n", | |
| " self.curdf = self.dfTraining.ix[self.cur['valids']].copy()\n", | |
| " def isfinish(self):\n", | |
| " return self.curidx >= len(self.tnodes)\n", | |
| " def becomeLeafAndGotoNext(self):\n", | |
| " curdf = self.curdf\n", | |
| " ysize = curdf.groupby('y').size()\n", | |
| " cur = self.cur\n", | |
| " cur['answerclass'] = ysize.index[ysize == max(ysize)][0]\n", | |
| " cur['isleaf'] = True\n", | |
| " self.curidx += 1\n", | |
| " \n", | |
| " def calcGini(self, tryfid, trypos):\n", | |
| " curdf = self.curdf\n", | |
| " trydf = curdf.iloc[trypos]\n", | |
| " tryfval = trydf[tryfid]\n", | |
| "\n", | |
| " tryfname = curdf.columns[tryfid]\n", | |
| " curdf[\"isleft\"] = curdf[tryfname] < tryfval\n", | |
| "\n", | |
| " try:\n", | |
| " righttotal = curdf[tryfname].groupby(curdf['isleft']).count()[False]\n", | |
| " lefttotal = curdf[tryfname].groupby(curdf['isleft']).count()[True]\n", | |
| " except KeyError:\n", | |
| " # True or False is not exist.\n", | |
| " return None\n", | |
| " \n", | |
| " if righttotal == 0 or lefttotal == 0:\n", | |
| " return None\n", | |
| " \n", | |
| " rightcountbyclass = curdf.groupby(['isleft', 'y'])['y'].count()[False]\n", | |
| "\n", | |
| " rightgini = 1.0 - sum([(float(rightcountbyclass[index])/righttotal)**2 for index in rightcountbyclass.index])\n", | |
| "\n", | |
| " leftcountbyclass = curdf.groupby(['isleft', 'y'])['y'].count()[True]\n", | |
| "\n", | |
| " leftgini = 1.0 - sum([(float(leftcountbyclass[index])/lefttotal)**2 for index in leftcountbyclass.index])\n", | |
| "\n", | |
| " # これが最小のfeatureと値がツリーとなる。\n", | |
| " totalgini = leftgini*float(lefttotal)/(lefttotal+righttotal) + rightgini*float(righttotal)/(righttotal+lefttotal)\n", | |
| " return {'gini': totalgini, 'splitval': tryfval, 'splitfid': tryfid, 'isleft':curdf['isleft'].copy()}\n", | |
| " def bestGini(self, curBest, newGini):\n", | |
| " if curBest == None:\n", | |
| " return newGini\n", | |
| " if newGini == None:\n", | |
| " return curBest\n", | |
| " if newGini['gini'] < curBest['gini']:\n", | |
| " return newGini\n", | |
| " return curBest\n", | |
| " def calcRandomSampleBestGini(self):\n", | |
| " curdf = self.curdf\n", | |
| " tryfeatureids = np.random.randint(0, featurenum, size=numradfeatures)\n", | |
| "\n", | |
| " bestRes = None\n", | |
| " for tryfid in tryfeatureids:\n", | |
| "\n", | |
| " tryposes = np.random.randint(0, len(curdf), size=numrandpos)\n", | |
| "\n", | |
| " for trypos in tryposes:\n", | |
| " curGini = self.calcGini(tryfid, trypos)\n", | |
| " \n", | |
| " bestRes = self.bestGini(bestRes, curGini)\n", | |
| " return bestRes\n", | |
| " \n", | |
| " def splitTwo(self, bestRes):\n", | |
| " # 二つに分ける\n", | |
| " leftnode = {\"isleaf\": False, \"level\": -1, \"featureidx\": -1, \"separateval\":0, \"answerclass\":0, \"leftidx\":-1, \"rightidx\": -1 }\n", | |
| " rightnode = {\"isleaf\": False, \"level\": -1, \"featureidx\": -1, \"separateval\":0, \"answerclass\":0, \"leftidx\":-1, \"rightidx\": -1 }\n", | |
| "\n", | |
| " cur = self.cur\n", | |
| " leftnode['level'] = cur['level']+1\n", | |
| " rightnode['level'] = cur['level']+1\n", | |
| " \n", | |
| " cur['featureidx'] = bestRes['splitfid']\n", | |
| " cur['separateval'] = bestRes['splitval']\n", | |
| " \n", | |
| " curdf = self.curdf\n", | |
| "\n", | |
| " # curdf, isleft, ids\n", | |
| " leftnode['valids'] = curdf[bestRes['isleft']].index\n", | |
| " rightnode['valids'] = curdf[False == bestRes['isleft']].index\n", | |
| "\n", | |
| " tnodes = self.tnodes\n", | |
| " cur['leftidx'] = len(tnodes)\n", | |
| " tnodes.append(leftnode)\n", | |
| " cur['rightidx'] = len(tnodes)\n", | |
| " tnodes.append(rightnode)\n", | |
| "\n", | |
| " \n", | |
| " def handleOneNode(self):\n", | |
| " curdf = self.curdf\n", | |
| " firsty = curdf.iloc[0]['y']\n", | |
| " \n", | |
| " if 0 == len(curdf[curdf['y'] != firsty]) or len(curdf) < minnodesize or cur['level'] >= maxlevel:\n", | |
| " self.becomeLeafAndGotoNext()\n", | |
| " return\n", | |
| " \n", | |
| " bestRes = self.calcRandomSampleBestGini()\n", | |
| "\n", | |
| " if bestRes == None:\n", | |
| " # lefttotal == 0 かrighttotal==0ならここで葉にする\n", | |
| " self.becomeLeafAndGotoNext()\n", | |
| " return\n", | |
| "\n", | |
| " self.splitTwo(bestRes)\n", | |
| " # このあとDT作るループの先頭に戻る\n", | |
| " self.curidx += 1\n", | |
| " \n", | |
| " def train(self):\n", | |
| " while not self.isfinish():\n", | |
| " self.setupOneNode()\n", | |
| " self.handleOneNode()\n", | |
| " \n", | |
| "\n", | |
| " \n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 605, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class DecisionTree(object):\n", | |
| " def __init__(self, tnodes):\n", | |
| " self.tree = tnodes\n", | |
| " def estimate(self, featureseries):\n", | |
| " node = self.tree[0]\n", | |
| " while not node['isleaf']:\n", | |
| " if featureseries[node['featureidx']] < node['separateval']:\n", | |
| " nextid = node['leftidx']\n", | |
| " else:\n", | |
| " nextid = node['rightidx']\n", | |
| " node = self.tree[nextid]\n", | |
| " return node['answerclass']\n", | |
| " \n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 606, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import collections\n", | |
| "class RandomForest(object):\n", | |
| " def __init__(self):\n", | |
| " self.trees = []\n", | |
| " def add(self, tree):\n", | |
| " self.trees.append(tree)\n", | |
| " def train_and_add(self, dfTraining):\n", | |
| " dtbuilder = DecisionTreeBuilder(dfTraining)\n", | |
| " dtbuilder.train()\n", | |
| " dt = DecisionTree(dtbuilder.tnodes)\n", | |
| " self.add(dt)\n", | |
| " def estimate(self, features):\n", | |
| " answers = map(lambda tree: tree.estimate(features), self.trees)\n", | |
| " self._answers = answers\n", | |
| " return collections.Counter(answers).most_common(1)[0][0]\n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 607, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def try_test_data(dfTest):\n", | |
| " guesses = map(lambda tupple: forest.estimate(tupple[1]), dfTest.iterrows())\n", | |
| " answers = list(dfTest['y'])\n", | |
| " return (len(forest.trees), len([i for i, ent in enumerate(answers) if ent != guesses[i]]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 608, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "treenum=1, err_count=10\n", | |
| "treenum=2, err_count=15\n", | |
| "treenum=3, err_count=12\n", | |
| "treenum=4, err_count=11\n", | |
| "treenum=5, err_count=9\n", | |
| "treenum=6, err_count=8\n", | |
| "treenum=7, err_count=9\n", | |
| "treenum=8, err_count=7\n", | |
| "treenum=9, err_count=8\n", | |
| "treenum=10, err_count=9\n", | |
| "treenum=11, err_count=9\n", | |
| "treenum=12, err_count=9\n", | |
| "treenum=13, err_count=9\n", | |
| "treenum=14, err_count=10\n", | |
| "treenum=15, err_count=9\n", | |
| "treenum=16, err_count=9\n", | |
| "treenum=17, err_count=8\n", | |
| "treenum=18, err_count=9\n", | |
| "treenum=19, err_count=9\n", | |
| "treenum=20, err_count=9\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "forest = RandomForest()\n", | |
| "forest.train_and_add(dfTraining)\n", | |
| "for i in range(20):\n", | |
| " res = try_test_data(dfTest)\n", | |
| " print(\"treenum={0}, err_count={1}\".format(*res))\n", | |
| " forest.train_and_add(dfTraining)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 2", | |
| "language": "python", | |
| "name": "python2" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 2 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython2", | |
| "version": "2.7.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment