Created
August 30, 2023 21:12
-
-
Save alpinus4/386572d1138b6d4a3291978c99657a83 to your computer and use it in GitHub Desktop.
Simple python state machine
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union | |
from enum import Enum | |
from typing import Callable | |
from abc import ABC, abstractmethod | |
class Transition: | |
def __init__(self, from_state: Union[str, Enum], to_state: Union[str, Enum], condition_func: Callable[[], bool]): | |
self.from_state = from_state | |
self.to_state = to_state | |
self.condition_func = condition_func | |
class StateMachine(ABC): | |
""" | |
Simple state machine. | |
It checks transitions, and moves one state at a time (if possible) during tick() method. | |
States can be virtually anything, but best kept as either strings or enums. | |
""" | |
def __init__(self): | |
self._state = None | |
self._previous_state = None | |
self._states = {} | |
self._transitions = {} | |
@abstractmethod | |
def _state_logic(self): | |
pass | |
def _get_transition(self): | |
for transition in self._transitions[self._state]: | |
if transition.condition_func(): | |
return transition.to_state | |
return None | |
@abstractmethod | |
def _enter_state(self, new_state: Union[str, Enum], old_state: Union[str, Enum]): | |
pass | |
@abstractmethod | |
def _exit_state(self, old_state: Union[str, Enum], new_state: Union[str, Enum]): | |
pass | |
def set_state(self, new_state: Union[str, Enum]): | |
self._previous_state = self._state | |
self._state = new_state | |
if self._previous_state: | |
self._exit_state(self._previous_state, new_state) | |
if new_state: | |
self._enter_state(new_state, self._previous_state) | |
def add_state(self, state: Union[str, Enum]): | |
self._states[state] = state | |
self._transitions[state] = [] | |
return self | |
def add_transition(self, from_state: Union[str, Enum], to_state: Union[str, Enum], condition_func: Callable[[], bool]): | |
self._transitions[from_state].append(Transition(from_state, to_state, condition_func)) | |
return self | |
def tick(self): | |
if self._state: | |
self._state_logic() | |
transition = self._get_transition() | |
if transition: | |
self.set_state(transition) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# EXAMPLE USAGE | |
from state_machine import StateMachine | |
from enum import Enum | |
class Wolf: | |
def __init__(self): | |
self.hunger = 10 | |
self.machine = WolfStateMachine(self) | |
def tick(self) -> None: | |
self.machine.tick() | |
def is_there_food_to_eat(self): | |
return True | |
def is_hungry(self): | |
if self.hunger > 5: | |
return True | |
return False | |
def hunt(self): | |
print("Hunt") | |
def eat(self): | |
print("Eat") | |
self.hunger -= 3 | |
def sleep(self): | |
print("Sleep") | |
self.hunger += 0.5 | |
class State(Enum): | |
HUNT = 0 | |
EAT = 1 | |
SLEEP = 2 | |
class WolfStateMachine(StateMachine): | |
def __init__(self, wolf: Wolf): | |
super().__init__() | |
self.wolf = wolf | |
self.add_state(Wolf.State.HUNT)\ | |
.add_state(Wolf.State.EAT)\ | |
.add_state(Wolf.State.SLEEP)\ | |
.add_transition(Wolf.State.HUNT, Wolf.State.EAT, lambda: self.wolf.is_there_food_to_eat())\ | |
.add_transition(Wolf.State.EAT, Wolf.State.SLEEP, lambda: not self.wolf.is_hungry()) \ | |
.add_transition(Wolf.State.EAT, Wolf.State.HUNT, lambda: self.wolf.is_hungry() and not self.wolf.is_there_food_to_eat()) \ | |
.add_transition(Wolf.State.SLEEP, Wolf.State.EAT, lambda: self.wolf.is_there_food_to_eat() and self.wolf.is_hungry())\ | |
.set_state(Wolf.State.HUNT) | |
# logic to be done based on state | |
def _state_logic(self): | |
match self._state: | |
case Wolf.State.HUNT: | |
self.wolf.hunt() | |
case Wolf.State.EAT: | |
self.wolf.eat() | |
case Wolf.State.SLEEP: | |
self.wolf.sleep() | |
# called when entering state | |
def _enter_state(self, new_state, old_state): | |
if new_state == Wolf.State.SLEEP: | |
print("Wolf is going to sleep") | |
# called when exiting state | |
def _exit_state(self, old_state, new_state): | |
if old_state == Wolf.State.SLEEP: | |
print("Wolf is waking up") | |
wolf = Wolf() | |
for _ in range (20): | |
wolf.tick() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment