Skip to content

Instantly share code, notes, and snippets.

@gryBox
Last active June 12, 2020 15:38
Show Gist options
  • Save gryBox/2663a8257e32ccc273a6e1d1eef70834 to your computer and use it in GitHub Desktop.
Save gryBox/2663a8257e32ccc273a6e1d1eef70834 to your computer and use it in GitHub Desktop.
from prefect import Flow, Task, Parameter
import copy
def remove_middle_params(G: "Flow", params_to_remove: dict) -> Flow:
"""
Removes a parameter from a flow and reassigns its downstream task to a new upstream task. This is useful after merging two flows where
a parameter was only acting as conduit between two disparate flows.
Args:
- G (Flow): The flow to manipulate.
- params_to_remove ({"<param_name>: **kwarhs}): A dictionary of `{key: {value}}` pairs. Where the `key` is the parameter to remove and
the value` is a dictionary of argument to pass `flow.get_tasks()`.
Returns:
- U (Flow): A new flow with re-defined edges and extraneous parameters` removed.
Raises:
- ValueError: If the parameter was not found or upstream task was not found in flow tasks.
- ValueError: If the upstream task is not uniquely identified.
"""
flow1 = copy.deepcopy(G)
for param_name, upstream_id_attr in params_to_remove.items():
# Find the parameter to remove in the flow
middle_param = flow1.get_tasks(name=param_name, task_type=Parameter)
# Find the upstream task to task to remove in the flow
upstream_task = flow1.get_tasks(**upstream_id_attr)
if not middle_param and upstream_task:
raise ValueError(f"Parameter {param_name} was not found in Flow {flow1}")
elif len(upstream_task)>1:
raise ValueError(f"{upstream_task[0].name} was not uniquely identified in Flow {flow1}")
param_to_replace = middle_param[0]
upstream = upstream_task[0]
# Get edges to replace from the paramater i.e. affected edges
edges_to_replace = flow1.edges_from(param_to_replace)
# Loop through edges and set new edges
for edge in edges_to_replace:
# Define downstream task
downstream = edge.downstream_task
flow1.add_edge(
upstream_task=upstream,
downstream_task=downstream,
key=edge.key,
mapped=edge.mapped,
validate=False,
)
print(f"Removing old edge {edge}")
flow1.edges.remove(edge)
# Remove middle param
flow1.tasks.remove(param_to_replace)
return flow1
@gryBox
Copy link
Author

gryBox commented Jun 10, 2020

This is how a call to the function looks like.

node_resources_fl = remove_middle_params(
    G=node_resources_fl,
    params_to_remove={
        "primary_source_df":{
            "tags":["primary_sources"] 
            },
        "secondary_source_df":{
            "tags":["secondary_sources"]
        }
    }
)

A flow goes from looking like this:
image

To this:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment