Created
December 10, 2020 11:48
-
-
Save miraculixx/4ee1bb5fb5466add074ae0d3680347a7 to your computer and use it in GitHub Desktop.
omegaml async task chaining
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# this will chain the fit and predict, i.e. fit will run only on predict success | |
with om.runtime.chain() as crt: | |
crt.model('regmodelx').fit('sample[y]', 'sample[x]') | |
crt.model('regmodelx').predict([5], rName='foox') | |
result = crt.run() | |
# sometime later | |
print(result.get()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
from contextlib import contextmanager | |
class TaskChain: | |
def __init__(self): | |
self.tasks = [] | |
def add(self, task): | |
self.tasks.append(task) | |
def delay(self, *args, **kwargs): | |
return self.apply_async(args=args, kwargs=kwargs) | |
def apply_async(self, args=None, kwargs=None, **celery_kwargs): | |
task = self.tasks[-1] | |
task._apply_kwargs(kwargs, celery_kwargs) | |
# immutable means results are not passed on from task to task | |
sig = task.task.signature(args=args, kwargs=kwargs, **celery_kwargs, immutable=True) | |
self.tasks[-1] = sig | |
return self | |
def run(self): | |
from celery import chain | |
chained = chain(*self.tasks) | |
return chained.apply_async() | |
@contextmanager | |
def chain(self): | |
chain = TaskChain() | |
_orig_task = self.task | |
def chaining_task(*args, **kwargs): | |
task = _orig_task(*args, **kwargs) | |
chain.add(task) | |
return chain | |
self.task = chaining_task | |
chain.runtime = self | |
chain.runtime.run = chain.run | |
try: | |
yield chain.runtime | |
finally: | |
chain.runtime.task = _orig_task | |
chain.runtime.run = None | |
om.runtime.chain = types.MethodType(chain, om.runtime) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment