Skip to content

Instantly share code, notes, and snippets.

@Drag13
Created February 28, 2024 18:53
Show Gist options
  • Save Drag13/4823312832209fc3c8c39b6db48dc5b4 to your computer and use it in GitHub Desktop.
Save Drag13/4823312832209fc3c8c39b6db48dc5b4 to your computer and use it in GitHub Desktop.
const TEST_DATA = [
["OUTLOOK", "TEMPERATURE", "HUMIDITY", "WIND", "PLAY"],
["SUNNY", "HOT", "HIGH", "WEAK", "NO"],
["SUNNY", "HOT", "HIGH", "STRONG", "NO"],
["OVERCAST", "HOT", "HIGH", "WEAK", "YES"],
["RAIN", "MILD", "HIGH", "WEAK", "YES"],
["RAIN", "COOL", "NORMAL", "WEAK", "YES"],
["RAIN", "COOL", "NORMAL", "STRONG", "NO"],
["OVERCAST", "COOL", "NORMAL", "STRONG", "YES"],
["SUNNY", "MILD", "HIGH", "WEAK", "NO"],
["SUNNY", "COOL", "NORMAL", "WEAK", "YES"],
["RAIN", "MILD", "NORMAL", "WEAK", "YES"],
["SUNNY", "MILD", "NORMAL", "STRONG", "YES"],
["OVERCAST", "MILD", "HIGH", "STRONG", "YES"],
["OVERCAST", "HOT", "NORMAL", "WEAK", "YES"],
["RAIN", "MILD", "HIGH", "STRONG", "NO"],
];
const TEST_DATA_BODY = TEST_DATA.slice(1);
function calculateEntropy<T>(dataSet: readonly T[], value: T) {
if (!dataSet.length) {
return 0;
}
const allLabelsLength = dataSet.length;
const givenLabelLength = dataSet.filter((x) => x === value).length;
if (!givenLabelLength) {
throw new Error(`VALUE "${value}" not found in the dataset`);
}
if (givenLabelLength === dataSet.length) {
return 0;
}
const otherLabelsQuantity = allLabelsLength - givenLabelLength;
const valueEntropy =
(givenLabelLength / allLabelsLength) * Math.log2(givenLabelLength / allLabelsLength);
const otherEntropy =
(otherLabelsQuantity / allLabelsLength) * Math.log2(otherLabelsQuantity / allLabelsLength);
return -1 * (valueEntropy + otherEntropy);
}
function extractColumn(dataset: (readonly string[])[], index: number) {
return dataset.map((row) => row[index]);
}
// TEST_CASE #1
{
const dataColumn = extractColumn(TEST_DATA_BODY, 4);
console.assert(
calculateEntropy(dataColumn, "YES").toFixed(2) === "0.94",
"Calculated wrong entropy, should be 0.94"
);
}
// TEST_CASE #2
const TEST_DATA_COLUMN = ["1", "1", "2", "2"];
console.assert(
calculateEntropy(TEST_DATA_COLUMN, "1").toFixed(2) === "1.00",
"Calculated wrong entropy, should be 1"
);
function calculateColumnStat(data: string[][], index: number) {
const stats = new Map();
for (let i = 0; i < data.length; i++) {
const row = data[i];
const targetKey = row[index];
const targetValue = row[row.length - 1];
const values = stats.has(targetKey) ? stats.get(targetKey) : [];
values.push(targetValue);
stats.set(targetKey, values);
}
return [...stats.values()];
}
function calculateGain(dataset: string[][], columnIndex: number) {
const decisionColumn = extractColumn(dataset, dataset[0].length - 1);
const decisionColumnEntropy = calculateEntropy(decisionColumn, "YES");
const targetColumnStats = calculateColumnStat(dataset, columnIndex);
const length = dataset.length;
const valueEntropy = targetColumnStats.reduce((acc, v) => {
const singleEntropy = (v.length / length) * calculateEntropy(v, v[0]); //?
return acc + singleEntropy;
}, 0);
return decisionColumnEntropy - valueEntropy;
}
// TEST_CASE #3
{
const actualValue = calculateGain(TEST_DATA_BODY, 3).toFixed(3);
console.assert(actualValue === "0.048", `GAIN should be 0.048, got ${actualValue}`);
}
// TEST_CASE #4
{
const actualValue = calculateGain(TEST_DATA_BODY, 0).toFixed(4);
console.assert(actualValue === "0.2467", `GAIN should be 0.2467, got ${actualValue}`);
}
function findMostValuableAttribute(dataset: string[][]) {
const maxGain = { v: 0, index: -1 };
for (let i = 0; i < dataset[0].length - 1; i++) {
const gainFromColumn = calculateGain(dataset, i);
if (maxGain.v < gainFromColumn) {
maxGain.v = gainFromColumn;
maxGain.index = i;
}
}
return maxGain;
}
console.log(findMostValuableAttribute(TEST_DATA_BODY));
function getChoicesByIndex(dataset: string[][], index: number) {
const choices = dataset.reduce((acc, v) => {
return acc.add(v[index]);
}, new Set());
return [...choices];
}
function splitByColumnIndex(dataset: string[][], index: number) {
const map = new Map<string, string[][]>();
for (let i = 0; i < dataset.length; i++) {
const row = dataset[i];
const key = row[index];
const subs = map.get(key) ?? [];
subs.push(row);
map.set(key, subs);
}
return map;
}
function sameValue(dataset: string[][], index: number) {
if (!dataset.length) {
return true;
}
const value = dataset[0][index];
return dataset.every((r) => r[index] === value);
}
class DTreeNode {
decision: string | undefined;
attribute: string | undefined;
branches: Map<string, DTreeNode> | undefined;
}
class DTree {
_rootNode = new DTreeNode();
_head: string[] | undefined = undefined;
train(dataset: string[][]) {
this._head = dataset[0];
const body = dataset.slice(1);
this._rootNode = this.buildTree(body);
}
predict(data: string[]) {
let i = 0;
let node = this._rootNode;
while (node.decision == null && i < 100_000) {
const attr = node.attribute;
const valueIndex = this._head!.findIndex((x) => x === attr);
const value = data[valueIndex];
node = node.branches?.get(value)!;
i++;
}
return { decision: node.decision, i };
}
private buildTree(dataset: string[][]): DTreeNode {
const node = new DTreeNode();
const hasDecision = sameValue(dataset, dataset[0].length - 1);
if (hasDecision) {
const firstRow = dataset[0];
node.decision = firstRow[firstRow.length - 1];
return node;
}
const mostValuableColumnIndex = findMostValuableAttribute(dataset);
const choices = splitByColumnIndex(dataset, mostValuableColumnIndex.index);
node.attribute = this._head![mostValuableColumnIndex.index];
node.branches = [...choices.entries()].reduce((acc, v) => {
const [key, subset] = v;
acc.set(key, this.buildTree(subset));
return acc;
}, new Map());
return node;
}
}
{
const testTree = new DTree();
testTree.train(TEST_DATA);
console.log(testTree._rootNode);
const REAL_DATA = ["Sunny", "Cool", "High", "Strong"].map((x) => x.toUpperCase());
const res = testTree.predict(REAL_DATA);
console.log(res);
}
const DATA = [
["status", "tired", "holidays", "have new game", "decision"],
["clean", "no", "no", "yes", "no"],
["light", "yes", "no", "yes", "no"],
["light", "no", "no", "yes", "no"],
["moderate", "yes", "no", "yes", "no"],
["moderate", "no", "no", "yes", "yes"],
["heavy", "yes", "no", "yes", "yes"],
["heavy", "yes", "yes", "yes", "yes"],
["heavy", "no", "no", "no", "yes"],
].map((x) => x.map((x) => x.toUpperCase()));
{
const testTree = new DTree();
testTree.train(DATA);
const REAL_DATA = ["heavy", "no", "yes", "yes"].map((x) => x.toUpperCase());
const res = testTree.predict(REAL_DATA);
console.log(res);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment