Skip to content

Instantly share code, notes, and snippets.

@gryBox
Forked from emcake/recursion_flow.py
Created January 15, 2021 21:28
Show Gist options
  • Save gryBox/497b54de90040d24bc7f85b5bc13fa85 to your computer and use it in GitHub Desktop.
Save gryBox/497b54de90040d24bc7f85b5bc13fa85 to your computer and use it in GitHub Desktop.
import prefect
from prefect import Flow, task, Task, case, Parameter
from prefect.environments import LocalEnvironment
from prefect.environments.storage import S3
from prefect.tasks.prefect import FlowRunTask
from prefect.tasks.control_flow import merge
from prefect.engine.results import PrefectResult
@task
def print_fib_parameter(fib) :
logger = prefect.context.get('logger')
logger.info(f'my parameter: {fib}')
@task
def is_terminal(fib):
return fib <= 2
@task
def recurse_parameter(n):
return dict(fib_value = n)
@task
def sum_up(items):
return sum(items)
@task
def one():
return 1
result_task_slug = 'result'
@task(slug=result_task_slug, name=result_task_slug, result=PrefectResult())
def cache(x):
"""
takes the result of a previous step and ensures it's cached.
This is the magic recursion bit - we search for this task in get_result_from_subflow and check that
it matches the result type we expect. we match the slug of this task and the slug of the task we look
for using the `result_task_slug` variable above.
"""
return x
@task
def get_result_from_subflow(flow_run_id):
"""
extracts a result value out of a subflow.
Note that the result doesn't have to be PrefectResult, just that whatever result type needs to match
between the caching and the extraction.
"""
c = prefect.Client()
info = c.get_flow_run_info(flow_run_id)
res = None
for x in info.task_runs.to_list():
if x.task_slug == result_task_slug :
x.state.load_result()
if isinstance(x.state._result, PrefectResult) :
res = x.state.result
else:
raise ValueError('expected PrefectResult')
return res
my_project = 'recursion'
my_flow = 'recursive-fibonacci'
with Flow(my_flow) as flow:
n = Parameter('fib_value', default=1)
printed = print_fib_parameter(n)
## non-terminal tasks
m_1 = n - 1
m_2 = n - 2
lst = prefect.tasks.core.collections.List()
inputs = lst(m_1, m_2)
parameters = recurse_parameter.map(inputs)
subflows = FlowRunTask(flow_name=my_flow, project_name=my_project, wait=True).map(parameters=parameters)
values = get_result_from_subflow.map(subflows)
f = sum_up(values)
t = one()
prefect.tasks.control_flow.ifelse(is_terminal(n), t,inputs)
merged = merge(t,f)
res = cache(merged)
flow.register(project_name=my_project)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment