Last active
October 21, 2015 01:32
-
-
Save adrianseeley/3b760ff5bfc65c1d56c2 to your computer and use it in GitHub Desktop.
Mighty RF
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
var fs = require('fs'); | |
var rl = require('readline'); | |
function read_train (cb) { | |
var train_inputs = []; | |
var train_outputs = []; | |
var headers_read = false; | |
var file = rl.createInterface({input: fs.createReadStream('./train.csv')}); | |
file.on('line', function (line) { | |
if (!headers_read) { | |
headers_read = true; | |
return; | |
} | |
var line_parts = line.split(','); | |
if (line_parts.length == 1) { | |
return; | |
} | |
train_outputs.push(parseFloat(line_parts[0])); | |
var training_input_values = []; | |
for (var line_part_idx = 1; line_part_idx < line_parts.length; line_part_idx++) { | |
training_input_values.push(parseFloat(line_parts[line_part_idx])); | |
} | |
train_inputs.push(training_input_values); | |
}); | |
file.on('close', function () { | |
return cb(train_inputs, train_outputs); | |
}); | |
}; | |
function read_test (cb) { | |
console.log('reading test'); | |
var test_inputs = []; | |
var test_ids = []; | |
var test_id = 1; | |
var headers_read = false; | |
var file = rl.createInterface({input: fs.createReadStream('./test.csv')}); | |
file.on('line', function (line) { | |
if (!headers_read) { | |
headers_read = true; | |
return; | |
} | |
var line_parts = line.split(','); | |
if (line_parts.length == 1) { | |
return; | |
} | |
test_ids.push(test_id); | |
test_id++; | |
var test_input_values = []; | |
for (var line_part_idx = 0; line_part_idx < line_parts.length; line_part_idx++) { | |
test_input_values.push(parseFloat(line_parts[line_part_idx])); | |
} | |
test_inputs.push(test_input_values); | |
}); | |
file.on('close', function () { | |
console.log('read ' + test_inputs.length + ' test cases, with ' + test_inputs[0].length + ' components each'); | |
return cb(test_inputs, test_ids); | |
}); | |
}; | |
function partition_train (number_of_partitions, train_inputs, train_outputs) { | |
var train_partition_inputs = []; | |
var train_partition_outputs = []; | |
for (var partition_idx = 0; partition_idx < number_of_partitions; partition_idx++) { | |
train_partition_inputs.push([]); | |
train_partition_outputs.push([]); | |
} | |
var partition_idx = 0; | |
for (var train_idx = 0; train_idx < train_inputs.length; train_idx++) { | |
var random_idx = Math.floor(Math.random() * train_inputs.length); | |
train_partition_inputs[partition_idx].push(train_inputs[random_idx]); | |
train_partition_outputs[partition_idx].push(train_outputs[random_idx]); | |
partition_idx++; | |
if (partition_idx >= number_of_partitions) { | |
partition_idx = 0; | |
} | |
} | |
return [train_partition_inputs, train_partition_outputs]; | |
}; | |
function create_random_forest_tree_node_class_distribution (train_outputs) { | |
var class_distribution = {}; | |
for (train_idx = 0; train_idx < train_outputs.length; train_idx++) { | |
if (!class_distribution.hasOwnProperty(train_outputs[train_idx])) { | |
class_distribution[train_outputs[train_idx]] = 0; | |
} | |
class_distribution[train_outputs[train_idx]]++; | |
} | |
for (var class_key in class_distribution) { | |
class_distribution[class_key] /= train_outputs.length; | |
} | |
return class_distribution; | |
}; | |
function calculate_random_forest_tree_node_class_distribution_entropy (class_distribution) { | |
var entropy = 0; | |
for (var class_key in class_distribution) { | |
entropy += -class_distribution[class_key] * log2(class_distribution[class_key]); | |
} | |
return entropy; | |
}; | |
var log2_base = Math.log(2); | |
function log2 (value) { | |
return Math.log(value) / log2_base; | |
}; | |
function measure_random_forest_tree_node_information_gain (entropy_before, component_idx, component_value, train_inputs, train_outputs) { | |
var left_classes = {}; | |
var right_classes = {}; | |
var left_classes_count = 0; | |
var right_classes_count = 0; | |
for (var train_idx = 0; train_idx < train_inputs.length; train_idx++) { | |
if (train_inputs[train_idx][component_idx] <= component_value) { | |
if (!left_classes.hasOwnProperty(train_outputs[train_idx])) { | |
left_classes[train_outputs[train_idx]] = 0; | |
} | |
left_classes[train_outputs[train_idx]]++; | |
left_classes_count++; | |
} else { | |
if (!right_classes.hasOwnProperty(train_outputs[train_idx])) { | |
right_classes[train_outputs[train_idx]] = 0; | |
} | |
right_classes[train_outputs[train_idx]]++; | |
right_classes_count++; | |
} | |
} | |
var left_weight = left_classes_count / train_inputs.length; | |
var right_weight = right_classes_count / train_inputs.length; | |
for (var left_class in left_classes) { | |
left_classes[left_class] /= left_classes_count; | |
} | |
for (var right_class in right_classes) { | |
right_classes[right_class] /= right_classes_count; | |
} | |
var left_entropy = calculate_random_forest_tree_node_class_distribution_entropy(left_classes); | |
var right_entropy = calculate_random_forest_tree_node_class_distribution_entropy(right_classes); | |
var entropy_after = (left_entropy * left_weight) + (right_entropy * right_weight); | |
var information_gain = entropy_before - entropy_after; | |
return information_gain; | |
}; | |
function create_random_forest_tree_node (current_tree_depth, maximum_tree_depth, components_observed, train_inputs, train_outputs) { | |
var random_forest_tree_node = { | |
component_idx: null, | |
component_value: null, | |
left: null, | |
right: null, | |
class_distribution: null | |
}; | |
if (train_inputs.length == 0) { | |
// no cases to consider | |
random_forest_tree_node.class_distribution = {}; | |
return random_forest_tree_node; | |
} else if (current_tree_depth == maximum_tree_depth) { | |
// calculate class distribution for terminal node | |
random_forest_tree_node.class_distribution = create_random_forest_tree_node_class_distribution(train_outputs); | |
return random_forest_tree_node; | |
} else { | |
// calculate the entropy of the current node before the split | |
var class_distribution = create_random_forest_tree_node_class_distribution(train_outputs); | |
var entropy_before = calculate_random_forest_tree_node_class_distribution_entropy(class_distribution); | |
// find highest information gain for branch node | |
var information_gain = null; | |
for (var components_observed_idx = 0; components_observed_idx < components_observed.length; components_observed_idx++) { | |
for (train_idx = 0; train_idx < train_inputs.length; train_idx++) { | |
var current_information_gain = measure_random_forest_tree_node_information_gain(entropy_before, components_observed[components_observed_idx], train_inputs[train_idx][components_observed[components_observed_idx]], train_inputs, train_outputs); | |
if (information_gain == null || current_information_gain > information_gain) { | |
information_gain = current_information_gain; | |
random_forest_tree_node.component_idx = components_observed[components_observed_idx]; | |
random_forest_tree_node.component_value = train_inputs[train_idx][components_observed[components_observed_idx]]; | |
} | |
} | |
} | |
// create left right train inputs and outputs | |
var left_train_inputs = []; | |
var right_train_inputs = []; | |
var left_train_outputs = []; | |
var right_train_outputs = []; | |
for (train_idx = 0; train_idx < train_inputs.length; train_idx++) { | |
if (train_inputs[train_idx][random_forest_tree_node.component_idx] <= random_forest_tree_node.component_value) { | |
left_train_inputs.push(train_inputs[train_idx]); | |
left_train_outputs.push(train_outputs[train_idx]); | |
} else { | |
right_train_inputs.push(train_inputs[train_idx]); | |
right_train_outputs.push(train_outputs[train_idx]); | |
} | |
} | |
// create child nodes | |
random_forest_tree_node.left = create_random_forest_tree_node(current_tree_depth + 1, maximum_tree_depth, components_observed, left_train_inputs, left_train_outputs); | |
random_forest_tree_node.right = create_random_forest_tree_node(current_tree_depth + 1, maximum_tree_depth, components_observed, right_train_inputs, right_train_outputs); | |
return random_forest_tree_node; | |
} | |
}; | |
function create_random_forest_tree (maximum_tree_depth, number_of_components_observed, train_inputs, train_outputs) { | |
var components_observed = []; | |
for (var component_idx = 0; component_idx < number_of_components_observed; component_idx++) { | |
components_observed.push(Math.floor(Math.random() * train_inputs[0].length)); | |
} | |
return create_random_forest_tree_node(0, maximum_tree_depth, components_observed, train_inputs, train_outputs); | |
}; | |
function create_random_forest (number_of_trees, maximum_tree_depth, number_of_components_observed, train_partition_inputs, train_partition_outputs) { | |
var random_forest = []; | |
var partition_idx = 0; | |
for (var tree_idx = 0; tree_idx < number_of_trees; tree_idx++) { | |
console.log(tree_idx); | |
random_forest.push(create_random_forest_tree(maximum_tree_depth, number_of_components_observed, train_partition_inputs[partition_idx], train_partition_outputs[partition_idx])); | |
partition_idx++; | |
if (partition_idx >= train_partition_inputs.length) { | |
partition_idx = 0; | |
} | |
} | |
return random_forest; | |
}; | |
function recurse_lookup_random_forest_tree_node (node, single_test_input) { | |
if (node.class_distribution == null) { | |
if (single_test_input[node.component_idx] <= node.component_value) { | |
return recurse_lookup_random_forest_tree_node(node.left, single_test_input); | |
} else { | |
return recurse_lookup_random_forest_tree_node(node.right, single_test_input); | |
} | |
} else { | |
return node.class_distribution; | |
} | |
}; | |
function run_random_forest (random_forest, test_inputs) { | |
var test_outputs = []; | |
for (var test_idx = 0; test_idx < test_inputs.length; test_idx++) { | |
var class_estimates = {}; | |
for (var tree_idx = 0; tree_idx < random_forest.length; tree_idx++) { | |
var current_class_estimates = recurse_lookup_random_forest_tree_node(random_forest[tree_idx], test_inputs[test_idx]); | |
for (var class_estimate in current_class_estimates) { | |
if (!class_estimates.hasOwnProperty(class_estimate)) { | |
class_estimates[class_estimate] = 0; | |
} | |
class_estimates[class_estimate] += current_class_estimates[class_estimate]; | |
} | |
} | |
var largest_class_estimate = null; | |
var largest_class_estimate_at = null; | |
for (var class_estimate in class_estimates) { | |
if (largest_class_estimate == null || class_estimates[class_estimate] > largest_class_estimate) { | |
largest_class_estimate = class_estimate; | |
largest_class_estimate_at = class_estimates[class_estimate]; | |
} | |
} | |
test_outputs.push(largest_class_estimate); | |
} | |
return test_outputs; | |
}; | |
function calculate_validation_score (validation_outputs, est_validation_outputs) { | |
var score = 0; | |
for (var validation_idx = 0; validation_idx < validation_outputs.length; validation_idx++) { | |
if (validation_outputs[validation_idx] == est_validation_outputs[validation_idx]) { | |
score++; | |
} | |
} | |
return score / validation_outputs.length; | |
}; | |
function monte_carlo_random_forest (partition_low, partition_high, partition_step, number_of_trees_low, number_of_trees_high, number_of_trees_step, maximum_tree_depth_low, maximum_tree_depth_high, maximum_tree_depth_step, number_of_components_observed_low, number_of_components_observed_high, number_of_components_observed_step, train_inputs, train_outputs, validation_inputs, validation_outputs) { | |
var results = []; | |
for (var partition_idx = partition_low; partition_idx <= partition_high; partition_idx += partition_step) { | |
var train_partitions = partition_train(partition_idx, train_inputs, train_outputs); | |
var train_partition_inputs = train_partitions[0]; | |
var train_partition_outputs = train_partitions[1]; | |
for (var number_of_trees_idx = number_of_trees_low; number_of_trees_idx <= number_of_trees_high; number_of_trees_idx += number_of_trees_step) { | |
for (var maximum_tree_depth_idx = maximum_tree_depth_low; maximum_tree_depth_idx <= maximum_tree_depth_high; maximum_tree_depth_idx += maximum_tree_depth_step) { | |
for (var number_of_components_observed_idx = number_of_components_observed_low; number_of_components_observed_idx <= number_of_components_observed_high; number_of_components_observed_idx += number_of_components_observed_step) { | |
process.stdout.cursorTo(0); | |
process.stdout.write( | |
'partition: [' + partition_low + ' (' + partition_idx + ') ' + partition_high + ']' + | |
' Ntrees: [' + number_of_trees_low + ' (' + number_of_trees_idx + ') ' + number_of_trees_high + ']' + | |
' treeD: [' + maximum_tree_depth_low + ' (' + maximum_tree_depth_idx + ') ' + maximum_tree_depth_high + ']' + | |
' numCO: [' + number_of_components_observed_low + ' (' + number_of_components_observed_idx + ') ' + number_of_components_observed_high + ']'); | |
var random_forest = create_random_forest(number_of_trees_idx, maximum_tree_depth_idx, number_of_components_observed_idx, train_partition_inputs, train_partition_outputs); | |
var random_forest_validation_outputs = run_random_forest(random_forest, validation_inputs); | |
var validation_score = calculate_validation_score(validation_outputs, random_forest_validation_outputs); | |
results.push([partition_idx, number_of_trees_idx, maximum_tree_depth_idx, number_of_components_observed_idx, validation_score]); | |
write_monte_carlo_results(results, true); | |
} | |
} | |
} | |
} | |
return results; | |
}; | |
function write_test (test_ids, test_outputs) { | |
console.log('writing test to est[date].csv'); | |
var str = 'ImageId,Label\n'; | |
for (var test_idx = 0; test_idx < test_ids.length; test_idx++) { | |
str += test_ids[test_idx] + ',' + test_outputs[test_idx] + '\n'; | |
} | |
fs.writeFileSync('est' + new Date().getTime() + '.csv', str, 'utf8'); | |
return; | |
}; | |
function write_monte_carlo_results (monte_carlo_results, partial) { | |
var str = 'number_of_partitions,number_of_trees,maximum_tree_depth,number_of_components_observed,validation_score\n'; | |
for (var result_idx = 0; result_idx < monte_carlo_results.length; result_idx++) { | |
str += monte_carlo_results[result_idx].join(',') + '\n'; | |
} | |
if (!partial) { | |
console.log('writing monte carlo results to mc[date].csv'); | |
fs.writeFileSync('mc' + new Date().getTime() + '.csv', str, 'utf8'); | |
} else { | |
fs.writeFileSync('mc_partial.csv', str, 'utf8'); | |
} | |
}; | |
var cfg_number_of_train_partitions = 500; | |
var cfg_number_of_random_forest_trees = 500; | |
var cfg_maximum_random_forest_tree_depth = 100; | |
var cfg_number_of_components_observed_per_random_forest_tree = 100; | |
var cfg_validation_percent = 0.5; | |
var cfg_monte_carlo_partition_low = 25; | |
var cfg_monte_carlo_partition_high = 25; | |
var cfg_monte_carlo_partition_step = 1; | |
var cfg_monte_carlo_number_of_trees_low = 25; | |
var cfg_monte_carlo_number_of_trees_high = 25; | |
var cfg_monte_carlo_number_of_trees_step = 1; | |
var cfg_monte_carlo_maximum_tree_depth_low = 100; | |
var cfg_monte_carlo_maximum_tree_depth_high = 100; | |
var cfg_monte_carlo_maximum_tree_depth_step = 1; | |
var cfg_monte_carlo_number_of_components_observed_low = 100; | |
var cfg_monte_carlo_number_of_components_observed_high = 100; | |
var cfg_monte_carlo_number_of_components_observed_step = 1; | |
if (cfg_number_of_train_partitions > cfg_number_of_random_forest_trees) { | |
throw 'there should always be more trees than train partitions'; | |
} | |
if (cfg_monte_carlo_partition_high > cfg_monte_carlo_number_of_trees_low) { | |
//throw 'monte carlo error, there should always be more trees than train partitions'; | |
} | |
read_train(function (train_inputs, train_outputs) { | |
//train_inputs = train_inputs.splice(0, 5000); | |
//train_outputs = train_outputs.splice(0, 5000); | |
/*var validation_count = Math.floor(train_inputs.length * cfg_validation_percent); | |
var validation_inputs = train_inputs.splice(0, validation_count); | |
var validation_outputs = train_outputs.splice(0, validation_count) | |
var monte_carlo_results = monte_carlo_random_forest(cfg_monte_carlo_partition_low, cfg_monte_carlo_partition_high, cfg_monte_carlo_partition_step, cfg_monte_carlo_number_of_trees_low, cfg_monte_carlo_number_of_trees_high, cfg_monte_carlo_number_of_trees_step, cfg_monte_carlo_maximum_tree_depth_low, cfg_monte_carlo_maximum_tree_depth_high, cfg_monte_carlo_maximum_tree_depth_step, cfg_monte_carlo_number_of_components_observed_low, cfg_monte_carlo_number_of_components_observed_high, cfg_monte_carlo_number_of_components_observed_step, train_inputs, train_outputs, validation_inputs, validation_outputs); | |
write_monte_carlo_results(monte_carlo_results, false);*/ | |
var train_partitions = partition_train(cfg_number_of_train_partitions, train_inputs, train_outputs); | |
var train_partition_inputs = train_partitions[0]; | |
var train_partition_outputs = train_partitions[1]; | |
random_forest = create_random_forest(cfg_number_of_random_forest_trees, cfg_maximum_random_forest_tree_depth, cfg_number_of_components_observed_per_random_forest_tree, train_partition_inputs, train_partition_outputs); | |
read_test(function (test_inputs, test_ids) { | |
var test_outputs = run_random_forest(random_forest, test_inputs); | |
test_outputs = ret_test_outputs; | |
write_test(test_ids, test_outputs); | |
}); | |
}); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment