Skip to content

Instantly share code, notes, and snippets.

@claymcleod
Last active August 29, 2015 14:09
Show Gist options
  • Save claymcleod/bd3d729c6dc1bf95afe2 to your computer and use it in GitHub Desktop.
Save claymcleod/bd3d729c6dc1bf95afe2 to your computer and use it in GitHub Desktop.
--[[ Arithmetic Expression Tree Program Skeleton
Recursive Function Version with Record-style Nodes
H. Conrad Cunningham, Professor
Computer and Information Science
University of Mississippi
Developed for CSci 658, Software Language Engineering, Fall 2013
1234567890123456789012345678901234567890123456789012345678901234567890
2013-09-03: Modified Recursive Function List-node and Object-based
versions to create this version
2013-09-07: Changed "type" field of table to be "tag"
2014-11-11: Corrected typo, added isExp structure checking function,
and restructured similar to newest Recursive Function
List version
--]]
--[[ ARITHMETIC EXPRESSION TREES
This program represents an arithmetic expression tree by a
record-style table with a type tag field and one or more named
fields to hold the operands.
--]]
local M = {}
-- Constants for tree node type tags
local CONST_TYPE, CONST_STR = "Const", "Const"
local VAR_TYPE, VAR_STR = "Var", "Var"
local SUM_TYPE, SUM_STR = "Sum", "Sum"
local SUB_TYPE, SUB_STR = "Sub", "Sub"
local PROD_TYPE, PROD_STR = "Prod", "Prod"
local DIV_TYPE, DIV_STR = "Div", "Div"
local NEG_TYPE, NEG_STR = "Neg", "Neg"
local SIN_TYPE, SIN_STR = "Sine", "Sine"
local COS_TYPE, COS_STR = "Cos", "Cos"
-- Checking for valid expressions
local tags = { [CONST_TYPE] = true, [VAR_TYPE] = true,
[SUM_TYPE] = true , [SUB_TYPE] = true,
[PROD_TYPE] = true, [DIV_TYPE] = true,
[NEG_TYPE] = true, [SIN_TYPE] = true,
[COS_TYPE] = true}
function isExp(t) -- note tags[t.tag] different in handout
return type(t) == "table" and tags[t.tag] ~= nil
end
-- Tree node constructor functions
function M.makeConst(v)
if type(v) == "number" then
return { tag = CONST_TYPE, value = v }
else
error("M.makeConst called with nonumeric value field: " ..
tostring(v), 2)
end
end
function M.makeVar(n)
if type(n) == "string" then
return { tag = VAR_TYPE, name = n }
else
error("makeVar called with nonstring name argument: " ..
tostring(n), 2)
end
end
function M.makeNeg(n)
if isExp(n) then
return { tag = NEG_TYPE, node = n }
else
error("M.makeNeg called with nonumeric value field: " ..
tostring(v), 2)
end
end
function M.makeSine(n)
if isExp(n) then
return { tag = SIN_TYPE, node = n }
else
error("M.makeSine called with nonumeric value field: " ..
tostring(v), 2)
end
end
function M.makeCos(n)
if isExp(n) then
return { tag = COS_TYPE, node = n }
else
error("M.makeCos called with nonumeric value field: " ..
tostring(v), 2)
end
end
function M.makeSum(l,r)
if isExp(l) then
if isExp(r) then
return { tag = SUM_TYPE, left = l, right = r }
else
error("Second argument of M.makeSum is not a valid expression: "
.. tostring(r), 2)
end
else
error("First argument of M.makeSum is not a valid expression: " ..
tostring(l), 2)
end
end
function M.makeSub(l,r)
if isExp(l) then
if isExp(r) then
return { tag = SUB_TYPE, left = l, right = r }
else
error("Second argument of M.makeSub is not a valid expression: "
.. tostring(r), 2)
end
else
error("First argument of M.makeSub is not a valid expression: " ..
tostring(l), 2)
end
end
function M.makeProd(l,r)
if isExp(l) then
if isExp(r) then
return { tag = PROD_TYPE, left = l, right = r }
else
error("Second argument of M.makeProd is not a valid expression: "
.. tostring(r), 2)
end
else
error("First argument of M.makeProd is not a valid expression: " ..
tostring(l), 2)
end
end
function M.makeDiv(l,r)
if isExp(l) then
if isExp(r) then
return { tag = DIV_TYPE, left = l, right = r }
else
error("Second argument of M.makeDiv is not a valid expression: "
.. tostring(r), 2)
end
else
error("First argument of M.makeDiv is not a valid expression: " ..
tostring(l), 2)
end
end
-- Constant tree node singletons
local CONST_ZERO = M.makeConst(0)
local CONST_ONE = M.makeConst(1)
-- Function "M.eval" evaluates expression tree "t" in environment
-- "env". It checks the operator (first element of "t") to determine
-- what actions to take.
function M.eval(t,env)
if isExp(t) then
if type(env) == "table" then
if t.tag == SUM_TYPE then
return M.eval(t.left,env) + M.eval(t.right,env)
elseif t.tag == SUB_TYPE then
return M.eval(t.left,env) - M.eval(t.right,env)
elseif t.tag == PROD_TYPE then
return M.eval(t.left,env) * M.eval(t.right,env)
elseif t.tag == DIV_TYPE then
return M.eval(t.left,env) / M.eval(t.right,env)
elseif t.tag == VAR_TYPE then
return env[t.name]
elseif t.tag == NEG_TYPE then
return (-1) * M.eval(t.node,env)
elseif t.tag == SIN_TYPE then
return math.sin(M.eval(t.node,env))
elseif t.tag == COS_TYPE then
return math.cos(M.eval(t.node,env))
elseif t.tag == CONST_TYPE then
return t.value
else
error("M.eval called with unknown tree type tag: " ..
tostring(t.tag), 2)
end
else
error("M.eval called with invalid environment argument: " ..
tostring(env), 2)
end
else
error("M.eval called with invalid expression argument: " ..
tostring(t), 2)
end
end
-- Function "derive" takes an arithmetic expression tree "t" and a
-- variable "v" and returns the derivative, another arithmetic
-- expression tree.
function M.derive(t,v)
if type(t) == "table" and t.tag then
if type(v) == "string" then
if t.tag == SUM_TYPE then
return M.makeSum(M.derive(t.left,v), M.derive(t.right,v))
elseif t.tag == SUB_TYPE then
return M.makeSub(M.derive(t.left,v), M.derive(t.right,v))
elseif t.tag == PROD_TYPE then
return M.makeSum(M.makeProd(t.left, M.derive(t.right,v)), M.makeProd(t.right, M.derive(t.left,v)))
elseif t.tag == DIV_TYPE then
g = t.left
h = t.right
return M.makeDiv(M.makeSub(M.makeProd(M.derive(g,v),h),M.makeProd(M.derive(h,v),g)), M.makeProd(h,h))
elseif t.tag == SIN_TYPE then
return M.makeCos(t.node)
elseif t.tag == COS_TYPE then
return M.makeNeg(M.makeSine(t.node))
elseif t.tag == VAR_TYPE then
if v == t.name then
return CONST_ONE
else
return CONST_ZERO
end
elseif t.tag == CONST_TYPE or t.tag == NEG_TYPE then
return CONST_ZERO
else
error("M.derive called with unknown tree type tag: " ..
tostring(t.tag), 2)
end
else
error("M.derive called with invalid variable: " ..
tostring(v), 2)
end
else
error("M.derive called with invalid expression argument: " ..
tostring(t), 2)
end
end
-- Function "valToString" takes an arithmetic expression tree "t" and
-- returns a string representation of the expression tree.
function M.valToString(t)
if isExp(t) then
if t.tag == SUM_TYPE then
return SUM_STR .. "(" .. M.valToString(t.left) .. ","
.. M.valToString(t.right) .. ")"
elseif t.tag == SUB_TYPE then
return SUB_STR .. "(" .. M.valToString(t.left) .. ","
.. M.valToString(t.right) .. ")"
elseif t.tag == PROD_TYPE then
return PROD_STR .. "(" .. M.valToString(t.left) .. ","
.. M.valToString(t.right) .. ")"
elseif t.tag == DIV_TYPE then
return DIV_STR .. "(" .. M.valToString(t.left) .. ","
.. M.valToString(t.right) .. ")"
elseif t.tag == VAR_TYPE then
return VAR_STR .. "(" .. t.name .. ")"
elseif t.tag == CONST_TYPE then
return CONST_STR .. "(" .. tostring(t.value) .. ")"
elseif t.tag == NEG_TYPE then
return NEG_STR .. "(" .. tostring(M.valToString(t.node)) .. ")"
elseif t.tag == SIN_TYPE then
return SIN_STR .. "(" .. tostring(M.valToString(t.node)) .. ")"
elseif t.tag == COS_TYPE then
return COS_STR .. "(" .. tostring(M.valToString(t.node)) .. ")"
else
error("M.valToString called with unknown tree type tag: " ..
tostring(t.tag), 2)
end
else
error("M.valToString called with invalid expression: " ..
tostring(t), 2)
end
end
-- Simplify
function M.simplify(t)
if isExp(t) then
if t.tag == CONST_TYPE or t.tag == VAR_TYPE then
return t
elseif t.tag == NEG_TYPE then
node = M.simplify(t.node)
if node.tag == CONST_TYPE then
return M.makeConst((-1) * (node.value))
else
return node
end
elseif t.tag == SIN_TYPE then
node = M.simplify(t.node)
if node.tag == CONST_TYPE then
return M.makeConst(math.sin(node.value))
else
return node
end
elseif t.tag == COS_TYPE then
node = M.simplify(t.node)
if node.tag == CONST_TYPE then
return M.makeConst(math.cos(node.value))
else
return node
end
else
local simplified_left = M.simplify(t.left)
local simplified_right = M.simplify(t.right)
if simplified_left.tag == CONST_TYPE and simplified_right.tag == CONST_TYPE then
-- simplifiable
if t.tag == SUM_TYPE then
return M.makeConst(simplified_left.value + simplified_right.value)
elseif t.tag == SUB_TYPE then
return M.makeConst(simplified_left.value - simplified_right.value)
elseif t.tag == PROD_TYPE then
return M.makeConst(simplified_left.value * simplified_right.value)
elseif t.tag == DIV_TYPE then
return M.makeConst(simplified_left.value / simplified_right.value)
else
error("M.simplify called with unknown tree type tag: " ..
tostring(t.tag), 2)
end
else
-- not simplifiable
if t.tag == SUM_TYPE then
return M.makeSum(simplified_left,simplified_right)
elseif t.tag == SUB_TYPE then
return M.makeSub(simplified_left,simplified_right)
elseif t.tag == PROD_TYPE then
return M.makeProd(simplified_left,simplified_right)
elseif t.tag == DIV_TYPE then
return M.makeDiv(simplified_left,simplified_right)
else
error("M.simplify called with unknown tree type tag: " ..
tostring(t.tag), 2)
end
end
end
end
end
return M
local m = require 'exprRecFuncRecord'
print("\n---\nEvaluating Simple Expressions\n---\n")
local env = { x = 5, y = 7 }
local add_exp = m.makeSum(m.makeConst(5), m.makeNeg(m.makeConst(3)))
print(m.valToString(add_exp) .. ' => ' .. m.eval(add_exp, env))
local exp_one = m.makeSub(m.makeCos(m.makeConst(0)), m.makeSine(m.makeConst(0)))
print(m.valToString(exp_one) .. ' => ' .. m.eval(exp_one, env))
local exp_two = m.makeProd(m.makeNeg(m.makeConst(5)), m.makeNeg(m.makeConst(5)))
print(m.valToString(exp_two) .. ' => ' .. m.eval(exp_two, env))
local exp_three = m.makeDiv(m.makeNeg(m.makeConst(5)), m.makeConst(5))
print(m.valToString(exp_three) .. ' => ' .. m.eval(exp_three, env))
-- MAIN PROGRAM
print("\n---\nTesting Given Functions\n---\n")
local exp = m.makeSum( m.makeSum(m.makeConst(7),m.makeVar("x")),
m.makeProd(m.makeConst(3), m.makeConst(9)) )
print("Expression: " .. m.valToString(exp))
print("Evaluation with x=5, y=7: " .. m.eval(exp,env))
print("Derivative relative to x:\n " ..
m.valToString(m.derive(exp, "x")))
print("Derivative relative to y:\n " ..
m.valToString(m.derive(exp, "y")))
local derive_exp = m.makeCos(m.makeVar("x"))
print("Derivative of " .. m.valToString(derive_exp) .. " => " .. m.valToString(m.derive(derive_exp, "x")))
print("\n---\nTesting Simplify Functions\n---\n")
local simplify_exp = m.makeSum(m.makeConst(1),m.makeConst(3))
print(m.valToString(simplify_exp)..' => '..m.valToString(m.simplify(simplify_exp)))
print(m.valToString(exp_one)..' => '..m.valToString(m.simplify(exp_one)))
print(m.valToString(exp_two)..' => '..m.valToString(m.simplify(exp_two)))
local simplify_exp_two = m.makeSum(m.makeVar("x"),m.makeConst(3))
print(m.valToString(simplify_exp_two)..' => '..m.valToString(m.simplify(simplify_exp_two)))
print(m.valToString(exp)..' => '..m.valToString(m.simplify(exp)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment