Last active
March 5, 2020 12:11
-
-
Save kuczmama/9ad5075657c57f12e738ed8730cc6297 to your computer and use it in GitHub Desktop.
A decision tree written in ruby
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env ruby | |
# frozen_string_literal: true | |
require 'set' | |
require 'pry' | |
require 'csv' | |
require 'optparse' | |
require 'time' | |
# Make predictions with trees | |
class DecisionTree | |
attr_accessor :root | |
def initialize(training_data, label_name = :label) | |
@label_name = label_name | |
@root = build_tree(training_data) | |
end | |
# Decision Tree Node | |
class Node | |
attr_accessor :label, :value, :left, :right, :predictions | |
def to_s | |
result = '' | |
unless label.nil? || value.nil? | |
comparator = value.is_a?(Numeric) ? '>=' : '==' | |
result += "Is #{label} #{comparator} #{value}" | |
end | |
result += "--> Predictions: #{predictions}" unless predictions.nil? | |
result | |
end | |
end | |
def to_s | |
print_tree(@root) | |
end | |
def predict(features) | |
predict_helper(features, @root) | |
end | |
private | |
# unique_features: {:weight=>#<Set: {10, 3, 5}>, :color=>#<Set: {"Green", "Orange"}>} | |
def find_unique_features(rows) | |
unique_features = {} | |
rows.each do |row| | |
row.each do |label, v| | |
next if label == @label_name # Ignore the label | |
unique_features[label] = Set.new if unique_features[label].nil? | |
unique_features[label] << v | |
end | |
end | |
unique_features | |
end | |
def build_tree(rows) | |
best_question = find_best_question(rows) | |
info_gain = best_question[:info_gain] | |
node = Node.new | |
if info_gain.zero? | |
node.predictions = labels(rows).uniq | |
return node | |
end | |
node.label = best_question[:label] | |
node.value = best_question[:value] | |
left, right = partition(rows, best_question[:label], best_question[:value]) | |
node.left = build_tree(left) | |
node.right = build_tree(right) | |
node.label = best_question[:label] | |
node.value = best_question[:value] | |
node | |
end | |
def calc_weighted_uncertainty(left, right) | |
left_weight = left.length / (left.length + right.length).to_f | |
left_weight * gini(left) + (1 - left_weight) * gini(right) | |
end | |
def labels(rows) | |
rows.map { |row| row[@label_name.to_sym] || row[@label_name.to_s] } | |
end | |
# { label: best_question_label, value: best_question_value } | |
def find_best_question(rows) | |
best_label = nil | |
best_value = nil | |
best_gain = 0.0 | |
current_uncertainty = gini(labels(rows)) | |
find_unique_features(rows).each do |label, values| | |
values.each do |value| | |
left, right = partition(rows, label, value) | |
next if left.empty? || right.empty? | |
info_gain = current_uncertainty - calc_weighted_uncertainty(labels(left), labels(right)) | |
next unless info_gain > best_gain | |
best_gain = info_gain | |
best_label = label | |
best_value = value | |
end | |
end | |
{ label: best_label, value: best_value, info_gain: best_gain } | |
end | |
def gini(labels) | |
label_counts = {} | |
labels.each do |label| | |
label_counts[label] = 0.0 if label_counts[label].nil? | |
label_counts[label] += 1.0 | |
end | |
result = 0.0 | |
labels.each do |label| | |
result += 1.0 / labels.length * (1 - label_counts[label] / labels.length) | |
end | |
result | |
end | |
def match(value, question_value) | |
if !!question_value == question_value | |
return value == question_value | |
end # boolean | |
return value == question_value if question_value.is_a? String | |
return value >= question_value if question_value.is_a? Numeric | |
raise "typeof #{question_value.class} is not supported" | |
end | |
def partition(rows, label, question_value) | |
trues = [] | |
falses = [] | |
rows.each do |row| | |
if match(row[label], question_value) | |
trues << row | |
else | |
falses << row | |
end | |
end | |
[trues, falses] | |
end | |
def print_tree(root, spacing = '') | |
return if root.nil? | |
puts "#{spacing}#{root}" | |
if root.left | |
puts "#{spacing}-->true" | |
print_tree(root.left, "#{spacing}\t") | |
end | |
if root.right | |
puts "#{spacing}-->false:" | |
print_tree(root.right, "#{spacing}\t") | |
end | |
end | |
def predict_helper(features, root = nil) | |
return root.predictions unless root.predictions.nil? | |
question_value = root.value | |
value = features[root.label] | |
if match(value, question_value) | |
predict_helper(features, root.left) | |
else | |
predict_helper(features, root.right) | |
end | |
end | |
end | |
Options = Struct.new(:max_train_rows, :label, :verbose) | |
class Parser | |
def self.parse(options) | |
args = Options.new(nil, :label, false) | |
# TODO: add test_data | |
opt_parser = OptionParser.new do |opts| | |
opts.banner = "Usage: #{$PROGRAM_NAME} TRAINING_DATA.csv [options]" | |
opts.on('-m ROWS', '--max-train-rows=ROWS', 'Max rows to read in from the TRAINING_DATA csv, default read in all rows') do |max_train_rows| | |
args.max_train_rows = max_train_rows | |
end | |
opts.on('-h', '--help', 'Prints this help') do | |
puts opts | |
exit | |
end | |
opts.on('-v', '--verbose', 'Run verbosely') do | |
args.verbose = true | |
end | |
opts.on('-l LABEL', '--label=LABEL', 'The column name that is the dependent variable. Default \'label\'') do |label| | |
args.label = label | |
end | |
# TODO: - date labels, regression, classifier | |
end | |
opt_parser.parse!(options) | |
args | |
end | |
end | |
options = Parser.parse(ARGV.length.zero? ? %w[--help] : ARGV) | |
class String | |
def is_float? | |
to_f.to_s == self | |
end | |
def is_int? | |
to_i.to_s == self | |
end | |
def is_date? | |
!Time.parse(self).nil? | |
rescue StandardError | |
false | |
end | |
end | |
training_data = [] | |
i = 0 | |
def date_features(date_time) | |
methods = ['month', 'wday', 'yday', 'dst?', 'gmtoff', 'gmt_offset', 'utc_offset', 'utc?', 'gmt?', 'sunday?', 'tuesday?', 'monday?', 'thursday?', 'wednesday?', 'saturday?', 'friday?'] | |
features = {} | |
methods.each do |method| | |
features[method] = date_time.send(method) | |
end | |
features | |
end | |
CSV.foreach(ARGV[0], headers: true) do |csv_row| | |
break if !options.max_train_rows.nil? && i >= options.max_train_rows.to_i | |
row = {} | |
csv_row.each do |k, v| | |
# TODO: handle dates | |
row[k] = if v.nil? | |
v = 0 | |
elsif v.is_float? || v.is_int? | |
v.to_f # Regression | |
elsif v.is_date? | |
date_time = Time.parse(v) | |
features = date_features(date_time) | |
row = row.merge(features) | |
else | |
v # Classifier - string | |
end | |
end | |
i += 1 | |
training_data << row | |
end | |
# training_data = [ | |
# { weight: 10, color: 'Green', label: 'Apple' }, | |
# { weight: 10, color: 'Orange', label: 'Orange' }, | |
# { weight: 3, color: 'Green', label: 'Grape' }, | |
# { weight: 5, color: 'Green', label: 'Grape' } | |
# ] | |
# TODO: change this... | |
train_length = (training_data.length * 0.9).to_i | |
train = training_data[0...train_length] | |
test_data = training_data[train_length..-1] | |
if options[:verbose] | |
puts "Creating a decision tree with #{training_data.length} rows from #{ARGV[0]}..." | |
end | |
decision_tree = DecisionTree.new(train, options.label.to_s) | |
puts 'finding accuracy: ' | |
test_data.each do |data| | |
actual = data.delete(options.label.to_s) | |
prediction = decision_tree.predict(data)[0] | |
puts "diff: #{actual - prediction} actual: #{actual} prediction: #{decision_tree.predict(data)}" | |
end | |
puts decision_tree | |
# probability_tree = decision_tree.predict(training_data[0..1000]) | |
# binding.pry | |
# puts predict(probability_tree, bid_size: 80.0, ask_size: 200.0, previous_price: 20.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment