Created
September 7, 2023 18:16
-
-
Save nitori/e4b90a557dfd9494fce070ae0a999735 to your computer and use it in GitHub Desktop.
Evaluate math expression using ast
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ast | |
| class NodeVisitor(ast.NodeVisitor): | |
| allowed = { | |
| ast.Add, ast.Sub, | |
| ast.Mult, ast.Div, | |
| ast.FloorDiv, ast.Mod, | |
| ast.Pow, ast.BitXor, | |
| ast.USub, ast.UAdd, | |
| ast.BinOp, ast.UnaryOp, | |
| ast.Constant, | |
| } | |
| def visit(self, node): | |
| if type(node) not in self.allowed: | |
| raise SyntaxError(f'Invalid syntax. {type(node)} is not allowed.') | |
| return super().visit(node) | |
| def visit_Constant(self, node: ast.Constant): | |
| if not isinstance(node.value, (int, float)) or isinstance(node.value, bool): | |
| raise SyntaxError(f'Invalid syntax. Only numbers are allowed. {node.value!r} is of type {type(node.value)}') | |
| return super().visit_Constant(node) | |
| def is_math_expression(code: str, *, dump: bool = False) -> ast.Expression: | |
| code = code.replace('^', '**') # ^ is xor, but make it behave like pow | |
| m = ast.parse(code) | |
| ast.fix_missing_locations(m) | |
| if len(m.body) != 1: | |
| raise SyntaxError('Invalid syntax. Only one expression is allowed.') | |
| expr = m.body[0] | |
| if not isinstance(expr, ast.Expr): | |
| raise SyntaxError('Invalid syntax. Only expressions are allowed.') | |
| if dump: | |
| print(ast.dump(expr.value, indent=2)) | |
| v = NodeVisitor() | |
| v.visit(expr.value) | |
| return ast.Expression(expr.value) | |
| def run_math_expression(code: str, *, dump: bool = False): | |
| m = is_math_expression(code, dump=dump) | |
| code = compile(m, filename='<ast>', mode='eval') | |
| return eval(code, {'__builtins__': None}) | |
| def main(): | |
| expressions = [ | |
| ('1+2', 3), | |
| ('1+2*3', 7), | |
| ('-24*21+13*21/5', -449.4), | |
| ('2^7', 128), | |
| ('2**7', 128), | |
| ('2**3**4', 2417851639229258349412352), | |
| ('2^3^4', 2417851639229258349412352), | |
| ('import os', None), | |
| ('1+2; 3+4', None), | |
| ('__import__("os").system("echo hello")', None), | |
| ] | |
| for expr, expected in expressions: | |
| print(expr, end='') | |
| try: | |
| result = run_math_expression(expr) | |
| except SyntaxError as e: | |
| if expected is not None: | |
| print(f' ... FAIL: {e}') | |
| continue | |
| else: | |
| if expected is None: | |
| print(f' ... FAIL: {result!r} should be an error') | |
| continue | |
| if result != expected: | |
| print(f' ... FAIL: {result!r} != {expected!r}') | |
| continue | |
| print(' ... OK') | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment