Skip to content

Instantly share code, notes, and snippets.

@nitori
Last active November 12, 2022 02:48
Show Gist options
  • Select an option

  • Save nitori/d1865458adbf64c81e821c332d244ab2 to your computer and use it in GitHub Desktop.

Select an option

Save nitori/d1865458adbf64c81e821c332d244ab2 to your computer and use it in GitHub Desktop.
Convert if_then_else function call to if expression.
import ast
class ConvertIfExpr(ast.NodeTransformer):
def visit_Call(self, node: ast.Call):
'''
Assumes the form if_then_else(cond)(then_case)(else_case)
All calls are expected to have exactly one argument.
'''
if (
len(node.args) == 1
and isinstance(node.func, ast.Call)
and len(node.func.args) == 1
and isinstance(node.func.func, ast.Call)
and len(node.func.func.args) == 1
and isinstance(node.func.func.func, ast.Name)
and node.func.func.func.id == 'if_then_else'
):
condition = node.func.func.args[0]
then_case = node.func.args[0]
else_case = node.args[0]
return ast.IfExp(
test=condition,
body=then_case,
orelse=else_case,
)
return self.generic_visit(node)
s1 = 'if_then_else(cond)(then_case)(else_case)'
s2 = '''
(sin(
if_then_else( safe_div(prey_captured)(1.0) )
( safe_div(hit_wall)(moves_remaining) )
( move_forward(0.0) )
)
)'''
# t = ast.parse(s1) # simpler case
t = ast.parse(s2)
print(ast.unparse(t))
t2 = ast.fix_missing_locations(ConvertIfExpr().visit(t))
print(ast.unparse(t2))
# Output s1:
if_then_else(cond)(then_case)(else_case)
then_case if cond else else_case
# Output s2:
sin(if_then_else(safe_div(prey_captured)(1.0))(safe_div(hit_wall)(moves_remaining))(move_forward(0.0)))
sin(safe_div(hit_wall)(moves_remaining) if safe_div(prey_captured)(1.0) else move_forward(0.0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment