Created
October 10, 2020 14:05
-
-
Save ahancock1/31d3778b4fb826eb79e6c111336bb798 to your computer and use it in GitHub Desktop.
Half Space Trees HST classifier .Net C# implemetation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class ClassifierSettings | |
{ | |
public int WindowSize { get; set; } = 250; | |
public int Estimators { get; set; } = 25; | |
public int MaxDepth { get; set; } = 15; | |
public int Features { get; set; } = 0; | |
public double MinLimit { get; set; } = 0d; | |
public double MaxLimit { get; set; } = 1d; | |
} | |
public class Classifier | |
{ | |
private readonly Node[] _nodes; | |
private readonly Random _random = new Random(42); | |
private readonly ClassifierSettings _settings; | |
private int _count; | |
public Classifier(ClassifierSettings settings) | |
{ | |
_settings = settings; | |
_nodes = new Node[_settings.Estimators]; | |
} | |
private Node Build(Limit[] limits, int depth) | |
{ | |
if (depth == _settings.MaxDepth) | |
{ | |
return new Node | |
{ | |
Type = NodeType.External, | |
Depth = depth | |
}; | |
} | |
var feature = _random.NextChoice( | |
Enumerable.Range(0, _settings.Features) | |
.ToDictionary(i => i, i => limits[i].Max - limits[i].Min)); | |
var limit = limits[feature]; | |
var max = limit.Max; | |
var min = limit.Min; | |
const double padding = 0.15d; | |
var value = _random.NextDouble( | |
limit.Min + padding * (limit.Max - limit.Min), | |
limit.Max - padding * (limit.Max - limit.Min)); | |
limit.Max = value; | |
limit.Min = min; | |
var left = Build(limits, depth + 1); | |
limit.Max = max; | |
limit.Min = value; | |
var right = Build(limits, depth + 1); | |
return new Node | |
{ | |
Left = left, | |
Right = right, | |
Type = NodeType.Internal, | |
Depth = depth, | |
Feature = feature, | |
Value = value | |
}; | |
} | |
private void Initialise() | |
{ | |
var features = _settings.Features; | |
for (var i = 0; i < _settings.Estimators; i++) | |
{ | |
var limits = Enumerable.Range(0, features) | |
.Select(_ => | |
new Limit | |
{ | |
Min = _settings.MinLimit, | |
Max = _settings.MaxLimit | |
}) | |
.ToArray(); | |
_nodes[i] = Build(limits, 0); | |
} | |
} | |
public void Fit(IData data) | |
{ | |
if (_count == 0) | |
{ | |
Initialise(); | |
} | |
foreach (var root in _nodes) | |
{ | |
foreach (var node in Path(root, n => | |
data.Features[n.Feature] < n.Value)) | |
{ | |
if (_count < _settings.WindowSize) | |
{ | |
node.RightMass++; | |
} | |
node.LeftMass++; | |
} | |
} | |
_count++; | |
if (_count % _settings.WindowSize == 0) | |
{ | |
foreach (var root in _nodes) | |
{ | |
Iterate(root, n => | |
{ | |
n.RightMass = n.LeftMass; | |
n.LeftMass = 0; | |
}); | |
} | |
} | |
} | |
private IEnumerable<Node> Path(Node node, Func<Node, bool> evaluator) | |
{ | |
while (node != null) | |
{ | |
yield return node; | |
node = evaluator.Invoke(node) ? node.Left : node.Right; | |
} | |
} | |
private void Iterate(Node node, Action<Node> action) | |
{ | |
action.Invoke(node); | |
if (node.Type == NodeType.External) | |
{ | |
return; | |
} | |
Iterate(node.Left, action); | |
Iterate(node.Right, action); | |
} | |
public double Score(IData data) | |
{ | |
var size = Math.Min(_settings.WindowSize, _count); | |
var max = _settings.Estimators * size * | |
(Math.Pow(2, _settings.MaxDepth + 1) - 1); | |
var limit = 0.1d * _settings.WindowSize; | |
var score = 0d; | |
foreach (var root in _nodes) | |
{ | |
foreach (var node in Path(root, n => | |
data.Features[n.Feature] < n.Value)) | |
{ | |
score += node.RightMass * Math.Pow(2, node.Depth); | |
if (node.RightMass < limit) | |
{ | |
break; | |
} | |
} | |
} | |
return 1 - score / max; | |
} | |
private class Limit | |
{ | |
public double Min { get; set; } | |
public double Max { get; set; } | |
} | |
private class Node | |
{ | |
public NodeType Type { get; set; } | |
public Node Left { get; set; } | |
public Node Right { get; set; } | |
public int LeftMass { get; set; } | |
public int RightMass { get; set; } | |
public int Feature { get; set; } | |
public double Value { get; set; } | |
public int Depth { get; set; } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment