Last active
September 29, 2023 16:41
-
-
Save rca/7421319 to your computer and use it in GitHub Desktop.
Python Finite State Machine implementation; logic mostly extracted from https://github.com/kmmbvnr/django-fsm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Finite State Machine | |
This FSM implementation is extracted from the django-fsm package and licensed | |
under the same BSD-like license at: | |
https://github.com/kmmbvnr/django-fsm/blob/master/LICENSE | |
Basic usage: | |
``` | |
from fsm import State, transition | |
class Container(object): | |
# some states | |
OFFLINE = 'offline' | |
ONLINE = 'online' | |
# the state machine | |
state_machine = State() | |
def __init__(self, container_id): | |
self.state = self.OFFLINE | |
self.container_id = container_id | |
self.log_proc = None | |
@transition(state_machine, source=OFFLINE, target=ONLINE) | |
def offline(self): | |
self.log_proc = attach_to_container(self.container_id) | |
@exception_transition((Disconnected,), target=OFFLINE) | |
@transition(state_machine, source=ONLINE, target=ONLINE) | |
def online(self): | |
for log in self.log_proc.get_logs(): | |
self.do_something(log) | |
def loop(self): | |
getattr(self, self.state)() | |
``` | |
""" | |
from collections import defaultdict | |
from functools import wraps | |
class Signal(object): | |
def __init__(self): | |
self.connections = [] | |
def connect(self, receiver, sender=None): | |
self.connections.append((receiver, sender)) | |
def send(self, sender, instance, name, source, target): | |
for _receiver, _sender in self.connections: | |
if _sender is None or _sender == sender: | |
_receiver(sender, instance=instance, name=name, source=source, target=target) | |
class TransitionNotAllowed(Exception): | |
"""Raise when a transition is not allowed""" | |
class State(object): | |
name = 'state' | |
def __init__(self, state='offline'): | |
self.state = state | |
self.transitions = [] | |
def __cmp__(self, other): | |
return cmp(str(self), str(other)) | |
def __repr__(self): | |
return '<State: {}>'.format(self.state) | |
def __str__(self): | |
return self.state | |
class StateMachine(object): | |
def __init__(self, state=None, target_state=None): | |
self.state = state | |
self.target_state = target_state | |
def loop(self, target_state=None): | |
# optionally set the desired target_state when calling run | |
self.target_state = target_state or self.target_state | |
# field to keep track of the state we transition to below. it is set | |
# to None at first so that "state loops" are run, i.e. a series of | |
# states that should be run which loop back to the initial state (e.g. | |
# ONLINE -> ONLINE_TASK -> ONLINE_TASK2 -> ONLINE) | |
new_state = None | |
# loop until the desired state or TERMINATED is reached | |
while new_state != self.target_state and self.state != self.TERMINATED: | |
getattr(self, self.state)() | |
new_state = self.state | |
pre_transition = Signal() | |
post_transition = Signal() | |
class FSMMeta(object): | |
""" | |
Models methods transitions meta information | |
""" | |
def __init__(self, field): | |
self.field = field | |
self.transitions = defaultdict() | |
self.conditions = defaultdict() | |
def add_transition(self, source, target, conditions=[]): | |
if source in self.transitions: | |
raise AssertionError('Duplicate transition for %s state' % source) | |
self.transitions[source] = target | |
self.conditions[source] = conditions | |
def _get_state_field(self, instance): | |
return self.field | |
def current_state(self, instance): | |
""" | |
Return current state of Django model | |
""" | |
field_name = self._get_state_field(instance).name | |
return getattr(instance, field_name) | |
def next_state(self, instance): | |
curr_state = self.current_state(instance) | |
result = None | |
try: | |
#machine_name = getattr(self.field, '_machine_name', '') | |
#print '{}: transitions={}, curr_state={}'.format(machine_name, self.transitions, curr_state) | |
result = self.transitions[str(curr_state)] | |
except KeyError: | |
result = self.transitions['*'] | |
return result | |
def has_transition(self, instance): | |
""" | |
Lookup if any transition exists from current model state | |
""" | |
return self.transitions.has_key(str(self.current_state(instance))) or self.transitions.has_key('*') | |
def conditions_met(self, instance): | |
""" | |
Check if all conditions has been met | |
""" | |
state = self.current_state(instance) | |
if state not in self.conditions: | |
state = '*' | |
if all(map(lambda f: f(instance), self.conditions.get(state, []))): | |
return True | |
return False | |
def to_next_state(self, instance): | |
""" | |
Switch to next state | |
""" | |
field_name = self._get_state_field(instance).name | |
state = self.next_state(instance) | |
if state: | |
instance.__dict__[field_name] = state | |
def transition(field, source='*', target=None, save=False, conditions=[]): | |
""" | |
Method decorator for mark allowed transition | |
Set target to None if current state need to be validated and not | |
changed after function call | |
""" | |
# pylint: disable=C0111 | |
def inner_transition(func): | |
if not hasattr(func, '_fsm_meta'): | |
setattr(func, '_fsm_meta', FSMMeta(field=field)) | |
@wraps(func) | |
def _change_state(instance, *args, **kwargs): | |
meta = func._fsm_meta | |
if not (meta.has_transition(instance) and meta.conditions_met(instance)): | |
raise TransitionNotAllowed("Can't switch from state '%s' using method '%s'" % (meta.current_state(instance), func.func_name)) | |
source_state = meta.current_state(instance) | |
pre_transition.send( | |
sender = instance.__class__, | |
instance = instance, | |
name = func.func_name, | |
source = source_state, | |
target = meta.next_state(instance)) | |
result = func(instance, *args, **kwargs) | |
meta.to_next_state(instance) | |
if save: | |
instance.save() | |
post_transition.send( | |
sender = instance.__class__, | |
instance = instance, | |
name = func.func_name, | |
source = source_state, | |
target = meta.current_state(instance)) | |
return result | |
else: | |
_change_state = func | |
if isinstance(source, (list, tuple)): | |
for state in source: | |
func._fsm_meta.add_transition(state, target, conditions) | |
else: | |
func._fsm_meta.add_transition(source, target, conditions) | |
if field: | |
field.transitions.append(_change_state) | |
return _change_state | |
return inner_transition | |
def exception_transition(exceptions, target, reraise=True): | |
""" | |
Decorator to set the state to the given target when the given exceptions are raised. | |
""" | |
def exception_transition_inner(func): | |
@wraps(func) | |
def exception_transition_wrapper(self, *args, **kwargs): | |
try: | |
return func(self, *args, **kwargs) | |
except exceptions, exc: | |
self.state = target | |
if reraise: | |
raise | |
return exception_transition_wrapper | |
return exception_transition_inner |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment