Skip to content

Instantly share code, notes, and snippets.

@deltam
Last active May 13, 2018 12:52
Show Gist options
  • Save deltam/28422648f11b28edcaca1bc8cc213cdd to your computer and use it in GitHub Desktop.
Save deltam/28422648f11b28edcaca1bc8cc213cdd to your computer and use it in GitHub Desktop.
Clojureで決定木を書く。コード生成もする
(ns sandbox.tree
"『集合知プログラミング』の7章を読んで決定木をプログラミングする")
;; https://resources.oreilly.com/examples/9780596529321/blob/master/PCI_Code%20Folder/chapter7/treepredict.py
(def my-data [["slashdot","USA","yes",18,"None"],
["google","France","yes",23,"Premium"],
["digg","USA","yes",24,"Basic"],
["kiwitobes","France","yes",23,"Basic"],
["google","UK","no",21,"Premium"],
["(direct)","New Zealand","no",12,"None"],
["(direct)","UK","no",21,"Basic"],
["google","USA","no",24,"Premium"],
["slashdot","France","yes",19,"None"],
["digg","USA","no",18,"None"],
["google","UK","no",18,"None"],
["kiwitobes","UK","no",19,"None"],
["digg","New Zealand","yes",12,"Basic"],
["slashdot","UK","no",21,"None"],
["google","UK","yes",18,"Basic"],
["kiwitobes","France","yes",19,"Basic"]])
(defn v-op [v] (if (number? v) >= =))
(defn v-pred [col v]
(let [op (v-op v)]
(fn [row] (op (row col) v))))
(defn divide-set [rows col v]
(group-by (v-pred col v) rows))
(def resultf last)
(defn uniq-count [rows]
(frequencies (map resultf rows)))
(defn gini-imprity
"ジニ不純度"
[rows]
(let [total (float (count rows))
counts (uniq-count rows)]
(apply +
(for [[k1 c1] counts, [k2 c2] counts
:when (not= k1 k2)]
(* (/ c1 total) (/ c2 total))))))
(defn entropy
"エントロピー"
[rows]
(let [total (float (count rows))
counts (uniq-count rows)
log2 (fn [n] (/ (Math/log n) (Math/log 2)))
ent (fn [p] (* p (log2 p)))]
(reduce - 0.0
(map (fn [[_ c]] (ent (/ c total)))
counts))))
(defn gain-set
"各行の情報ゲインを計算する"
[rows scoref]
(let [cur-score (scoref rows)
gainf (fn [s1 s2] (let [p (/ (float (count s1)) (count rows))]
(- cur-score
(* p (scoref s1))
(* (- 1 p) (scoref s2)))))]
(for [col (range 0 (dec (count (first rows))))
v (distinct (map #(nth % col) rows))
:let [{s1 true, s2 false} (divide-set rows col v)
gain (gainf s1 s2)]
:when (and (not-empty s1) (not-empty s2))]
[gain {:col col, :val v} [s1 s2]])))
(defn best-gain-set [rows]
(reduce
(fn [[best-gain acc1 acc2] [gain v1 v2]]
(if (> gain best-gain)
[gain v1 v2]
[best-gain acc1 acc2]))
[0.0]
rows))
;; 木の表現
;{:col 0, :val "slashdot", true {:col 1, ...}, false {:col 3, ...}}
(defn build-tree
([rows scoref]
(let [set-rows (gain-set rows scoref)
[best-gain t [s1 s2]] (best-gain-set set-rows)]
(if (> best-gain 0)
(assoc t
true (build-tree s1)
false (build-tree s2))
(uniq-count rows))))
([rows] (build-tree rows entropy)))
(defn classify [tree row]
(loop [t tree]
(if (t :col)
(let [pred (v-pred (t :col) (t :val))]
(recur (t (pred row))))
t)))
(defn tree->if-then
"決定木をif-thenのコードブロックに変換する"
[tree arg-row]
(if (nil? (tree :col))
`~tree
(let [{col :col, v :val, ttr true, ftr false} tree
op (v-op v)]
`(if (~op (~arg-row ~col) ~v)
~(tree->if-then ttr arg-row)
~(tree->if-then ftr arg-row)))))
(defmacro tree->classifier
"決定木を関数に変換する"
[tree]
`(fn [arg#]
~(tree->if-then (eval tree) 'arg#)))
(defmacro def-classifier [name tree]
(let [arg `row#
block (tree->if-then (eval tree) arg)]
`(defn ~name [~arg]
~block)))
;; 木の表示
(defn print-tree
([tree indent]
(if (:col tree)
(let [{col :col, v :val, tbr true, fbr false} tree]
(println col ":" v "?")
(print indent "T->") (print-tree tbr (str indent "\t"))
(print indent "F->") (print-tree fbr (str indent "\t")))
(println tree)))
([tree] (print-tree tree "\t")))
@deltam
Copy link
Author

deltam commented May 13, 2018

on REPL

user> (require '[sandbox.tree :as tr])
nil
user> (def t1 (tr/build-tree tr/my-data))
#'user/t1
user> (tr/print-tree t1)
0 : google ?
     T->3 : 21 ?
         T->{Premium 3}
         F->2 : no ?
             T->{None 1}
             F->{Basic 1}
     F->0 : slashdot ?
         T->{None 3}
         F->2 : yes ?
             T->{Basic 4}
             F->3 : 21 ?
                 T->{Basic 1}
                 F->{None 3}
nil
user> (tr/classify t1 ["slashdot" "USA" "yes" 18])
{"None" 3}
user> (let [row ["slashdot" "USA" "yes" 18]]
        (eval (tr/tree->if-then t1 'row)))
{"None" 3}
user> (clojure.pprint/pprint (tr/tree->if-then t1 'row))
(if
 (#function[clojure.core/=] (row 0) "google")
 (if
  (#function[clojure.core/>=] (row 3) 21)
  {"Premium" 3}
  (if (#function[clojure.core/=] (row 2) "no") {"None" 1} {"Basic" 1}))
 (if
  (#function[clojure.core/=] (row 0) "slashdot")
  {"None" 3}
  (if
   (#function[clojure.core/=] (row 2) "yes")
   {"Basic" 4}
   (if
    (#function[clojure.core/>=] (row 3) 21)
    {"Basic" 1}
    {"None" 3}))))
nil
user> (tr/def-classifier cls1 t1)
#'user/cls1
user> (cls1 ["slashdot" "USA" "yes" 18])
{"None" 3}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment