Created
August 15, 2022 02:51
-
-
Save ZechCodes/01749f44c819175dc64913f879a73ad3 to your computer and use it in GitHub Desktop.
A crazy little async state machine that's intended to read like english as much as possible. Just download both file and run the `count_bs_after_as.py` to see it count all the B's it finds after A's in the string.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from state_machine import StateMachine, State | |
class CountBsAfterAs(StateMachine): | |
def __init__(self): | |
super().__init__() | |
self.num_bs_after_as = 0 | |
@State | |
async def begin(self, character: str): | |
await self.next(character) | |
@State | |
async def found_an_a(self, character: str): | |
return | |
@State | |
async def not_an_a(self, character: str): | |
return | |
@State | |
async def b_after_an_a(self, character: str): | |
self.num_bs_after_as += 1 | |
await self.next(character) | |
@ (begin & found_an_a & not_an_a).goes_to(found_an_a).when | |
async def character_is_an_a(self, character: str) -> bool: | |
return character.lower() == "a" | |
@found_an_a.goes_to(b_after_an_a).when | |
async def character_is_a_b(self, character: str) -> bool: | |
return character.lower() == "b" | |
(begin & found_an_a & b_after_an_a & not_an_a).goes_to(not_an_a) | |
async def main(): | |
text = "Abbbabaaaaacccbaaab" | |
machine = CountBsAfterAs() | |
await machine.start(CountBsAfterAs.begin, "") | |
for character in text: | |
await machine.next(character) | |
print("Found", machine.num_bs_after_as, "B's after A's") | |
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import Awaitable, Callable, Generic, ParamSpec, TypeAlias, TypeVar | |
P = ParamSpec("P") | |
T = TypeVar("T") | |
Predicate: TypeAlias = Callable[[P], Awaitable[bool]] | |
StateRunFunction: TypeAlias = Callable[[P], Awaitable[T]] | |
StateLeaveFunction: TypeAlias = Callable[[P], Awaitable[None]] | |
class State(Generic[T, P]): | |
def __init__(self, run: StateRunFunction): | |
self._run = run | |
self._leave: StateLeaveFunction | None = None | |
self.transitions: dict[State[T, P], Transition] = {} | |
self.name = None | |
def __and__(self, other) -> AggregateState: | |
return AggregateState(self, other) | |
def __repr__(self): | |
return f"{type(self).__name__}({self.name!r})" | |
def __set_name__(self, owner, name): | |
self.name = name | |
def goes_to(self, to_state: State) -> Transition: | |
self.transitions[to_state] = Transition(self, to_state) | |
return self.transitions[to_state] | |
async def leave(self, *args: P.args, **kwargs: P.args): | |
if self._leave: | |
await self._leave(*args, **kwargs) | |
async def next(self, *args: P.args, **kwargs: P.kwargs) -> State[T, P]: | |
for state, transition in self.transitions.items(): | |
if await transition.can_transition(*args, **kwargs): | |
return state | |
else: | |
raise Exception( | |
f"{self} has no transitions that match the given arguments\n {args=}\n {kwargs=}" | |
) | |
def on_leave(self, leave: StateLeaveFunction) -> State[T, P]: | |
self._leave = leave | |
return self | |
async def run(self, *args: P.args, **kwargs: P.args) -> T: | |
return await self._run(*args, **kwargs) | |
class Transition: | |
def __init__( | |
self, from_state: State, to_state: State, predicate: Predicate | None = None | |
): | |
self.from_state = from_state | |
self.to_state = to_state | |
self.name = None | |
self.predicate = predicate or self._predicate | |
def __set_name__(self, owner, name): | |
self.name = name | |
async def can_transition(self, *args: P.args, **kwargs: P.kwargs) -> bool: | |
return await self.predicate(*args, **kwargs) | |
def when(self, predicate: Predicate) -> Transition: | |
self.predicate = predicate | |
return self | |
async def _predicate(self, *_, **__) -> bool: | |
return True | |
class AggregateState: | |
def __init__(self, *states): | |
self.states = list(states) | |
def __and__(self, other): | |
self.states.append(other) | |
return self | |
def goes_to(self, to_state: State) -> SharedTransition: | |
return SharedTransition(*self.states, to_state=to_state) | |
class SharedTransition: | |
def __init__(self, *from_states: State, to_state: State): | |
self.from_states = from_states | |
self.to_state = to_state | |
self.transitions = [state.goes_to(self.to_state) for state in from_states] | |
def when(self, predicate: Predicate) -> SharedTransition: | |
self.transitions = [ | |
transition.when(predicate) for transition in self.transitions | |
] | |
return self | |
def __set_name__(self, owner, name): | |
self.name = name | |
class StateMachine(Generic[T, P]): | |
def __init__(self): | |
self._current_state: State[T, P] | None = None | |
self._current_value: T | None = None | |
@property | |
def state(self) -> State[T, P]: | |
return self._current_state | |
@property | |
def value(self) -> T: | |
return self._current_value | |
async def start( | |
self, initial_state: State[T, P], *args: P.args, **kwargs: P.kwargs | |
) -> StateMachine[T, P]: | |
if self._current_state: | |
raise Exception( | |
f"{self} has already been started" | |
) | |
await self._set_current_state(initial_state, *args, **kwargs) | |
return self | |
async def next(self, *args: P.args, **kwargs: P.kwargs) -> T: | |
next_state = await self._current_state.next(self, *args, **kwargs) | |
return await self._set_current_state(next_state, *args, **kwargs) | |
async def _set_current_state( | |
self, new_state: State[T, P], *args: P.args, **kwargs: P.kwargs | |
) -> T: | |
if self._current_state: | |
await self._current_state.leave(self, *args, **kwargs) | |
self._current_state = new_state | |
self._current_value = await self._current_state.run(self, *args, **kwargs) | |
return self._current_value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment