Skip to content

Instantly share code, notes, and snippets.

@MilesCranmer
Last active January 27, 2021 22:44
Show Gist options
  • Select an option

  • Save MilesCranmer/e7b543f38ea1c277097f5cccd29de330 to your computer and use it in GitHub Desktop.

Select an option

Save MilesCranmer/e7b543f38ea1c277097f5cccd29de330 to your computer and use it in GitHub Desktop.
Implementing SymbolicUtils.jl interface for SymbolicRegression.jl
using SymbolicUtils
mutable struct Node
#Holds operators, variables, constants in a tree
degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
val::Union{Float32, Integer, Nothing} #Either const value, or enumerates variable
constant::Bool #false if variable
op::Integer #enumerates operator (separately for degree=1,2)
l::Union{Node, Nothing}
r::Union{Node, Nothing}
Node(val::Float32) = new(0, val, true, 1, nothing, nothing)
Node(val::Integer) = new(0, val, false, 1, nothing, nothing)
Node(op::Integer, l::Node) = new(1, nothing, false, op, l, nothing)
Node(op::Integer, l::Union{Float32, Integer}) = new(1, nothing, false, op, Node(l), nothing)
Node(op::Integer, l::Node, r::Node) = new(2, nothing, false, op, l, r)
Node(op::Integer, l::Union{Float32, Integer}, r::Node) = new(2, nothing, false, op, Node(l), r)
Node(op::Integer, l::Node, r::Union{Float32, Integer}) = new(2, nothing, false, op, l, Node(r))
Node(op::Integer, l::Union{Float32, Integer}, r::Union{Float32, Integer}) = new(2, nothing, false, op, Node(l), Node(r))
end
countNodes(tree::Nothing) = 0
countNodes(tree::Node) = 1 + countNodes(tree.l) + countNodes(tree.r)
#User-defined operations
binops = (+, *, /, -)
unaops = (cos, exp)
# Equation printing:
function stringOp(op::F, tree::Node;
bracketed::Bool=false,
varMap::Union{Array{String, 1}, Nothing}=nothing)::String where {F}
if op in [+, -, *, /, ^]
l = stringTree(tree.l, bracketed=false, varMap=varMap)
r = stringTree(tree.r, bracketed=false, varMap=varMap)
if bracketed
return "$l $(string(op)) $r"
else
return "($l $(string(op)) $r)"
end
else
l = stringTree(tree.l, bracketed=true, varMap=varMap)
r = stringTree(tree.r, bracketed=true, varMap=varMap)
return "$(string(op))($l, $r)"
end
end
# Convert an equation to a string
function stringTree(tree::Node;
bracketed::Bool=false,
varMap::Union{Array{String, 1}, Nothing}=nothing)::String
if tree.degree == 0
if tree.constant
return string(tree.val)
else
if varMap == nothing
return "x$(tree.val)"
else
return varMap[tree.val]
end
end
elseif tree.degree == 1
return "$(unaops[tree.op])($(stringTree(tree.l, bracketed=true, varMap=varMap)))"
else
return stringOp(binops[tree.op], tree, bracketed=bracketed, varMap=varMap)
end
end
# Print an equation
function printTree(tree::Node; varMap::Union{Array{String, 1}, Nothing}=nothing)
println(stringTree(tree, varMap=varMap))
end
SymbolicUtils.istree(x::Node)::Bool = (x.degree > 0)
SymbolicUtils.operation(x::Node)::Function = x.degree == 1 ? unaops[x.op] : binops[x.op]
SymbolicUtils.arguments(x::Node)::Array{Node} = x.degree == 1 ? [x.l] : [x.l, x.r]
SymbolicUtils.similarterm(x::Node, f, args) = begin
nargs = length(args)
if nargs == 1
f(args[1])
elseif nargs == 2
f(args[1], args[2])
else
f(args[1], similarterm(x, f, args[2:end]))
end
end
SymbolicUtils.symtype(x::Node) = Number
SymbolicUtils.promote_symtype(f, arg_symtypes...) = Number
Base.hash(x::Node) = begin
if x.degree == 0
hash(hash(x.constant), hash(x.val))
elseif x.degree == 1
hash(hash(x.op), Base.hash(x.l))
else
hash(hash(x.op), hash(Base.hash(x.l), Base.hash(x.r)))
end
end
Base.isequal(x::Node, y::Node)::Bool = begin
if x.degree != y.degree
false
elseif x.degree == 0
(x.constant == y.constant) && (x.val == y.val)
elseif x.degree == 1
(x.op == y.op) && Base.isequal(x.l, y.l)
else
(x.op == y.op) && Base.isequal(x.l, y.l) && Base.isequal(x.r, y.r)
end
end
Base.isless(x::Node, y::Node)::Bool = begin
nx = countNodes(x)
ny = countNodes(y)
if nx < ny
true
elseif nx > ny
false
elseif x.constant && y.constant
x.val < y.val
elseif x.constant && ~y.constant
true
elseif ~x.constant && y.constant
false
else
Base.hash(x) < Base.hash(y)
end
end
for (op, f) in enumerate(map(Symbol, binops))
@eval begin
Base.$f(l::Node, r::Node) = (l.constant && r.constant) ? Node($f(l.val, r.val)) : Node($op, l, r)
Base.$f(l::Node, r::Number) = l.constant ? Node($f(l.val, r)) : Node($op, l, convert(Float32, r))
Base.$f(l::Number, r::Node) = r.constant ? Node($f(l, r.val)) : Node($op, convert(Float32, l), r)
end
end
for (op, f) in enumerate(map(Symbol, unaops))
@eval begin
Base.$f(l::Node) = l.constant ? Node($f(l.val)) : Node($op, l)
Base.$f(l::Number) = Node($f(l))
end
end
# 1.0 + x1 + 1.0
t = Node(1f0) + Node(1) + Node(1f0)
printTree(t)
t = SymbolicUtils.simplify(t)
printTree(t)
# 1.0 + x1 + 1.0 + x1*3
t = Node(1f0) + Node(1) + Node(1f0) + Node(1) * 3
printTree(t)
t = SymbolicUtils.simplify(t)
println(typeof(t))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment