Created
April 2, 2024 17:48
-
-
Save nahkd123/5a4226f5965e4beb5d58a4fd73907a52 to your computer and use it in GitHub Desktop.
450 lines math expression engine
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* (c) Tran Huu An 2024. Licensed under MIT license. | |
*/ | |
import java.io.Serial; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Set; | |
import java.util.function.DoubleBinaryOperator; | |
import java.util.function.DoubleUnaryOperator; | |
import java.util.regex.Matcher; | |
import java.util.regex.Pattern; | |
/** | |
* <p> | |
* A "simple" mathematical expression engine. Only supports {@code +}, | |
* {@code -}, {@code *} and {@code /} but it can be extended for your needs. | |
* </p> | |
* | |
* @author nahkd123 | |
* @see #compile(String) | |
* @see #evaluate(Context) | |
* @see #evaluate() | |
*/ | |
public interface Expression { | |
public static interface Context { | |
default Double getVariable(String name) { | |
return null; | |
} | |
default DoubleUnaryOperator getFunction1(String name) { | |
return null; | |
} | |
default DoubleBinaryOperator getFunction2(String name) { | |
return null; | |
} | |
} | |
public static class MergedContext implements Context { | |
private Context[] contexts; | |
public MergedContext(Context... contexts) { | |
this.contexts = contexts; | |
} | |
@Override | |
public Double getVariable(String name) { | |
for (Context context : contexts) { | |
Double variable = context.getVariable(name); | |
if (variable != null) return variable; | |
} | |
return null; | |
} | |
@Override | |
public DoubleUnaryOperator getFunction1(String name) { | |
for (Context context : contexts) { | |
DoubleUnaryOperator function1 = context.getFunction1(name); | |
if (function1 != null) return function1; | |
} | |
return null; | |
} | |
@Override | |
public DoubleBinaryOperator getFunction2(String name) { | |
for (Context context : contexts) { | |
DoubleBinaryOperator function2 = context.getFunction2(name); | |
if (function2 != null) return function2; | |
} | |
return null; | |
} | |
} | |
public static class UniverseContext implements Context { | |
@Override | |
public Double getVariable(String name) { | |
return switch (name) { | |
case "pi", "\u03C0" -> Math.PI; | |
case "e" -> Math.E; | |
case "true" -> 1d; | |
case "false" -> 0d; | |
case "random" -> Math.random(); | |
default -> null; | |
}; | |
} | |
@Override | |
public DoubleUnaryOperator getFunction1(String name) { | |
return switch (name) { | |
case "sin" -> Math::sin; | |
case "cos" -> Math::cos; | |
case "tan" -> Math::tan; | |
case "asin" -> Math::asin; | |
case "acos" -> Math::acos; | |
case "atan" -> Math::atan; | |
case "exp" -> Math::exp; | |
case "log" -> Math::log; | |
case "log10" -> Math::log10; | |
case "signum" -> Math::signum; | |
case "floor" -> Math::floor; | |
case "ceil" -> Math::ceil; | |
case "round" -> Math::round; | |
case "sqrt" -> Math::sqrt; | |
case "cbrt" -> Math::cbrt; | |
case "deg" -> Math::toDegrees; | |
case "rad" -> Math::toRadians; | |
default -> null; | |
}; | |
} | |
@Override | |
public DoubleBinaryOperator getFunction2(String name) { | |
return switch (name) { | |
case "atan2" -> Math::atan2; | |
case "pow" -> Math::pow; | |
default -> null; | |
}; | |
} | |
} | |
public static final UniverseContext UNIVERSE = new UniverseContext(); | |
public static class EvalException extends RuntimeException { | |
@Serial | |
private static final long serialVersionUID = 3019515149401799088L; | |
public EvalException(String message) { | |
super(message); | |
} | |
public EvalException(String message, Throwable cause) { | |
super(message, cause); | |
} | |
} | |
/** | |
* <p> | |
* Evaluate the expression, using your own context. You can provide your own | |
* functions and variables by implementing the context. Please note that | |
* functions like {@code sin()} will not be available unless you provided it | |
* yourself. | |
* </p> | |
* | |
* @param context The evaluation context. | |
* @return The evaluation result. | |
* @throws EvalException if the expression can't be evaluated because of unknown | |
* variable or function. | |
*/ | |
public double evaluate(Context context) throws EvalException; | |
/** | |
* <p> | |
* Evaluate the expression, using {@link #UNIVERSE} context. The | |
* {@link #UNIVERSE} context contains all commonly used functions and variables, | |
* like {@code sin()}, {@code random} or {@code pi} (note that {@code random} is | |
* a variable, not a function). | |
* </p> | |
* | |
* @return The evaluation result. | |
* @throws EvalException if the expression can't be evaluated because of unknown | |
* variable or function. | |
*/ | |
default double evaluate() throws EvalException { | |
return evaluate(UNIVERSE); | |
} | |
public static record Const(double value) implements Expression { | |
@Override | |
public double evaluate(Context context) throws EvalException { | |
return value; | |
} | |
} | |
public static record Variable(String name) implements Expression { | |
@Override | |
public double evaluate(Context context) throws EvalException { | |
Double value = context.getVariable(name); | |
if (value == null) throw new EvalException("Unknown variable: " + name); | |
return value; | |
} | |
} | |
public static enum Operator { | |
ADD("+") { | |
@Override | |
public double apply(double a, double b) { | |
return a + b; | |
} | |
}, | |
SUBTRACT("-") { | |
@Override | |
public double apply(double a, double b) { | |
return a - b; | |
} | |
}, | |
MULTIPLY("*") { | |
@Override | |
public double apply(double a, double b) { | |
return a * b; | |
} | |
}, | |
DIVIDE("/") { | |
@Override | |
public double apply(double a, double b) { | |
return a / b; | |
} | |
}; | |
private String symbol; | |
private Operator(String symbol) { | |
this.symbol = symbol; | |
} | |
public String getSymbol() { return symbol; } | |
public abstract double apply(double a, double b); | |
} | |
public static record Operate(Expression a, Expression b, Operator operator) implements Expression { | |
@Override | |
public double evaluate(Context context) throws EvalException { | |
return operator.apply(a.evaluate(context), b.evaluate(context)); | |
} | |
} | |
public static record Call1(String name, Expression param1) implements Expression { | |
@Override | |
public double evaluate(Context context) throws EvalException { | |
DoubleUnaryOperator function1 = context.getFunction1(name); | |
if (function1 == null) throw new EvalException("Unknown function: " + name + "(x)"); | |
return function1.applyAsDouble(param1.evaluate(context)); | |
} | |
} | |
public static record Call2(String name, Expression param1, Expression param2) implements Expression { | |
@Override | |
public double evaluate(Context context) throws EvalException { | |
DoubleBinaryOperator function2 = context.getFunction2(name); | |
if (function2 == null) throw new EvalException("Unknown function: " + name + "(x, y)"); | |
return function2.applyAsDouble(param1.evaluate(context), param2.evaluate(context)); | |
} | |
} | |
public static final Pattern PATTERN = Pattern.compile("([\\w.]+|[+*/(),-])"); | |
@SuppressWarnings("rawtypes") | |
public static final Set[] ORDER_OF_OPERATIONS = { | |
Set.of(Operator.MULTIPLY, Operator.DIVIDE), | |
Set.of(Operator.ADD, Operator.SUBTRACT) | |
}; | |
/** | |
* <p> | |
* Compile math expression from string to {@link Expression}, which you can | |
* {@link #evaluate()} at any time. | |
* </p> | |
* | |
* @param source The math expression in string form. | |
* @return The compiled expression. | |
* @throws IllegalArgumentException if the expression can't be compiled. | |
*/ | |
public static Expression compile(String source) throws IllegalArgumentException { | |
Matcher matcher = PATTERN.matcher(source); | |
List<Expression> expressions = new ArrayList<>(); | |
while (matcher.find()) { | |
String result = matcher.group(1); | |
expressions.add(isNumber(result) ? new Const(Double.parseDouble(result)) : new $$$Token(result)); | |
} | |
return validate(reduce(expressions)); | |
} | |
public static record $$$Token(String token) implements Expression { | |
@Override | |
public double evaluate(Context context) throws EvalException { | |
throw new EvalException("Not parsed"); | |
} | |
} | |
private static boolean isSymbol(String token) { | |
if (isNumber(token)) return false; | |
if (token.equals("(") || token.equals(")") || token.equals(",")) return false; | |
for (Operator operator : Operator.values()) if (token.equals(operator.symbol)) return false; | |
return true; | |
} | |
private static boolean isNumber(String token) { | |
for (int i = 0; i < token.length(); i++) | |
if (token.charAt(i) != '.' && (token.charAt(i) < '0' || token.charAt(i) > '9')) return false; | |
return true; | |
} | |
@SuppressWarnings("rawtypes") | |
private static Expression reduce(List<Expression> expressions) { | |
while (scan(expressions)); | |
while (expressions.size() > 1) { | |
int lastSize = expressions.size(); | |
for (Set set : ORDER_OF_OPERATIONS) { | |
boolean applied; | |
do { | |
applied = false; | |
for (int i = 1; i < expressions.size() - 1; i++) { | |
Expression left = expressions.get(i - 1); | |
Operator operator = asOperator(expressions.get(i)); | |
Expression right = expressions.get(i + 1); | |
if (set.contains(operator)) { | |
expressions.remove(i + 1); | |
expressions.remove(i); | |
expressions.set(i - 1, new Operate(left, right, operator)); | |
applied = true; | |
break; | |
} else { | |
i++; // Step by 2 | |
} | |
} | |
} while (applied); | |
} | |
if (lastSize == expressions.size()) | |
throw new IllegalArgumentException("Parse error: Stuck in infinite loop"); | |
lastSize = expressions.size(); | |
} | |
return expressions.get(0); | |
} | |
private static boolean scan(List<Expression> expressions) { | |
boolean mark = false; | |
for (int i = 0; i < expressions.size(); i++) { | |
Expression expr = expressions.get(i); | |
if (expr instanceof $$$Token token) { | |
if (isSymbol(token.token)) { | |
// Last | |
if (i == expressions.size() - 1) { | |
expressions.set(i, new Variable(token.token)); | |
return true; | |
} | |
if (expressions.get(i + 1) instanceof $$$Token nextToken) { | |
// Function | |
if (nextToken.token.equals("(")) { | |
List<Expression> subExpression = new ArrayList<>(); | |
List<Expression> params = new ArrayList<>(); | |
int depth = 0; | |
int removes = 3; | |
for (int j = i + 2; j < expressions.size(); j++) { | |
if (expressions.get(j) instanceof $$$Token altToken) { | |
if (altToken.token.equals(")")) { | |
depth--; | |
if (depth == -1) { | |
params.add(reduce(subExpression)); | |
break; | |
} | |
} | |
if (altToken.token.equals("(")) depth++; | |
if (altToken.token.equals(",") && depth == 0) { | |
params.add(reduce(subExpression)); | |
subExpression = new ArrayList<>(); | |
removes++; | |
continue; | |
} | |
} | |
subExpression.add(expressions.get(j)); | |
removes++; | |
} | |
if (depth != -1) throw new IllegalArgumentException("Missing ')'"); | |
if (params.size() < 1 || params.size() > 2) | |
throw new IllegalArgumentException("Unsupported number of function parameters: " | |
+ params.size()); | |
while (removes > 0) { | |
expressions.remove(i); | |
removes--; | |
} | |
if (params.size() == 1) | |
expressions.add(i, new Call1(token.token, params.get(0))); | |
if (params.size() == 2) | |
expressions.add(i, new Call2(token.token, params.get(0), params.get(1))); | |
mark = true; | |
continue; | |
} | |
} | |
// Regular variable otherwise | |
expressions.set(i, new Variable(token.token)); | |
mark = true; | |
} | |
// Group | |
if (token.token.equals("(")) { | |
List<Expression> subExpression = new ArrayList<>(); | |
int depth = 0; | |
int removes = 2; | |
for (int j = i + 1; j < expressions.size(); j++) { | |
if (expressions.get(j) instanceof $$$Token altToken) { | |
if (altToken.token.equals(")")) { | |
depth--; | |
if (depth == -1) break; | |
} | |
if (altToken.token.equals("(")) depth++; | |
if (altToken.token.equals(",") && depth == 0) | |
throw new IllegalArgumentException("Unexpected ','"); | |
} | |
subExpression.add(expressions.get(j)); | |
removes++; | |
} | |
while (removes > 0) { | |
expressions.remove(i); | |
removes--; | |
} | |
expressions.add(i, reduce(subExpression)); | |
mark = true; | |
} | |
} | |
} | |
return mark; | |
} | |
private static Operator asOperator(Expression expr) { | |
if (!(expr instanceof $$$Token token)) throw new IllegalArgumentException("Unexpected expression: " + expr); | |
return switch (token.token) { | |
case "+" -> Operator.ADD; | |
case "-" -> Operator.SUBTRACT; | |
case "*" -> Operator.MULTIPLY; | |
case "/" -> Operator.DIVIDE; | |
default -> throw new IllegalArgumentException("Unexpected token: " + token.token); | |
}; | |
} | |
private static Expression validate(Expression expr) { | |
if (expr instanceof $$$Token token) throw new IllegalArgumentException("Unexpected token: " + token.token); | |
if (expr instanceof Call1 call1) validate(call1.param1); | |
if (expr instanceof Call2 call2) { | |
validate(call2.param1); | |
validate(call2.param2); | |
} | |
if (expr instanceof Operate operate) { | |
validate(operate.a); | |
validate(operate.b); | |
} | |
return expr; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class ExpressionTest { | |
public static void main(String[] args) { | |
System.out.println(Expression.compile("1 + 2 * 32 / (4 + 5 * 6) + sin(pi) + e").evaluate()); | |
System.out.println(1d + 2d * 32d / (4d + 5d * 6d) + Math.sin(Math.PI) + Math.E); | |
// => 5.600634769635516 | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment