Skip to content

Instantly share code, notes, and snippets.

@edoakes
Last active October 12, 2021 14:15
Show Gist options
  • Save edoakes/c41bee99e030e75f77fb0824321af4c4 to your computer and use it in GitHub Desktop.
Save edoakes/c41bee99e030e75f77fb0824321af4c4 to your computer and use it in GitHub Desktop.
@pipeline.step
def preprocess(_input: str) -> PreprocessOutput:
pass
@pipeline.step(num_replicas=10, num_gpus=1)
class Model1:
def __call__(self, _input: PreprocessOutput) -> Model1Output:
pass
@pipeline.step(num_replicas=5, num_cpus=1)
class Model2:
def __call__(self, _input: PreprocessOutput) -> Model2Output:
pass
@pipeline.step
def combiner(arg1: Model1Output, arg2: Model2Output) -> FinalResult:
pass
def build_pipeline():
# pipeline.INPUT defines that this is the start of the DAG. You could have
# multiple steps that take pipeline.INPUT, it's effectively the entry
# node in the DAG.
preprocess_output = preprocess(pipeline.INPUT)
model1 = Model1(*args, **kwargs)
model2 = Model2(*args, **kwargs)
return combiner(model1(preprocess), model2(preprocess))
@serve.deployment
class PipelineDriver:
# Pipelines that the driver depends on can be declared using class variables.
# All replicas of the deployment will have a handle to the same pipeline.
# The serve controller will manage the pipeline's lifetime and deploy it for us.
pipeline = build_pipeline()
async def __call__(self, _input: str) -> FinalResult:
return await pipeline.remote(_input)
PipelineDriver.deploy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment