Last active
January 9, 2025 16:20
-
-
Save HarryR/91bd05c32f9e7df514881f3fef38efcf to your computer and use it in GitHub Desktop.
Typesafe task pipeline for Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from dataclasses import dataclass | |
from typing import Awaitable, Callable, Generic, Iterable, Literal, Tuple, TypeVar, Union | |
_TaskInputT = TypeVar('_TaskInputT') | |
_TaskOutputT = TypeVar('_TaskOutputT') | |
TaskTypeT = Literal['first', 'intermediate', 'final'] | |
@dataclass | |
class StepInfo(Generic[_TaskInputT,_TaskOutputT]): | |
task_type: TaskTypeT | |
task_name: str | |
duration_seconds: float | |
step_input: _TaskInputT | |
step_output: None|_TaskOutputT | |
step_error: None|Exception | |
def _get_qualified_name(fn) -> str: | |
"""Get the fully qualified name of the function including module path.""" | |
if hasattr(fn, '__name__'): | |
module_name = getattr(fn, '__module__', '<unknown>') | |
fn_name = getattr(fn, '__qualname__', fn.__name__) | |
return f"{module_name}.{fn_name}" | |
return str(fn) | |
@dataclass | |
class StepExecutor(Generic[_TaskInputT, _TaskOutputT]): | |
task_type: TaskTypeT | |
fn: Callable[[_TaskInputT],Awaitable[_TaskOutputT|None]] | |
async def __call__(self, input_data: _TaskInputT) -> Tuple[_TaskOutputT|None, StepInfo[_TaskInputT, _TaskOutputT]]: | |
start_time = time.time() | |
step_err: Exception|None = None | |
result: _TaskOutputT|None = None | |
try: | |
result = await self.fn(input_data) | |
except Exception as e: | |
step_err = e | |
end_time = time.time() | |
return (result, StepInfo( | |
task_type=self.task_type, | |
task_name=_get_qualified_name(self.fn), | |
step_input=input_data, | |
step_output=result, | |
step_error=step_err, | |
duration_seconds=end_time-start_time)) | |
_TaskIntermediateT = TypeVar('_TaskIntermediateT') | |
_TaskResultT = TypeVar('_TaskResultT') | |
async def pipeline( | |
task_input: _TaskInputT, | |
steps: Union[ | |
Tuple[ | |
Callable[[_TaskInputT], Awaitable[_TaskIntermediateT]], | |
Iterable[Callable[[_TaskIntermediateT], Awaitable[_TaskIntermediateT]]], | |
Callable[[_TaskIntermediateT], Awaitable[_TaskResultT]] | |
], | |
Tuple[ | |
Callable[[_TaskInputT], Awaitable[_TaskIntermediateT]], | |
Callable[[_TaskIntermediateT], Awaitable[_TaskResultT]] | |
], | |
Tuple[Callable[[_TaskInputT], Awaitable[_TaskResultT]]] | |
] | |
) -> tuple[None|_TaskResultT, list]: | |
"""Execute a sequence of asynchronous tasks while collecting execution metrics. | |
Args: | |
task_input: Initial input to the pipeline steps | |
steps: Tuple of 1-N tasks forming the pipeline | |
Returns: | |
Tuple of (final_result, execution_logs) where final_result may be None on failure | |
Raises: | |
TypeError: If pipeline structure is invalid | |
Example: | |
```python | |
async def first_task(x: int) -> str: | |
return str(x) | |
async def middle_task(x: str) -> str: | |
return x + "!" | |
async def final_task(x: str) -> bool: | |
return len(x) > 5 | |
async def main(*argv: str) -> int: | |
pipeline = (first_task, (middle_task, middle_task, middle_task, middle_task), final_task) | |
pipeline = (first_task, (middle_task), final_task) | |
pipeline = (first_task, (), final_task) | |
pipeline = (first_task, final_task) | |
result, logs = await process_pipeline(35, pipeline) | |
# or even with a single step | |
pipeline = (final_task,) | |
result, logs = await process_pipeline('hello', pipeline) | |
``` | |
""" | |
logs = [] | |
match steps: | |
case (first_step, middle_tasks, last_step): | |
middle_logs:list[StepInfo[_TaskIntermediateT, _TaskIntermediateT]] = [] | |
e = StepExecutor('first', first_step) | |
x, log = await e(task_input) | |
logs.append(log) | |
if x is None: | |
return None, logs | |
for m in middle_tasks: | |
f = StepExecutor('intermediate', m) | |
x, log = await f(x) | |
logs.append(log) | |
middle_logs.append(log) | |
if x is None: | |
return None, logs | |
g = StepExecutor('final', last_step) | |
y, log = await g(x) | |
logs.append(log) | |
return y, logs | |
case (first_step, last_step): | |
e = StepExecutor('final', first_step) | |
x, log = await e(task_input) | |
logs.append(log) | |
if x is None: | |
return None, logs | |
f = StepExecutor('final', last_step) | |
y, log = await f(x) | |
logs.append(log) | |
return y, logs | |
case (single_step,): | |
e = StepExecutor('final', single_step) | |
x, log = await e(task_input) | |
logs.append(log) | |
return x, logs | |
raise TypeError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment