Created
December 14, 2018 13:37
-
-
Save albertz/2ebec1fc0243fc6443e78c3bacc15e1e to your computer and use it in GitHub Desktop.
get_switch_op_cond_ctx, get control_flow_ops.CondContext from a switch tf.Operation (if possible)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_switch_op_cond_ctx(op): | |
""" | |
See control_flow_util.IsCondSwitch. | |
:param tf.Operation op: switch op | |
:rtype: tensorflow.python.ops.control_flow_ops.CondContext|None | |
""" | |
from tensorflow.python.ops import control_flow_ops | |
from tensorflow.python.ops import control_flow_util | |
assert op.type in {"Switch", "RefSwitch"} # control_flow_util.IsSwitch | |
assert op.outputs | |
# Switch nodes are not part of the cond control flow context that they | |
# represent, so consider the consumers of its outputs to determine if it is | |
# cond switch or not. A switch is a cond switch iff all its consumers are in | |
# cond contexts. | |
is_cond_switch = True | |
ctxt = None | |
for o in op.outputs: | |
for c in o.consumers(): | |
ctxt = c._get_control_flow_context() # pylint: disable=protected-access | |
if control_flow_util.IsLoopEnter(c): | |
ctxt = ctxt.outer_context | |
is_cond_switch = is_cond_switch and (ctxt is not None and ctxt.IsCondContext()) | |
assert is_cond_switch | |
if not ctxt: | |
# This can happen, if we just have constructed the switch, or this is via tf.gradients. | |
return None | |
assert isinstance(ctxt, control_flow_ops.CondContext) | |
return ctxt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment