Skip to content

Instantly share code, notes, and snippets.

@ZechCodes
Created August 15, 2022 02:51
Show Gist options
  • Save ZechCodes/01749f44c819175dc64913f879a73ad3 to your computer and use it in GitHub Desktop.
Save ZechCodes/01749f44c819175dc64913f879a73ad3 to your computer and use it in GitHub Desktop.
A crazy little async state machine that's intended to read like english as much as possible. Just download both file and run the `count_bs_after_as.py` to see it count all the B's it finds after A's in the string.
import asyncio
from state_machine import StateMachine, State
class CountBsAfterAs(StateMachine):
def __init__(self):
super().__init__()
self.num_bs_after_as = 0
@State
async def begin(self, character: str):
await self.next(character)
@State
async def found_an_a(self, character: str):
return
@State
async def not_an_a(self, character: str):
return
@State
async def b_after_an_a(self, character: str):
self.num_bs_after_as += 1
await self.next(character)
@ (begin & found_an_a & not_an_a).goes_to(found_an_a).when
async def character_is_an_a(self, character: str) -> bool:
return character.lower() == "a"
@found_an_a.goes_to(b_after_an_a).when
async def character_is_a_b(self, character: str) -> bool:
return character.lower() == "b"
(begin & found_an_a & b_after_an_a & not_an_a).goes_to(not_an_a)
async def main():
text = "Abbbabaaaaacccbaaab"
machine = CountBsAfterAs()
await machine.start(CountBsAfterAs.begin, "")
for character in text:
await machine.next(character)
print("Found", machine.num_bs_after_as, "B's after A's")
asyncio.run(main())
from __future__ import annotations
from typing import Awaitable, Callable, Generic, ParamSpec, TypeAlias, TypeVar
P = ParamSpec("P")
T = TypeVar("T")
Predicate: TypeAlias = Callable[[P], Awaitable[bool]]
StateRunFunction: TypeAlias = Callable[[P], Awaitable[T]]
StateLeaveFunction: TypeAlias = Callable[[P], Awaitable[None]]
class State(Generic[T, P]):
def __init__(self, run: StateRunFunction):
self._run = run
self._leave: StateLeaveFunction | None = None
self.transitions: dict[State[T, P], Transition] = {}
self.name = None
def __and__(self, other) -> AggregateState:
return AggregateState(self, other)
def __repr__(self):
return f"{type(self).__name__}({self.name!r})"
def __set_name__(self, owner, name):
self.name = name
def goes_to(self, to_state: State) -> Transition:
self.transitions[to_state] = Transition(self, to_state)
return self.transitions[to_state]
async def leave(self, *args: P.args, **kwargs: P.args):
if self._leave:
await self._leave(*args, **kwargs)
async def next(self, *args: P.args, **kwargs: P.kwargs) -> State[T, P]:
for state, transition in self.transitions.items():
if await transition.can_transition(*args, **kwargs):
return state
else:
raise Exception(
f"{self} has no transitions that match the given arguments\n {args=}\n {kwargs=}"
)
def on_leave(self, leave: StateLeaveFunction) -> State[T, P]:
self._leave = leave
return self
async def run(self, *args: P.args, **kwargs: P.args) -> T:
return await self._run(*args, **kwargs)
class Transition:
def __init__(
self, from_state: State, to_state: State, predicate: Predicate | None = None
):
self.from_state = from_state
self.to_state = to_state
self.name = None
self.predicate = predicate or self._predicate
def __set_name__(self, owner, name):
self.name = name
async def can_transition(self, *args: P.args, **kwargs: P.kwargs) -> bool:
return await self.predicate(*args, **kwargs)
def when(self, predicate: Predicate) -> Transition:
self.predicate = predicate
return self
async def _predicate(self, *_, **__) -> bool:
return True
class AggregateState:
def __init__(self, *states):
self.states = list(states)
def __and__(self, other):
self.states.append(other)
return self
def goes_to(self, to_state: State) -> SharedTransition:
return SharedTransition(*self.states, to_state=to_state)
class SharedTransition:
def __init__(self, *from_states: State, to_state: State):
self.from_states = from_states
self.to_state = to_state
self.transitions = [state.goes_to(self.to_state) for state in from_states]
def when(self, predicate: Predicate) -> SharedTransition:
self.transitions = [
transition.when(predicate) for transition in self.transitions
]
return self
def __set_name__(self, owner, name):
self.name = name
class StateMachine(Generic[T, P]):
def __init__(self):
self._current_state: State[T, P] | None = None
self._current_value: T | None = None
@property
def state(self) -> State[T, P]:
return self._current_state
@property
def value(self) -> T:
return self._current_value
async def start(
self, initial_state: State[T, P], *args: P.args, **kwargs: P.kwargs
) -> StateMachine[T, P]:
if self._current_state:
raise Exception(
f"{self} has already been started"
)
await self._set_current_state(initial_state, *args, **kwargs)
return self
async def next(self, *args: P.args, **kwargs: P.kwargs) -> T:
next_state = await self._current_state.next(self, *args, **kwargs)
return await self._set_current_state(next_state, *args, **kwargs)
async def _set_current_state(
self, new_state: State[T, P], *args: P.args, **kwargs: P.kwargs
) -> T:
if self._current_state:
await self._current_state.leave(self, *args, **kwargs)
self._current_state = new_state
self._current_value = await self._current_state.run(self, *args, **kwargs)
return self._current_value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment