Skip to content

Instantly share code, notes, and snippets.

@ashemag
Last active November 18, 2021 01:14
Show Gist options
  • Save ashemag/7e2fc4b871ee2b7b60dcbd3bf2b0b016 to your computer and use it in GitHub Desktop.
Save ashemag/7e2fc4b871ee2b7b60dcbd3bf2b0b016 to your computer and use it in GitHub Desktop.
import kfp
def configure_task(task, cache=True):
task.container.add_resource_request("memory", "50Gi")
task.container.add_resource_limit("memory", "50Gi")
task.container.add_resource_request("cpu", "250m")
if not cache:
task.execution_options.caching_strategy.max_cache_staleness = "P0D"
@kfp.dsl.pipeline(name="My pipeline name", description="My pipeline description")
def pipeline(experiment_name: str = "testing", data_filename: str = "full_data_nov", epochs: int = 100):
data_handler_op = kfp.components.load_component_from_file("components/data_handler/data_handler.yaml")
full_train_op = kfp.components.load_component_from_file("components/train/full_train.yaml")
train_op_0 = kfp.components.load_component_from_file("components/train/train.yaml")
train_op_1 = kfp.components.load_component_from_file("components/train/train.yaml")
train_op_2 = kfp.components.load_component_from_file("components/train/train.yaml")
results_op = kfp.components.load_component_from_file("components/results/results.yaml")
# run data handler
data_handler_task = data_handler_op(experiment_name, data_filename)
configure_task(data_handler_task)
# train tasks
full_train_task = full_train_op(
result_file="%s" % data_handler_task.outputs["result"],
experiment_name=experiment_name,
epochs=epochs,
)
configure_task(full_train_task, cache=False)
train_task_0 = train_op_0(
result_file="%s" % data_handler_task.outputs["result"],
fold_idx=0,
experiment_name=experiment_name,
epochs=epochs,
)
configure_task(train_task_0, cache=False)
train_task_1 = train_op_1(
result_file="%s" % data_handler_task.outputs["result"],
fold_idx=1,
experiment_name=experiment_name,
epochs=epochs,
)
configure_task(train_task_1, cache=False)
train_task_2 = train_op_2(
result_file="%s" % data_handler_task.outputs["result"],
fold_idx=2,
experiment_name=experiment_name,
epochs=epochs,
)
configure_task(train_task_2, cache=False)
# reporting
results_task = results_op(
train_task_0.outputs["result"],
train_task_1.outputs["result"],
train_task_2.outputs["result"],
full_train_task.outputs["result"],
experiment_name=experiment_name,
)
configure_task(results_task, cache=False)
if __name__ == "__main__":
pipeline_yaml = __file__.replace(".py", ".yaml")
kfp.compiler.Compiler().compile(pipeline, pipeline_yaml)
print(f"Exported pipeline definition to {pipeline_yaml}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment