-
-
Save kmmbvnr/c5bac4383c1148edc45e to your computer and use it in GitHub Desktop.
Django state field that enforce a workflow path
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
from functools import wraps | |
from itertools import chain | |
from django.utils.functional import curry | |
from django.db.models import CharField | |
class StateField(CharField): | |
Starts = namedtuple('Starts', ['state']) | |
Ends = namedtuple('Ends', ['state']) | |
Transition = namedtuple('Transition', ['from_state', 'to_state', 'label']) | |
def __init__(self, steps, *args, **kwargs): | |
self.steps = steps | |
self.transitions = [step for step in self.steps if isinstance(step, StateField.Transition)] | |
self.starts_with = [state for state in self.steps if isinstance(state, StateField.Starts)] | |
self.ends_with = [state for state in self.steps if isinstance(state, StateField.Ends)] | |
kwargs['max_length'] = 255 | |
kwargs['db_index'] = True | |
kwargs['choices'] = self.get_choices() | |
super().__init__(*args, **kwargs) | |
def get_state_choices(self): | |
"""Return a choice like list of all available status.""" | |
choices = list(set(chain(*((step.from_state, step.to_state) for step in self.transitions)))) | |
return [(choices, choices.title()) for choices in choices] | |
@staticmethod | |
def before_transition(func): | |
"""Executed before transition to STATE""" | |
def decorator(state): | |
@wraps(func) | |
def wrapper(self): | |
return func(self) | |
return wrapper | |
return decorator | |
@staticmethod | |
def after_transition(func): | |
"""Executed after transition to STATE""" | |
def decorator(state): | |
@wraps(func) | |
def wrapper(self): | |
return func(self) | |
return wrapper | |
return decorator | |
def _get_next_states(self, field): | |
"""Return a choice like list of the available status for the current state.""" | |
state_position = getattr(self, field.attname) | |
return [(transition.to_state, transition.label.title()) | |
for transition in field.transitions | |
if transition.to_state == state_position] | |
def contribute_to_class(self, cls, name, virtual_only=False): | |
super().contribute_to_class(cls, name, virtual_only) | |
setattr(cls, 'get_next_states' % self.name, curry(self._get_next_states, field=self)) | |
def validate(self, value, model_instance): | |
super().validate(value, model_instance) | |
# TODO: Validate state transitions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.db import models | |
from .fields import StateField | |
class Order(models.Model): | |
NEW = 'new' | |
CANCELLED = 'cancelled' | |
PAYMENT_PENDING = 'payment-pending' | |
FULLY_PAID = 'fully-paid' | |
SHIPPED = 'shipped' | |
RECEIVED = 'received' | |
COMPLETED = 'completed' | |
user = models.ForeignKey(settings.AUTH_USER_MODEL) | |
status = StateField(steps=( | |
StateField.Starts(NEW), | |
StateField.Transition(NEW, CANCELLED, 'Cancel'), | |
StateField.Transition(NEW, PAYMENT_PENDING, 'Confirm Sale'), | |
StateField.Transition(PAYMENT_PENDING, CANCELLED, 'Cancel'), | |
StateField.Transition(PAYMENT_PENDING, FULLY_PAID, 'Process Payment'), | |
StateField.Transition(FULLY_PAID, SHIPPED, 'Seller ships'), | |
StateField.Transition(SHIPPED, RECEIVED, 'Buyer received'), | |
StateField.Transition(RECEIVED, COMPLETED, 'Completed'), | |
StateField.Ends(CANCELLED), | |
StateField.Ends(COMPLETED), | |
), default=NEW) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment