Skip to content

Instantly share code, notes, and snippets.

@mariocesar
Last active December 28, 2021 19:59
Show Gist options
  • Save mariocesar/989c042b7c9b00bc015e to your computer and use it in GitHub Desktop.
Save mariocesar/989c042b7c9b00bc015e to your computer and use it in GitHub Desktop.
Django state field that enforce a workflow path
from collections import namedtuple
from functools import wraps
from itertools import chain
from django.utils.functional import curry
from django.db.models import CharField
class StateField(CharField):
Starts = namedtuple('Starts', ['state'])
Ends = namedtuple('Ends', ['state'])
Transition = namedtuple('Transition', ['from_state', 'to_state', 'label'])
def __init__(self, steps, *args, **kwargs):
self.steps = steps
self.transitions = [step for step in self.steps if isinstance(step, StateField.Transition)]
self.starts_with = [state for state in self.steps if isinstance(state, StateField.Starts)]
self.ends_with = [state for state in self.steps if isinstance(state, StateField.Ends)]
kwargs['max_length'] = 255
kwargs['db_index'] = True
kwargs['choices'] = self.get_choices()
super().__init__(*args, **kwargs)
def get_state_choices(self):
"""Return a choice like list of all available status."""
choices = list(set(chain(*((step.from_state, step.to_state) for step in self.transitions))))
return [(choices, choices.title()) for choices in choices]
@staticmethod
def before_transition(func):
"""Executed before transition to STATE"""
def decorator(state):
@wraps(func)
def wrapper(self):
return func(self)
return wrapper
return decorator
@staticmethod
def after_transition(func):
"""Executed after transition to STATE"""
def decorator(state):
@wraps(func)
def wrapper(self):
return func(self)
return wrapper
return decorator
def _get_next_states(self, field):
"""Return a choice like list of the available status for the current state."""
state_position = getattr(self, field.attname)
return [(transition.to_state, transition.label.title())
for transition in field.transitions
if transition.to_state == state_position]
def contribute_to_class(self, cls, name, virtual_only=False):
super().contribute_to_class(cls, name, virtual_only)
setattr(cls, 'get_next_states' % self.name, curry(self._get_next_states, field=self))
def validate(self, value, model_instance):
super().validate(value, model_instance)
# TODO: Validate state transitions
from django.db import models
from .fields import StateField
class Order(models.Model):
NEW = 'new'
CANCELLED = 'cancelled'
PAYMENT_PENDING = 'payment-pending'
FULLY_PAID = 'fully-paid'
SHIPPED = 'shipped'
RECEIVED = 'received'
COMPLETED = 'completed'
user = models.ForeignKey(settings.AUTH_USER_MODEL)
status = StateField(steps=(
StateField.Starts(NEW),
StateField.Transition(NEW, CANCELLED, 'Cancel'),
StateField.Transition(NEW, PAYMENT_PENDING, 'Confirm Sale'),
StateField.Transition(PAYMENT_PENDING, CANCELLED, 'Cancel'),
StateField.Transition(PAYMENT_PENDING, FULLY_PAID, 'Process Payment'),
StateField.Transition(FULLY_PAID, SHIPPED, 'Seller ships'),
StateField.Transition(SHIPPED, RECEIVED, 'Buyer received'),
StateField.Transition(RECEIVED, COMPLETED, 'Completed'),
StateField.Ends(CANCELLED),
StateField.Ends(COMPLETED),
), default=NEW)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment