Last active
June 12, 2020 15:38
-
-
Save gryBox/2663a8257e32ccc273a6e1d1eef70834 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from prefect import Flow, Task, Parameter | |
import copy | |
def remove_middle_params(G: "Flow", params_to_remove: dict) -> Flow: | |
""" | |
Removes a parameter from a flow and reassigns its downstream task to a new upstream task. This is useful after merging two flows where | |
a parameter was only acting as conduit between two disparate flows. | |
Args: | |
- G (Flow): The flow to manipulate. | |
- params_to_remove ({"<param_name>: **kwarhs}): A dictionary of `{key: {value}}` pairs. Where the `key` is the parameter to remove and | |
the value` is a dictionary of argument to pass `flow.get_tasks()`. | |
Returns: | |
- U (Flow): A new flow with re-defined edges and extraneous parameters` removed. | |
Raises: | |
- ValueError: If the parameter was not found or upstream task was not found in flow tasks. | |
- ValueError: If the upstream task is not uniquely identified. | |
""" | |
flow1 = copy.deepcopy(G) | |
for param_name, upstream_id_attr in params_to_remove.items(): | |
# Find the parameter to remove in the flow | |
middle_param = flow1.get_tasks(name=param_name, task_type=Parameter) | |
# Find the upstream task to task to remove in the flow | |
upstream_task = flow1.get_tasks(**upstream_id_attr) | |
if not middle_param and upstream_task: | |
raise ValueError(f"Parameter {param_name} was not found in Flow {flow1}") | |
elif len(upstream_task)>1: | |
raise ValueError(f"{upstream_task[0].name} was not uniquely identified in Flow {flow1}") | |
param_to_replace = middle_param[0] | |
upstream = upstream_task[0] | |
# Get edges to replace from the paramater i.e. affected edges | |
edges_to_replace = flow1.edges_from(param_to_replace) | |
# Loop through edges and set new edges | |
for edge in edges_to_replace: | |
# Define downstream task | |
downstream = edge.downstream_task | |
flow1.add_edge( | |
upstream_task=upstream, | |
downstream_task=downstream, | |
key=edge.key, | |
mapped=edge.mapped, | |
validate=False, | |
) | |
print(f"Removing old edge {edge}") | |
flow1.edges.remove(edge) | |
# Remove middle param | |
flow1.tasks.remove(param_to_replace) | |
return flow1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is how a call to the function looks like.
A flow goes from looking like this:
To this: