Skip to content

Instantly share code, notes, and snippets.

@nitori
Created September 7, 2023 18:16
Show Gist options
  • Select an option

  • Save nitori/e4b90a557dfd9494fce070ae0a999735 to your computer and use it in GitHub Desktop.

Select an option

Save nitori/e4b90a557dfd9494fce070ae0a999735 to your computer and use it in GitHub Desktop.
Evaluate math expression using ast
import ast
class NodeVisitor(ast.NodeVisitor):
allowed = {
ast.Add, ast.Sub,
ast.Mult, ast.Div,
ast.FloorDiv, ast.Mod,
ast.Pow, ast.BitXor,
ast.USub, ast.UAdd,
ast.BinOp, ast.UnaryOp,
ast.Constant,
}
def visit(self, node):
if type(node) not in self.allowed:
raise SyntaxError(f'Invalid syntax. {type(node)} is not allowed.')
return super().visit(node)
def visit_Constant(self, node: ast.Constant):
if not isinstance(node.value, (int, float)) or isinstance(node.value, bool):
raise SyntaxError(f'Invalid syntax. Only numbers are allowed. {node.value!r} is of type {type(node.value)}')
return super().visit_Constant(node)
def is_math_expression(code: str, *, dump: bool = False) -> ast.Expression:
code = code.replace('^', '**') # ^ is xor, but make it behave like pow
m = ast.parse(code)
ast.fix_missing_locations(m)
if len(m.body) != 1:
raise SyntaxError('Invalid syntax. Only one expression is allowed.')
expr = m.body[0]
if not isinstance(expr, ast.Expr):
raise SyntaxError('Invalid syntax. Only expressions are allowed.')
if dump:
print(ast.dump(expr.value, indent=2))
v = NodeVisitor()
v.visit(expr.value)
return ast.Expression(expr.value)
def run_math_expression(code: str, *, dump: bool = False):
m = is_math_expression(code, dump=dump)
code = compile(m, filename='<ast>', mode='eval')
return eval(code, {'__builtins__': None})
def main():
expressions = [
('1+2', 3),
('1+2*3', 7),
('-24*21+13*21/5', -449.4),
('2^7', 128),
('2**7', 128),
('2**3**4', 2417851639229258349412352),
('2^3^4', 2417851639229258349412352),
('import os', None),
('1+2; 3+4', None),
('__import__("os").system("echo hello")', None),
]
for expr, expected in expressions:
print(expr, end='')
try:
result = run_math_expression(expr)
except SyntaxError as e:
if expected is not None:
print(f' ... FAIL: {e}')
continue
else:
if expected is None:
print(f' ... FAIL: {result!r} should be an error')
continue
if result != expected:
print(f' ... FAIL: {result!r} != {expected!r}')
continue
print(' ... OK')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment