Last active
September 3, 2019 20:52
-
-
Save mgxd/3c62ae0c42b47025abfbd97ddaa1f569 to your computer and use it in GitHub Desktop.
Pydra Hooks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pydra import mark, Workflow | |
@mark.task | |
def add1(x): | |
return x + 1 | |
# start workflow | |
wf = Workflow(name='wf', input_spec=['x']) | |
wf.inputs.x = 2 | |
# add first task | |
t1 = add1(name='task1') | |
wf.add(t1) | |
t1.inputs.x = wf.lzin.x | |
# define some hooks | |
def prehook(task, *args): | |
import time | |
print(f"Called before {task} executes") | |
if task.inputs.x == 3: | |
print("Sleeping...") | |
time.sleep(2) | |
def posthook(task, *args): | |
print(f"Called after {task} executes") | |
# add second task | |
t2 = add1(name='task2') | |
t2.hooks.pre_run = prehook | |
t2.hooks.post_run = posthook | |
wf.add(t2) | |
t2.inputs.x = wf.task1.lzout.out | |
wf.set_output([('out', wf.task2.lzout.out)]) | |
# execute the workflow | |
wf(plugin='cf') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pydra import mark | |
from pydra.engine.core import Workflow | |
def ping_server(task, result): | |
""" | |
Post run task hook to signal server for additional processing | |
""" | |
import requests | |
import time | |
retries = 5 | |
poll_sleep = 3 | |
url = "http://localhost:8001" | |
print(f"({task}) - Communicating with server") | |
req = requests.post(f"{url}/post", data=f"output={result.output.out}") | |
if not req.status_code == 200: | |
print("Server not available") | |
return | |
print("Waiting for server response") | |
while retries: | |
req = requests.get(f"{url}/done") | |
if req.status_code == 200: | |
print("Server side processing complete") | |
return | |
retries -= 1 | |
time.sleep(poll_sleep) | |
print("Server response took too long") | |
return | |
@mark.task | |
def adder(x): | |
return x + 1 | |
wf = Workflow(name='wf', input_spec=['x']) | |
foo = adder(name='t1', x=wf.lzin.x) | |
wf.add(foo) | |
bar = adder(name='t2', x=wf.t1.lzout.out) | |
bar.hooks.post_run = ping_server | |
wf.add(bar) | |
baz = adder(name='t3', x=wf.t2.lzout.out) | |
wf.add(baz) | |
wf.set_output([('out', wf.t3.lzout.out)]) | |
wf.inputs.x = 1 | |
wf(plugin='cf') | |
print(wf.result()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sanic import Sanic | |
from sanic.response import json, text | |
app = Sanic() | |
@app.route("/done") | |
async def done(request): | |
return json({"done": "data"}) | |
@app.route("/post", methods=["POST"]) | |
async def receive(request): | |
return text("You are trying to create a user with the following POST: %s" % request.body) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=8001) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment