Created
February 28, 2024 18:53
-
-
Save Drag13/4823312832209fc3c8c39b6db48dc5b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const TEST_DATA = [ | |
["OUTLOOK", "TEMPERATURE", "HUMIDITY", "WIND", "PLAY"], | |
["SUNNY", "HOT", "HIGH", "WEAK", "NO"], | |
["SUNNY", "HOT", "HIGH", "STRONG", "NO"], | |
["OVERCAST", "HOT", "HIGH", "WEAK", "YES"], | |
["RAIN", "MILD", "HIGH", "WEAK", "YES"], | |
["RAIN", "COOL", "NORMAL", "WEAK", "YES"], | |
["RAIN", "COOL", "NORMAL", "STRONG", "NO"], | |
["OVERCAST", "COOL", "NORMAL", "STRONG", "YES"], | |
["SUNNY", "MILD", "HIGH", "WEAK", "NO"], | |
["SUNNY", "COOL", "NORMAL", "WEAK", "YES"], | |
["RAIN", "MILD", "NORMAL", "WEAK", "YES"], | |
["SUNNY", "MILD", "NORMAL", "STRONG", "YES"], | |
["OVERCAST", "MILD", "HIGH", "STRONG", "YES"], | |
["OVERCAST", "HOT", "NORMAL", "WEAK", "YES"], | |
["RAIN", "MILD", "HIGH", "STRONG", "NO"], | |
]; | |
const TEST_DATA_BODY = TEST_DATA.slice(1); | |
function calculateEntropy<T>(dataSet: readonly T[], value: T) { | |
if (!dataSet.length) { | |
return 0; | |
} | |
const allLabelsLength = dataSet.length; | |
const givenLabelLength = dataSet.filter((x) => x === value).length; | |
if (!givenLabelLength) { | |
throw new Error(`VALUE "${value}" not found in the dataset`); | |
} | |
if (givenLabelLength === dataSet.length) { | |
return 0; | |
} | |
const otherLabelsQuantity = allLabelsLength - givenLabelLength; | |
const valueEntropy = | |
(givenLabelLength / allLabelsLength) * Math.log2(givenLabelLength / allLabelsLength); | |
const otherEntropy = | |
(otherLabelsQuantity / allLabelsLength) * Math.log2(otherLabelsQuantity / allLabelsLength); | |
return -1 * (valueEntropy + otherEntropy); | |
} | |
function extractColumn(dataset: (readonly string[])[], index: number) { | |
return dataset.map((row) => row[index]); | |
} | |
// TEST_CASE #1 | |
{ | |
const dataColumn = extractColumn(TEST_DATA_BODY, 4); | |
console.assert( | |
calculateEntropy(dataColumn, "YES").toFixed(2) === "0.94", | |
"Calculated wrong entropy, should be 0.94" | |
); | |
} | |
// TEST_CASE #2 | |
const TEST_DATA_COLUMN = ["1", "1", "2", "2"]; | |
console.assert( | |
calculateEntropy(TEST_DATA_COLUMN, "1").toFixed(2) === "1.00", | |
"Calculated wrong entropy, should be 1" | |
); | |
function calculateColumnStat(data: string[][], index: number) { | |
const stats = new Map(); | |
for (let i = 0; i < data.length; i++) { | |
const row = data[i]; | |
const targetKey = row[index]; | |
const targetValue = row[row.length - 1]; | |
const values = stats.has(targetKey) ? stats.get(targetKey) : []; | |
values.push(targetValue); | |
stats.set(targetKey, values); | |
} | |
return [...stats.values()]; | |
} | |
function calculateGain(dataset: string[][], columnIndex: number) { | |
const decisionColumn = extractColumn(dataset, dataset[0].length - 1); | |
const decisionColumnEntropy = calculateEntropy(decisionColumn, "YES"); | |
const targetColumnStats = calculateColumnStat(dataset, columnIndex); | |
const length = dataset.length; | |
const valueEntropy = targetColumnStats.reduce((acc, v) => { | |
const singleEntropy = (v.length / length) * calculateEntropy(v, v[0]); //? | |
return acc + singleEntropy; | |
}, 0); | |
return decisionColumnEntropy - valueEntropy; | |
} | |
// TEST_CASE #3 | |
{ | |
const actualValue = calculateGain(TEST_DATA_BODY, 3).toFixed(3); | |
console.assert(actualValue === "0.048", `GAIN should be 0.048, got ${actualValue}`); | |
} | |
// TEST_CASE #4 | |
{ | |
const actualValue = calculateGain(TEST_DATA_BODY, 0).toFixed(4); | |
console.assert(actualValue === "0.2467", `GAIN should be 0.2467, got ${actualValue}`); | |
} | |
function findMostValuableAttribute(dataset: string[][]) { | |
const maxGain = { v: 0, index: -1 }; | |
for (let i = 0; i < dataset[0].length - 1; i++) { | |
const gainFromColumn = calculateGain(dataset, i); | |
if (maxGain.v < gainFromColumn) { | |
maxGain.v = gainFromColumn; | |
maxGain.index = i; | |
} | |
} | |
return maxGain; | |
} | |
console.log(findMostValuableAttribute(TEST_DATA_BODY)); | |
function getChoicesByIndex(dataset: string[][], index: number) { | |
const choices = dataset.reduce((acc, v) => { | |
return acc.add(v[index]); | |
}, new Set()); | |
return [...choices]; | |
} | |
function splitByColumnIndex(dataset: string[][], index: number) { | |
const map = new Map<string, string[][]>(); | |
for (let i = 0; i < dataset.length; i++) { | |
const row = dataset[i]; | |
const key = row[index]; | |
const subs = map.get(key) ?? []; | |
subs.push(row); | |
map.set(key, subs); | |
} | |
return map; | |
} | |
function sameValue(dataset: string[][], index: number) { | |
if (!dataset.length) { | |
return true; | |
} | |
const value = dataset[0][index]; | |
return dataset.every((r) => r[index] === value); | |
} | |
class DTreeNode { | |
decision: string | undefined; | |
attribute: string | undefined; | |
branches: Map<string, DTreeNode> | undefined; | |
} | |
class DTree { | |
_rootNode = new DTreeNode(); | |
_head: string[] | undefined = undefined; | |
train(dataset: string[][]) { | |
this._head = dataset[0]; | |
const body = dataset.slice(1); | |
this._rootNode = this.buildTree(body); | |
} | |
predict(data: string[]) { | |
let i = 0; | |
let node = this._rootNode; | |
while (node.decision == null && i < 100_000) { | |
const attr = node.attribute; | |
const valueIndex = this._head!.findIndex((x) => x === attr); | |
const value = data[valueIndex]; | |
node = node.branches?.get(value)!; | |
i++; | |
} | |
return { decision: node.decision, i }; | |
} | |
private buildTree(dataset: string[][]): DTreeNode { | |
const node = new DTreeNode(); | |
const hasDecision = sameValue(dataset, dataset[0].length - 1); | |
if (hasDecision) { | |
const firstRow = dataset[0]; | |
node.decision = firstRow[firstRow.length - 1]; | |
return node; | |
} | |
const mostValuableColumnIndex = findMostValuableAttribute(dataset); | |
const choices = splitByColumnIndex(dataset, mostValuableColumnIndex.index); | |
node.attribute = this._head![mostValuableColumnIndex.index]; | |
node.branches = [...choices.entries()].reduce((acc, v) => { | |
const [key, subset] = v; | |
acc.set(key, this.buildTree(subset)); | |
return acc; | |
}, new Map()); | |
return node; | |
} | |
} | |
{ | |
const testTree = new DTree(); | |
testTree.train(TEST_DATA); | |
console.log(testTree._rootNode); | |
const REAL_DATA = ["Sunny", "Cool", "High", "Strong"].map((x) => x.toUpperCase()); | |
const res = testTree.predict(REAL_DATA); | |
console.log(res); | |
} | |
const DATA = [ | |
["status", "tired", "holidays", "have new game", "decision"], | |
["clean", "no", "no", "yes", "no"], | |
["light", "yes", "no", "yes", "no"], | |
["light", "no", "no", "yes", "no"], | |
["moderate", "yes", "no", "yes", "no"], | |
["moderate", "no", "no", "yes", "yes"], | |
["heavy", "yes", "no", "yes", "yes"], | |
["heavy", "yes", "yes", "yes", "yes"], | |
["heavy", "no", "no", "no", "yes"], | |
].map((x) => x.map((x) => x.toUpperCase())); | |
{ | |
const testTree = new DTree(); | |
testTree.train(DATA); | |
const REAL_DATA = ["heavy", "no", "yes", "yes"].map((x) => x.toUpperCase()); | |
const res = testTree.predict(REAL_DATA); | |
console.log(res); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment