Skip to content

Instantly share code, notes, and snippets.

@mratsim
Created December 10, 2019 22:18
Show Gist options
  • Save mratsim/c54dd77fdc88fa425d94c82f8451120f to your computer and use it in GitHub Desktop.
Save mratsim/c54dd77fdc88fa425d94c82f8451120f to your computer and use it in GitHub Desktop.
# beacon_chain
# Copyright (c) 2018 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.
proc insertedAtExits(ast: NimNode, statement: NimNode): NimNode =
## Scan an AST, copy it with ``statement`` inserted at each exit point
##
## For example
##
## ```
## proc allGreater(s: seq[int], x: int): bool =
## for i in 0 ..< s.len:
## if s[i] <= x:
## return false
## return true
## ```
##
## will be transformed into
##
## ```
## proc allGreater(s: seq[int], x: int): bool =
## for i in 0 ..< s.len:
## if s[i] <= x:
## statement
## return false
## statement
## return true
## ```
##
## This is used for benchmarking hooks and has the following limitations due
## to the simplicity of implementation:
##
## - It assumes that the return statement is costless.
## I.e. "return digest(x)" will miss digest.
## This can be later fixed by always assigning the return value to result
## then insert the bench statement
## then return.
##
## - It assumes that the last statement is costless
## and proc ending by an expression are common in the codebase
proc inspect(node: NimNode, expressionsMayReturn: bool): NimNode =
## Recursively inspect the the AST tree.
## A return statement can happen anywhere
## while an expression that is also a return value can only
## happen as the last child of the current node
# TODO: Does that handle "finally"?
case node.kind
of nnkReturnStmt:
# Add our statement and re-add the return statement
result = newStmtList()
result.add statement
result.add node
of nnkStmtList, nnkStmtListExpr, nnkBlockStmt, nnkBlockExpr, nnkWhileStmt,
nnkForStmt, nnkTryStmt:
# New nested scope
# We need to go deeper
result = node.kind.newTree()
for i in 0 ..< node.len - 1:
# Check only return statement, up to the second-to-last statement
result.add inspect(node[i], expressionsMayReturn = false)
if node.len >= 1:
# Check if the very last statement is a return,
# or an expression (implicit return)
# if we are in the last nested scope
# of the last ... of the last nested scope
result.add inspect(node[^1], expressionsMayReturn)
of nnkIfStmt, nnkIfExpr:
# in a conditional scope all blocks returns or don't not just the last one.
result = node.kind.newTree()
for conditionalBranch in node:
result.add inspect(conditionalBranch, expressionsMayReturn)
of nnkCaseStmt:
result = node.kind.newTree()
# Skip the first block which is a comparison
result.add node[0]
for i in 1 ..< node.len:
result.add inspect(node[i], expressionsMayReturn)
of nnkElifBranch, nnkElifExpr:
# Only the last node carries the potential return statement/expression
result = node.kind.newTree()
result.add node[0]
result.add inspect(node[^1], expressionsMayReturn)
of nnkElse, nnkElseExpr:
result = node.kind.newTree()
result.add inspect(node[0], expressionsMayReturn)
of nnkOfBranch:
result = node.kind.newTree()
for i in 0 ..< node.len - 1:
result.add node[i]
result.add inspect(node[^1], expressionsMayReturn)
else:
# We have an ident, a function call, an assignment,
# for now we only insert our statement just before which means
# if it was an expensive function call bench will be flawed - TODO
if expressionsMayReturn:
result = newStmtList()
if node.kind == nnkAsgn and node[0].eqIdent"result":
# At least catch the common case when result assignment is last
result.add node
result.add statement
else:
result.add statement
result.add node
else:
# Not at the last statement of the scope so it can't
# be an expression that returns
return node
ast.expectKind(nnkStmtList)
result = ast.inspect(expressionsMayReturn = true)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment