Last active
October 20, 2021 23:09
-
-
Save d1manson/81c5982b144671783b37b71de12c7be5 to your computer and use it in GitHub Desktop.
Prefect control flow helper, similar to `case`. See https://github.com/PrefectHQ/prefect/issues/5071#issuecomment-947846404
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, TYPE_CHECKING, Dict | |
import prefect | |
from prefect import Task, Flow | |
from prefect.triggers import all_successful | |
from prefect.tasks.control_flow.conditional import CompareValue | |
from prefect.engine import signals | |
if TYPE_CHECKING: | |
from prefect.engine import state # noqa | |
from prefect import core # noqa | |
__all__ = ("if_equal",) | |
class EndIfEqual(Task): | |
def __init__(self, name: str): | |
def trigger(upstream_states: Dict["core.Edge", "state.State"]): | |
for edge, edge_state in upstream_states.items(): | |
if edge.key == 'parent_condition' and edge_state.is_skipped(): | |
raise signals.SKIP( | |
"Parent if-equal block was skipped, so this block should return SKIP.") | |
return all_successful(upstream_states) | |
super().__init__(name, skip_on_upstream_skip=False, trigger=trigger) | |
def run(self, parent_condition): | |
pass | |
class if_equal(object): | |
""" | |
This is adapted from: | |
https://raw.githubusercontent.com/PrefectHQ/prefect/e699fce534f77106e52dda1f0f0b23a3f8bcdf81/src/prefect/tasks/control_flow/case.py | |
A conditional block in a flow definition. | |
Used as a context-manager, `if_equal` creates a block of tasks that are only | |
run if the result of `task` is equal to `value`. | |
Args: | |
- task (Task): The task to use in the comparison | |
- value (Any): A constant the result of `task` will be compared with | |
Example: | |
A `if_equal` block is similar to Python's if-blocks. It delimits a block | |
of tasks that will only be run if the result of `task` is equal to | |
`value`: | |
```python | |
a = task_a() | |
x = task_x() | |
with if_equal(x, '42') as conditional: | |
b = task_b() | |
b.set_upstream(a) | |
c = task_c() | |
c.set_upstream(conditional) | |
``` | |
In this example, task c will run after task b, whether or not b is skipped. | |
And if a fails, the failure will propagate through to the end. | |
The `value` argument can be any non-task object. | |
See https://github.com/PrefectHQ/prefect/issues/5071#issuecomment-947846404 | |
""" | |
def __init__(self, task: Task, value: Any, name="if_equal"): | |
if isinstance(value, Task): | |
raise TypeError("`value` cannot be a task") | |
self.task = task | |
self.value = value | |
self._name = name | |
self._tasks = set() | |
self._flow = None | |
def add_task(self, task: Task, flow: Flow) -> None: | |
"""Add a new task under the if_equal statement. | |
Args: | |
- task (Task): the task to add | |
- flow (Flow): the flow to use | |
""" | |
if self._flow is None: | |
self._flow = flow | |
elif self._flow is not flow: | |
raise ValueError( | |
"Multiple flows cannot be used with the same if_equal statement" | |
) | |
self._tasks.add(task) | |
# We need to let all the if_equal blocks up the stack know about this task too | |
# Warning: this breaks if there are case blocks in the stack | |
if self.__parent_case: | |
self.__parent_case.add_task(task, flow) | |
def __enter__(self): | |
parent = prefect.context.get("case") | |
self._end_if_equal = EndIfEqual(name=f"{self._name}:end")( | |
parent_condition=parent and parent._cond) | |
self._cond = CompareValue(self.value, name=f"{self._name}:if({self.value})").bind( | |
value=self.task | |
) | |
self.__parent_case = parent | |
prefect.context.update(case=self) | |
return self._end_if_equal | |
def __exit__(self, *args): | |
if self.__parent_case is None: | |
prefect.context.pop("case", None) | |
else: | |
prefect.context.update(case=self.__parent_case) | |
# This deals with the skip vs fail issue by copying upstream dependancies onto the end-if | |
# See https://github.com/PrefectHQ/prefect/issues/5071 | |
self._end_if_equal.set_upstream(self.task) | |
for task in self._tasks: | |
upstream_tasks_in_context = self._flow.upstream_tasks( | |
task).intersection(self._tasks) | |
upstream_tasks_not_in_context = self._flow.upstream_tasks( | |
task).difference(self._tasks) | |
downstream_tasks_in_context = self._flow.downstream_tasks( | |
task).intersection(self._tasks) | |
for u_task in upstream_tasks_not_in_context: | |
# This deals with the skip vs fail issue by copying upstream dependancies onto the end-if | |
# See https://github.com/PrefectHQ/prefect/issues/5071 | |
self._end_if_equal.set_upstream(u_task, flow=self._flow) | |
if not downstream_tasks_in_context: | |
# Nothing else within the context depends on this, so we connect it up to the end-if | |
self._end_if_equal.set_upstream(task, flow=self._flow) | |
if not upstream_tasks_in_context: | |
# We need the condition to be upstream of this task, because there's no other tasks within the context that are | |
task.set_upstream(self._cond, flow=self._flow) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment