Skip to content

Instantly share code, notes, and snippets.

@vermorel
Last active July 31, 2017 14:44
Show Gist options
  • Save vermorel/b2f413a0b546931a507a9cedc7322b9e to your computer and use it in GitHub Desktop.
Save vermorel/b2f413a0b546931a507a9cedc7322b9e to your computer and use it in GitHub Desktop.
Regression. A monolithic random forest C# implementation, categorical variables treated as ordinals
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using NUnit.Framework;
namespace Lokad
{
/// <summary>
/// A random forest tailored for regression, treading categorical variables as on ordinals.
/// </summary>
public static class RandomForestR
{
/// <summary>
/// Wrap-up for performing two tasks: first build the random forest based
/// on the labeled <see cref="instances"/>. Then classify the <see cref="unlabeled"/> ones.
/// Hence, all the parameters apart from <see cref="unlabeled"/> are the ones descrinbed in
/// the corresponding other Regress method.
/// </summary>
/// <returns>An array with one line per instance and one column per tree.</returns>
public static int[][] Regress(
RFRInternal.FeatureType[] features,
ushort[][] instances,
int[] labels,
ushort[][] unlabeled,
int treeCount = 100,
int maxDepth = 20,
int degreeOfParallelism = 1)
{
var trees = Regress(
features: features,
instances: instances,
labels: labels,
treeCount: treeCount,
maxDepth: maxDepth,
degreeOfParallelism: degreeOfParallelism);
var results = new int[unlabeled.Length][];
for (int i = 0; i < unlabeled.Length; i++)
{
results[i] = trees.Select(t => t.Regress(unlabeled[i])).ToArray();
}
return results;
}
/// <summary>
/// Builds the random forest for regression.
/// </summary>
/// <remarks>
/// Each tree is grown with a fraction of instances given by <see cref="RFRInternal.InstanceFraction"/>
/// and a fraction of features given by <see cref="RFRInternal.FeatureFraction"/>.
/// </remarks>
public static RFRInternal.Tree[] Regress(
RFRInternal.FeatureType[] features,
ushort[][] instances,
int[] labels,
int treeCount = 100,
int maxDepth = 20,
int degreeOfParallelism = 1)
{
var instanceSampleSize = (int)(labels.Length * RFRInternal.InstanceFraction);
var featureSampleSize = (int)((features.Length + 1) * RFRInternal.FeatureFraction);
var seed = 42;
var trees = BuildForest(
classify: false,
features: features,
instances: instances,
labels: labels,
instanceSampleCount: instanceSampleSize,
featureSampleCount: featureSampleSize,
treeCount: treeCount,
maxDepth: maxDepth,
seed: seed,
degreeOfParallelism: degreeOfParallelism);
return trees;
}
private static RFRInternal.Tree[] BuildForest(
bool classify,
RFRInternal.FeatureType[] features,
ushort[][] instances,
int[] labels,
int instanceSampleCount,
int featureSampleCount,
int treeCount,
int maxDepth,
int seed,
int degreeOfParallelism = 1)
{
// all features can now be treated as ordinals
var mapper = RFRInternal.OrdinalMapper.Build(
ftypes: features,
instances: instances,
labels: labels,
instanceSampleSize: instanceSampleCount,
treeCount: treeCount,
seed: seed);
instances = mapper.MapInstances(instances);
var trees = Enumerable.Range(0, treeCount)
.AsParallel().AsOrdered().WithDegreeOfParallelism(degreeOfParallelism)
.Select(i =>
{
var rand = new Random(seed + i);
var sampleInstances = new ushort[instanceSampleCount][];
var sampleLabels = new int[instanceSampleCount];
// fast sampling with redundancies (redundancies are important)
for (int j = 0; j < instanceSampleCount; j++)
{
var n = rand.Next(labels.Length);
sampleInstances[j] = instances[n];
sampleLabels[j] = labels[n];
}
var node = RFRInternal.BuildNode(
classify: classify,
featureCount: features.Length,
featureSampleCount: featureSampleCount,
instances: sampleInstances,
labels: sampleLabels,
seed: rand.Next(),
depth: 0,
maxDepth: maxDepth);
return RFRInternal.BuildTree(node, mapper);
})
.ToArray();
return trees;
}
}
public static class RFRInternal
{
/// <summary>
/// The proportion of selected features for building a tree.
/// </summary>
public const float FeatureFraction = 0.5f;
/// <summary>
/// The proportion of selected examples for building a tree.
/// </summary>
public const float InstanceFraction = 0.66f;
/// <summary>
/// A minimum value for the number of samples in a leaf. If it becomes less than this value,
/// the node is considered as a terminal leaf.
/// </summary>
public const int MinSampleCount = 4;
/// <summary>
/// Defines the type of the feature: either ordinal or categorical.
/// </summary>
public enum FeatureType
{
Ordinal = 0,
Categorical = 1,
}
/// <summary>
/// A structure defining a low memory Node (comparing the <see cref="Node"/> class).
/// Essentially it is used in flattened trees represented as a table of nodes, which
/// allows for giving references to child nodes with integers
/// </summary>
public struct CompactNode
{
/// <summary>
/// The feature split at this node
/// </summary>
public ushort FeatureIndex { get; }
public int Left { get; }
public ushort Threshold { get; }
public int Right => Left + 1;
/// <summary>
/// A leaf is a terminal node.
/// </summary>
public bool IsLeaf => FeatureIndex == ushort.MaxValue;
public int Value => Left;
public CompactNode(ushort featureIndex, int left, ushort threshold)
{
FeatureIndex = featureIndex;
Left = left;
Threshold = threshold;
}
public static CompactNode Leaf(int value)
{
return new CompactNode(featureIndex: ushort.MaxValue, left: value, threshold: ushort.MaxValue);
}
}
public class Tree
{
private readonly OrdinalMapper _mapper;
private readonly CompactNode[] _nodes;
public Tree(OrdinalMapper mapper, CompactNode[] nodes)
{
_mapper = mapper;
_nodes = nodes;
}
public int Regress(ushort[] instance)
{
var node = _nodes[0];
while (!node.IsLeaf)
{
var v = _mapper.GetOrdinal(node.FeatureIndex, instance[node.FeatureIndex]);
node = v <= node.Threshold ? _nodes[node.Left] : _nodes[node.Right];
}
return node.Value;
}
}
public static Tree BuildTree(Node root, OrdinalMapper mapper)
{
var list = new List<CompactNode>();
var queue = new Queue<Node>();
queue.Enqueue(root);
while (queue.Count > 0)
{
var x = queue.Dequeue();
var cn = x.IsLeaf ?
CompactNode.Leaf(x.Value) :
new CompactNode(
featureIndex: (ushort)x.FeatureIndex,
left: list.Count + 1 + queue.Count,
threshold: x.Threshold
);
list.Add(cn);
if (!x.IsLeaf)
{
queue.Enqueue(x.Left);
queue.Enqueue(x.Right);
}
}
return new Tree(mapper, list.ToArray());
}
public class Node
{
public bool IsLeaf => Left == null;
public int FeatureIndex { get; }
/// <summary> Inclusive left. </summary>
public ushort Threshold { get; }
public int Value { get; }
public Node Left { get; }
public Node Right { get; }
public Node(int value)
{
Value = value;
}
public Node(int featureIndex, ushort threshold, Node left, Node right)
{
FeatureIndex = featureIndex;
Threshold = threshold;
Left = left;
Right = right;
}
}
public class OrdinalMapper
{
/// <summary>Array with first coordinate the feature index and second
/// coordinate the corresponding ordinal feature.
/// Nulls for all ordinal features which don't need mappings. </summary>
private readonly ushort[][] _map;
/// <summary>
/// Array of <see cref="FeatureType"/> defining the type of feature of each line of
/// <see cref="_map"/>.
/// </summary>
private readonly FeatureType[] _features;
OrdinalMapper(FeatureType[] features, ushort[][] map)
{
_features = features;
_map = map;
}
/// <summary>
/// Returns the ordinal-equivalent feature. If the feature is already
/// ordinal, it simply returns it. Otherwise, it uses the <see cref="OrdinalMapper"/> map.
/// </summary>
public ushort GetOrdinal(int featureIndex, ushort feature)
{
if (_features[featureIndex] == FeatureType.Ordinal)
{
return feature;
}
var map = _map[featureIndex];
// edge case: unknown instance, returning the median rank
if (map.Length <= feature)
{
return map[map.Length / 2];
}
return _map[featureIndex][feature];
}
/// <summary>
/// Builds a mapper from the orginal problem to a problem with only ordinal features.
/// It only changes the categorical ones.
/// </summary>
public static OrdinalMapper Build(
FeatureType[] ftypes,
ushort[][] instances,
int[] labels,
int instanceSampleSize,
int treeCount,
int seed)
{
var cats = ftypes
.Select(
(v, i) => new { V = v, I = i }
)
.Where(tu => tu.V == FeatureType.Categorical)
.Select(tu => tu.I)
.ToArray();
var max = new int[cats.Length];
for (int i = 0; i < instances.Length; i++)
{
var instance = instances[i];
for (int j = 0; j < cats.Length; j++)
{
max[j] = Math.Max(max[j], instance[cats[j]]);
}
}
var rankSums = new int[cats.Length][];
var outputSums = new int[cats.Length][];
for (int j = 0; j < cats.Length; j++)
{
rankSums[j] = new int[max[j] + 1];
outputSums[j] = new int[max[j] + 1];
}
for (int i = 0; i < treeCount; i++)
{
var rand = new Random(seed + i);
for (int j = 0; j < cats.Length; j++)
{
Array.Clear(outputSums[j], 0, outputSums[j].Length);
}
for (int k = 0; k < instanceSampleSize; k++)
{
var n = rand.Next(labels.Length);
var instance = instances[n];
var label = labels[n];
for (int j = 0; j < cats.Length; j++)
{
var instancej = instance[cats[j]];
var outputSumj = outputSums[j];
outputSumj[instancej] += label;
}
}
for (int j = 0; j < cats.Length; j++)
{
var rankj = outputSums[j].ToRank();
var rankSumj = rankSums[j];
for (int k = 0; k < rankj.Length; k++)
{
rankSumj[k] += rankj[k];
}
}
}
var maps = new ushort[ftypes.Length][];
for (int j = 0; j < cats.Length; j++)
{
var mapj = maps[cats[j]] = new ushort[rankSums[j].Length];
var rankj = rankSums[j].ToRank();
for (int k = 0; k < rankj.Length; k++)
{
mapj[k] = (ushort)rankj[k];
}
}
return new OrdinalMapper(ftypes, maps);
}
public ushort[][] MapInstances(ushort[][] instances)
{
if (_features.All(v => v == FeatureType.Ordinal))
{
return instances;
}
// rewritting the instances
var mapped = new ushort[instances.Length][];
for (int i = 0; i < instances.Length; i++)
{
var instancei = instances[i];
var mappedi = mapped[i] = new ushort[instancei.Length];
for (int j = 0; j < instancei.Length; j++)
{
mappedi[j] = GetOrdinal(j, instancei[j]);
}
}
return mapped;
}
}
public static Node BuildNode(
bool classify,
int featureCount,
int featureSampleCount,
ushort[][] instances,
int[] labels,
int seed,
int depth,
int maxDepth)
{
if (instances.Length == 0)
{
return new Node(0); // very degenerate case
}
var maxLabel = labels.Max();
var minLabel = labels.Min();
// if there is only one label left, then return a leaf
if (maxLabel == minLabel)
{
return new Node(maxLabel);
}
// if labels are too few, or if we are too deep, then pick a leaf at random from the labels
var rand = new Random(seed);
if (labels.Length < MinSampleCount || depth >= maxDepth)
{
return new Node(labels[rand.Next(labels.Length)]);
}
// Select the subset of features used for this tree
var featureSample = rand.NextNoDupplicate(featureCount, featureSampleCount);
// A pair is a tuple (Instance [32 bits], Label [32 bits]) which is
// represented as a 'ulong[]' in order to speed-up both 'Regress' and 'Classify'
// because a sort is involved.
var pairs = new ulong[labels.Length];
var instance = new ushort[instances.Length];
// recycling 'instance' and 'pair' over all features
var splits = new Split[featureSampleCount];
for (int i = 0; i < featureSampleCount; i++)
{
var v = featureSample[i];
instance.Fill(instances, v);
if (IsDegenerate(instance))
{
splits[i] = new Split(v);
}
else
{
splits[i] = classify
? ClassifyOrdinal(v, instance, labels, pairs, maxLabel)
: RegressOrdinal(v, instance, labels, pairs);
}
}
var split = splits.ArgMin();
if (split.Degenerate)
{
return new Node(labels[rand.Next(labels.Length)]);
}
instance.Fill(instances, split.FeatureIndex);
var leftInstances = new ushort[split.LeftCount][];
var rightInstances = new ushort[instances.Length - split.LeftCount][];
var leftLabels = new int[split.LeftCount];
var rightLabels = new int[instances.Length - split.LeftCount];
for (int i = 0, j = 0, k = 0; i < instances.Length; i++)
{
if (instance[i] <= split.Threshold)
{
leftInstances[j] = instances[i];
leftLabels[j] = labels[i];
j++;
}
else
{
rightInstances[k] = instances[i];
rightLabels[k] = labels[i];
k++;
}
}
var left = BuildNode(classify, featureCount, featureSampleCount, leftInstances, leftLabels, rand.Next(), depth + 1, maxDepth);
var right = BuildNode(classify, featureCount, featureSampleCount, rightInstances, rightLabels, rand.Next(), depth + 1, maxDepth);
return new Node(split.FeatureIndex, split.Threshold, left, right);
}
[DebuggerDisplay("FeatureIndex:{FeatureIndex} Variance:{_variance}")]
private class Split : IComparable<Split>
{
public readonly int FeatureIndex;
public readonly bool Degenerate;
private readonly float _variance;
public readonly ushort Threshold;
public readonly int LeftCount;
public Split(int featureIndex)
{
FeatureIndex = featureIndex;
Degenerate = true;
_variance = float.MaxValue;
}
public Split(int featureIndex, float variance, ushort threshold, int leftCount)
{
FeatureIndex = featureIndex;
_variance = variance;
Threshold = threshold;
LeftCount = leftCount;
}
public int CompareTo(Split other)
{
return _variance.CompareTo(other._variance);
}
}
/// <summary>
/// Degenerate feature: feature with only one constant value
/// </summary>
private static bool IsDegenerate(ushort[] instances)
{
var isConstant = true;
for (int i = 1; i < instances.Length; i++)
{
if (instances[i] != instances[0])
{
isConstant = false;
break;
}
}
return isConstant;
}
private static Split RegressOrdinal(
int featureIndex,
ushort[] instances,
int[] labels,
ulong[] pairs)
{
for (int i = 0; i < pairs.Length; i++)
{
pairs[i] = ((ulong)instances[i]) << 32;
unchecked { pairs[i] |= (uint)labels[i]; }
}
// The 'Array.Sort' is the performance bottleneck of the whole ordinal
// forest with roughly 50% of the total CPU time spent here.
// [vermorel] May 2016, micro-optimizing to uint[] sort barely speed-up
Array.Sort(pairs);
// initialization of the right and left side
double leftSum = 0.0, leftSumSq = 0.0;
double rightSum = 0.0, rightSumSq = 0.0;
for (int i = 0; i < pairs.Length; i++)
{
int li = labels[i];
rightSum += li;
rightSumSq += li * (double)li;
}
// finding the minimal variance
var minVariance = double.MaxValue;
var minVarianceIndex = -1;
for (int i = 0; i < pairs.Length - 2; i++)
{
int li;
unchecked { li = (int)((uint)(pairs[i] & 0xFFFFFFFFUL)); }
var li2 = li * (double)li;
leftSum += li;
leftSumSq += li2;
rightSum -= li;
rightSumSq -= li2;
// variance needs 2 points to be computed
if (i < 2) continue;
// variance should only computed at the instance thresholds
if ((pairs[i] >> 32) == (pairs[i + 1] >> 32)) continue;
var leftCount = i + 1;
var rightCount = pairs.Length - leftCount;
var variance =
leftCount / (leftCount - 1.0) * (leftSumSq - leftSum * leftSum / leftCount) +
rightCount / (rightCount - 1.0) * (rightSumSq - rightSum * rightSum / rightCount);
if (variance < minVariance)
{
minVariance = variance;
minVarianceIndex = i;
}
}
// edge-case: no partition found
if (minVarianceIndex == -1)
{
return new Split(featureIndex);
}
return new Split(featureIndex,
(float)minVariance, (ushort)(pairs[minVarianceIndex] >> 32), minVarianceIndex + 1);
}
private static Split ClassifyOrdinal(
int featureIndex,
ushort[] instances,
int[] labels,
ulong[] pairs,
int maxLabel)
{
for (int i = 0; i < pairs.Length; i++)
{
pairs[i] = ((ulong)instances[i]) << 32;
unchecked { pairs[i] |= (uint)labels[i]; }
}
var labelCounts = new int[maxLabel + 1];
for (int i = 0; i < labels.Length; i++)
{
labelCounts[labels[i]] += 1;
}
// no bucket sort here, ordinal values can be large
Array.Sort(pairs);
var minEntropy = double.MaxValue;
var minEntropyIndex = -1;
var partialCounts = new int[maxLabel + 1];
for (int i = 0; i < pairs.Length; i++)
{
int li;
// 0xFFFFFFFFUL is uint.MaxValue
// li selects the label part of pairs[i]
unchecked { li = (int)((uint)(pairs[i] & 0xFFFFFFFFUL)); }
partialCounts[li] += 1;
// entropy calculation only applies at instance thresholds
// otherwise, the threshold won't properly reflect the partition
if (i == pairs.Length - 1 || (pairs[i] >> 32) != (pairs[i + 1] >> 32))
{
if (i == pairs.Length - 1 && minEntropyIndex > 0) continue;
double leftEntropy = 0.0, rightEntropy = 0.0;
for (int j = 0; j <= maxLabel; j++)
{
var cj = partialCounts[j];
var labelj = labelCounts[j];
// lpj: left-probability of label j
var lpj = cj / (double)(i + 1);
if (lpj > 0 && lpj < 1)
{
leftEntropy -= lpj * Math.Log(lpj);
}
// rpj: right-probability of label j
var rpj = (labelj - cj) / (double)(labels.Length - i - 1);
if (rpj > 0 && rpj < 1)
{
rightEntropy -= rpj * Math.Log(rpj);
}
}
var entropy = ((i + 1) * leftEntropy + (labels.Length - i) * rightEntropy) / labels.Length;
if (entropy < minEntropy)
{
minEntropy = entropy;
minEntropyIndex = i;
}
}
}
// edge-case: degenerate partition
if (minEntropyIndex == -1 || minEntropyIndex + 1 == instances.Length)
{
return new Split(featureIndex);
}
return new Split(featureIndex, (float)minEntropy, (ushort)(pairs[minEntropyIndex] >> 32), minEntropyIndex + 1);
}
}
public static class ArrayHelper
{
/// <summary>
/// This method generates <see cref="sampleCount"/> random samples with no dupplicate from
/// the interval [0, <see cref="fromCount"/>-1]. It works by making a random permutation of
/// [1, <see cref="fromCount"/>] and then selecting the <see cref="sampleCount"/> first
/// elements, returning their value -1.
/// </summary>
public static int[] NextNoDupplicate(this Random random, int fromCount, int sampleCount)
{
var sample = new int[fromCount];
for (int i = 0; i < sampleCount; i++)
{
var n = random.Next(fromCount);
// lazily generating the values only swap only
// the zero has the semantic 'undefined'
var vi = sample[i];
vi = vi > 0 ? vi : i + 1;
var vn = sample[n];
vn = vn > 0 ? vn : n + 1;
sample[i] = vn;
sample[n] = vi;
}
var headSample = new int[sampleCount];
for (int i = 0; i < sampleCount; i++)
{
headSample[i] = sample[i] - 1;
}
return headSample;
}
public static T ArgMin<T>(this IReadOnlyList<T> array) where T : IComparable<T>
{
if (array.Count == 0)
throw new ArgumentOutOfRangeException(nameof(array));
var min = array[0];
for (int i = 1; i < array.Count; i++)
{
if (min.CompareTo(array[i]) > 0)
min = array[i];
}
return min;
}
/// <summary> Gets a column from a matrix stored by lines. </summary>
public static void Fill<T>(this T[] array, T[][] matrix, int column)
{
for (int i = 0; i < array.Length; i++)
{
array[i] = matrix[i][column];
}
}
struct IndexValue : IComparable<IndexValue>
{
public int Index { get; }
public int Value { get; }
public IndexValue(int index, int value)
{
Index = index;
Value = value;
}
public int CompareTo(IndexValue other)
{
if (Value == other.Value) return Index.CompareTo(other.Index);
return Value.CompareTo(other.Value);
}
}
public static int[] ToRank(this int[] array)
{
var indexed = new IndexValue[array.Length];
for (int i = 0; i < array.Length; i++)
{
indexed[i] = new IndexValue(i, array[i]);
}
Array.Sort(indexed);
var ranked = new int[array.Length];
for (int i = 0; i < array.Length; i++)
{
ranked[indexed[i].Index] = i;
}
return ranked;
}
}
public class RandomForestRTests
{
[Test]
public void Regress_nonrandom_data()
{
var random = new Random(45);
var N = 500;
var F = 20;
var L = 2;
var instances = ToUshort(
Enumerable.Range(0, N).Select(
x => Enumerable.Range(0, F).Select(
y => random.Next(x + 1)).ToArray()).ToArray());
var labels = Enumerable.Range(0, N).Select(x => (ushort)random.Next(L)).ToArray();
var unlabeled = ToUshort(
Enumerable.Range(0, N).Select(
x => Enumerable.Range(0, F).Select(
y => random.Next(x + 1)).ToArray()).ToArray());
for (int i = 0; i < labels.Length; i++)
{
instances[i][0] = labels[i]; // the column 0 fully explains the labels
}
for (int i = 0; i < unlabeled.Length; i++)
{
unlabeled[i][0] = (ushort)random.Next(L);
}
var features = Enumerable.Range(0, F).Select(
x => random.Next(2) == 0 ?
RFRInternal.FeatureType.Ordinal :
RFRInternal.FeatureType.Categorical).ToArray();
// categorical selection
features[0] = RFRInternal.FeatureType.Categorical;
var c = RandomForestR.Regress(features, instances, ToInt(labels), unlabeled, treeCount: 500);
var m = 0;
for (int i = 0; i < unlabeled.Length; i++)
{
if (c[i][0] == unlabeled[i][0]) m++;
}
Assert.Greater(m, N * 0.70);
Console.WriteLine($"Accuracy: {m / (float)N}");
// ordinal selection
features[0] = RFRInternal.FeatureType.Ordinal;
c = RandomForestR.Regress(features, instances, ToInt(labels), unlabeled);
m = 0;
for (int i = 0; i < unlabeled.Length; i++)
{
if (c[i][0] == unlabeled[i][0]) m++;
}
Assert.Greater(m, N * 0.70);
Console.WriteLine($"Accuracy: {m / (float)N}");
}
private ushort[][] ToUshort(int[][] array)
{
return array.Select(a => a.Select(x => (ushort)x).ToArray()).ToArray();
}
private int[] ToInt(ushort[] array)
{
return array.Select(a => (int) a).ToArray();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment