Skip to content

Instantly share code, notes, and snippets.

@masatoi
Created February 16, 2019 11:08
Show Gist options
  • Save masatoi/0b860a1f85b2ea3025c9128883b47add to your computer and use it in GitHub Desktop.
Save masatoi/0b860a1f85b2ea3025c9128883b47add to your computer and use it in GitHub Desktop.
cl-random-forest-workspace.lisp
;;; -*- coding:utf-8; mode:lisp -*-
(in-package :cl-random-forest)
;;; Small dataset ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defparameter *n-class* 4)
(defparameter *target*
(make-array 11 :element-type 'fixnum
:initial-contents '(0 0 1 1 2 2 2 3 3 3 3)))
(defparameter *datamatrix*
(make-array '(11 2)
:element-type 'double-float
:initial-contents '((-1.0d0 -2.0d0)
(-2.0d0 -1.0d0)
(1.0d0 -2.0d0)
(3.0d0 -1.5d0)
(-2.0d0 2.0d0)
(-3.0d0 1.0d0)
(-2.0d0 1.0d0)
(3.0d0 2.0d0)
(2.0d0 2.0d0)
(1.0d0 2.0d0)
(1.0d0 1.0d0))))
;;; Decision tree ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; make decision tree
(defparameter *dtree*
(make-dtree *n-class* *datamatrix* *target*
:max-depth 5 :min-region-samples 1 :n-trial 10))
;; prediction
(predict-dtree *dtree* *datamatrix* 2)
(predict-dtree-majority-vote *dtree* *datamatrix* 2)
(defun predict-node (node)
(let ((max 0d0)
(max-class 0)
(dist (node-class-distribution node))
(n-class (dtree-n-class (node-dtree node))))
(loop for i fixnum from 0 to (1- n-class) do
(when (> (aref dist i) max)
(setf max (aref dist i)
max-class i)))
max-class))
(defun extract-node (node)
(if (and node (node-test-attribute node))
`(if (>= (aref d i ,(node-test-attribute node)) ,(node-test-threshold node))
,(extract-node (node-left-node node))
,(extract-node (node-right-node node)))
(predict-node node)))
(defun construct-dtree-lambda (dtree)
`(lambda (d i)
(declare (optimize (speed 3) (space 0) (safety 0) (debug 0) (compilation-speed 0))
(type (simple-array double-float) d)
(type fixnum i))
,(extract-node (dtree-root dtree))))
(construct-dtree-lambda *dtree*)
;; 生成されるlambda式
(LAMBDA (DATAMATRIX DATUM-INDEX)
(DECLARE
(OPTIMIZE (SPEED 3) (SPACE 0) (SAFETY 0) (DEBUG 0) (COMPILATION-SPEED 0))
(TYPE (SIMPLE-ARRAY DOUBLE-FLOAT) DATAMATRIX)
(TYPE FIXNUM DATUM-INDEX))
(IF (>= (AREF DATAMATRIX DATUM-INDEX 0) -0.7394168078526362d0)
(IF (>= (AREF DATAMATRIX DATUM-INDEX 1) 0.8903535809681147d0)
3
1)
(IF (>= (AREF DATAMATRIX DATUM-INDEX 1) 0.5876648784761986d0)
2
0)))
;; コンパイル
(defparameter compiled-dtree (compile nil (construct-dtree-lambda *dtree*)))
;;呼び出し
(funcall compiled-dtree *datamatrix* 0)
;;;;;
(defparameter mnist-dim 784)
(defparameter mnist-n-class 10)
(let ((mnist-train (clol.utils:read-data "/home/wiz/datasets/mnist.scale" mnist-dim :multiclass-p t))
(mnist-test (clol.utils:read-data "/home/wiz/datasets/mnist.scale.t" mnist-dim :multiclass-p t)))
;; Add 1 to labels in order to form class-labels beginning from 0
(dolist (datum mnist-train) (incf (car datum)))
(dolist (datum mnist-test) (incf (car datum)))
(multiple-value-bind (datamat target)
(clol-dataset->datamatrix/target mnist-train)
(defparameter mnist-datamatrix datamat)
(defparameter mnist-target target))
(multiple-value-bind (datamat target)
(clol-dataset->datamatrix/target mnist-test)
(defparameter mnist-datamatrix-test datamat)
(defparameter mnist-target-test target)))
;;; Make Decision Tree ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defparameter mnist-dtree
(make-dtree mnist-n-class mnist-datamatrix mnist-target
:max-depth 10 :n-trial 28 :min-region-samples 5))
(test-dtree mnist-dtree mnist-datamatrix mnist-target)
(test-dtree mnist-dtree mnist-datamatrix-test mnist-target-test)
(time
(loop repeat 100 do
(loop for i from 0 below (array-dimension mnist-datamatrix 0) do
(predict-dtree mnist-dtree mnist-datamatrix i))))
(time (defparameter dtree-predictor (compile nil (construct-dtree-lambda mnist-dtree))))
(time (defparameter dtree-lambda (construct-dtree-lambda mnist-dtree))) ; これは非常に高速
(time
(loop repeat 100 do
(loop for i from 0 below (array-dimension mnist-datamatrix 0) do
(funcall dtree-predictor mnist-datamatrix i))))
(defparameter mnist-forest
(make-forest mnist-n-class mnist-datamatrix mnist-target
:n-tree 500 :bagging-ratio 0.1 :max-depth 5 :n-trial 28 :min-region-samples 5))
(time
(defparameter dtree-predictor-list
(loop for dtree in (forest-dtree-list mnist-forest)
for i from 0
collect (progn
(print i)
(compile nil (construct-dtree-lambda dtree))))))
(defun argmax (arr)
(let ((max 0)
(max-i 0))
(loop for i from 0 below (length arr) do
(when (> (aref arr i) max)
(setf max (aref arr i)
max-i i)))
max-i))
(defun predict-dtree-predictor-list (dtree-predictor-list datamatrix index)
(let ((cnt (make-array (array-dimension datamatrix 1))))
(loop for predictor in dtree-predictor-list do
(incf (aref cnt (funcall predictor datamatrix index))))
(argmax cnt)))
(predict-dtree-predictor-list dtree-predictor-list mnist-datamatrix 0)
(defun test-dtree-predictor-list (dtree-predictor-list datamatrix target)
(loop for i from 0 below (array-dimension datamatrix 0)
count (= (predict-dtree-predictor-list dtree-predictor-list datamatrix i)
(aref target i))))
(time (test-dtree-predictor-list dtree-predictor-list mnist-datamatrix-test mnist-target-test))
;; 9385
;; Evaluation took:
;; 0.286 seconds of real time
;; 0.284000 seconds of total run time (0.284000 user, 0.000000 system)
;; 99.30% CPU
;; 967,442,117 processor cycles
;; 62,876,912 bytes consed
(time (test-forest mnist-forest mnist-datamatrix-test mnist-target-test))
;; Accuracy: 94.33%, Correct: 9433, Total: 10000
;; Evaluation took:
;; 2.659 seconds of real time
;; 2.660000 seconds of total run time (2.660000 user, 0.000000 system)
;; 100.04% CPU
;; 9,021,268,236 processor cycles
;; 1,216 bytes consed
;; 事前に全ての葉の予測値を出しておく方式(多数決 majority-vote)
;; predict-all-leaf
(defparameter leaf1 (find-leaf (dtree-root mnist-dtree) mnist-datamatrix 0))
(argmax (node-class-distribution leaf1))
(defun set-leaf-prediction! (dtree)
(do-leaf (lambda (node)
(setf (node-leaf-prediction node)
(argmax (node-class-distribution node))))
(dtree-root dtree)))
(defun set-leaf-prediction-forest! (forest)
(dolist (dtree (forest-dtree-list forest))
(set-leaf-prediction! dtree)))
(time (set-leaf-prediction mnist-dtree))
(time (set-leaf-prediction-forest mnist-forest))
(defun predict-dtree-majority-vote (dtree datamatrix datum-index)
(node-leaf-prediction (find-leaf (dtree-root dtree) datamatrix datum-index)))
(defun test-dtree-majority-vote (dtree datamatrix target &key quiet-p)
(declare (optimize (speed 3) (safety 0))
(type dtree dtree)
(type (simple-array double-float) datamatrix)
(type (simple-array fixnum (*)) target))
(let ((n-correct 0)
(len (length target)))
(declare (type fixnum n-correct len))
(loop for i fixnum from 0 below len do
(when (= (predict-dtree-majority-vote dtree datamatrix i)
(aref target i))
(incf n-correct)))
(calc-accuracy n-correct len :quiet-p quiet-p)))
(defun predict-forest-majority-vote (forest datamatrix datum-index)
(let ((class-count-array (forest-class-count-array forest)))
;; init class-count-array
(loop for i fixnum from 0 below (length class-count-array) do
(setf (aref class-count-array i) 0d0))
(dolist (dtree (forest-dtree-list forest))
(let ((predicted-class
(node-leaf-prediction (find-leaf (dtree-root dtree) datamatrix datum-index))))
(incf (aref class-count-array predicted-class) 1.0d0)))
(argmax class-count-array)))
(defun test-forest-majority-vote (forest datamatrix target &key quiet-p)
(declare (optimize (speed 3) (safety 0))
(type forest forest)
(type (simple-array double-float) datamatrix)
(type (simple-array fixnum) target))
(let ((n-correct 0)
(len (length target)))
(declare (type fixnum n-correct len))
(loop for i fixnum from 0 below len do
(when (= (predict-forest-majority-vote forest datamatrix i)
(aref target i))
(incf n-correct)))
(calc-accuracy n-correct len :quiet-p quiet-p)))
(time (test-forest-majority-vote mnist-forest mnist-datamatrix mnist-target))
(time (test-forest-majority-vote mnist-forest mnist-datamatrix-test mnist-target-test))
;;
(ql:quickload :wiz-util)
(time (set-leaf-prediction *dtree*))
(do-leaf (lambda (node)
(node-class-distribution node)
(node-leaf-prediction node)
)
(dtree-root *dtree*))
(require :sb-sprof)
(sb-sprof:
(sb-sprof:with-profiling (:max-samples 1000
:report :flat
:loop nil)
(test-forest-majority-vote mnist-forest mnist-datamatrix mnist-target))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment