Created
December 7, 2021 10:37
-
-
Save yangfch3/654240f75e2df33d9632ba4acd82051b to your computer and use it in GitHub Desktop.
Lua 决策树
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- @author: yangfch3 | |
-- @date: 2020/09/27 15:20 | |
------------------------------------------ | |
-- Node 基类 | |
local DTNode = BaseClass("DTNode") | |
function DTNode:ctor() end | |
function DTNode:Eval(decisionTree, context) end | |
function DTNode:SetChildren() | |
assert(false) | |
end | |
-- 条件结点 | |
local DTCondNode = BaseClass("DTCondNode", DTNode) | |
function DTCondNode:ctor(cond) | |
self._cond = cond | |
end | |
function DTCondNode:SetChildren(...) | |
self._childrenNodes = {...} | |
end | |
function DTCondNode:Eval(decisionTree, context) | |
local nodeIdx = self._cond(decisionTree, context) | |
assert(nodeIdx and nodeIdx > 0 and nodeIdx <= #self._childrenNodes) | |
return self._childrenNodes[nodeIdx]:Eval(decisionTree, context) | |
end | |
-- 布尔结点 | |
local DTBoolNode = BaseClass("DTBoolNode", DTNode) | |
function DTBoolNode:ctor(cond) | |
self._cond = cond | |
end | |
function DTBoolNode:SetChildren(trueNode, falseNode) | |
self._trueNode = trueNode | |
self._falseNode = falseNode | |
assert(trueNode) | |
assert(falseNode) | |
end | |
function DTBoolNode:Eval(decisionTree, context) | |
local b = self._cond(decisionTree, context) | |
local node = b and self._trueNode or self._falseNode | |
return node:Eval(decisionTree, context) | |
end | |
-- 任务结点:只逻辑计算,不返回 | |
local DTTaskNode = BaseClass("DTTaskNode", DTNode) | |
function DTTaskNode:ctor(func) | |
self._func = func | |
end | |
function DTTaskNode:SetChildren(nextNode) | |
self._nextNode = nextNode | |
assert(nextNode) | |
end | |
function DTTaskNode:Eval(decisionTree, context) | |
self._func(decisionTree, context) | |
return self._nextNode:Eval(decisionTree, context) | |
end | |
-- 执行结点:决策树的终点节点 | |
local DTExeNode = BaseClass("DTExeNode", DTNode) | |
function DTExeNode:ctor(func) | |
self._func = func | |
end | |
function DTExeNode:Eval(decisionTree, context) | |
return self._func(decisionTree, context) | |
end | |
-- 手动决策树 | |
local ManualDecisionTree = BaseClass("ManualDecisionTree") | |
function ManualDecisionTree:ctor(rootNode) | |
self:SetRootNode(rootNode) | |
end | |
function ManualDecisionTree:SetRootNode(node) | |
self._rootNode = node | |
end | |
function ManualDecisionTree:Eval(context) | |
return self._rootNode:Eval(self, context) | |
end | |
return { | |
DTCondNode = DTCondNode, | |
DTBoolNode = DTBoolNode, | |
DTTaskNode = DTTaskNode, | |
DTExeNode = DTExeNode, | |
ManualDecisionTree = ManualDecisionTree | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment