Created
March 30, 2021 06:05
-
-
Save alysivji/0475531d2703373c024a60a63b8a2db6 to your computer and use it in GitHub Desktop.
Exporting Mermaid Markdown for django-fsm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8; mode: django -*- | |
import graphviz | |
from optparse import make_option | |
from django.core.management.base import BaseCommand | |
from django.utils.encoding import force_text | |
from django_fsm import FSMFieldMixin, GET_STATE, RETURN_VALUE | |
try: | |
from django.db.models import get_apps, get_app, get_models, get_model | |
NEW_META_API = False | |
except ImportError: | |
from django.apps import apps | |
NEW_META_API = True | |
from django import VERSION | |
HAS_ARGPARSE = VERSION >= (1, 10) | |
def all_fsm_fields_data(model): | |
if NEW_META_API: | |
return [(field, model) for field in model._meta.get_fields() | |
if isinstance(field, FSMFieldMixin)] | |
else: | |
return [(field, model) for field in model._meta.fields | |
if isinstance(field, FSMFieldMixin)] | |
def node_name(field, state): | |
opts = field.model._meta | |
return "%s.%s.%s.%s" % (opts.app_label, opts.verbose_name.replace(' ', '_'), field.name, state) | |
def node_label(field, state): | |
if isinstance(state, int): | |
return force_text(dict(field.choices).get(state)) | |
else: | |
return state | |
def add_transition_mermaid(transition_source, transition_target, transition_name, source_name, field, sources, targets, edges): | |
sources.add((source_name, transition_source)) | |
targets.add((transition_target, transition_target)) | |
edges.add((source_name, transition_target, transition_name)) | |
def generate_mermaid(fields_data): | |
for field, model in fields_data: | |
sources, targets, edges, any_targets, any_except_targets = set(), set(), set(), set(), set() | |
# dump nodes and edges | |
for transition in field.get_all_transitions(model): | |
if transition.source == '*': | |
any_targets.add((transition.target, transition.name)) | |
elif transition.source == '+': | |
any_except_targets.add((transition.target, transition.name)) | |
else: | |
if transition.target is not None: | |
if isinstance(transition.target, GET_STATE) or isinstance(transition.target, RETURN_VALUE): | |
if transition.target.allowed_states: | |
for transition_target_index, transition_target in enumerate(transition.target.allowed_states): | |
add_transition_mermaid(transition.source, transition_target, transition.name, | |
transition.source, field, sources, targets, edges) | |
else: | |
add_transition_mermaid(transition.source, transition.target, transition.name, | |
transition.source, field, sources, targets, edges) | |
if transition.on_error: | |
edges.add((source_name, transition.on_error, f"{transition.name}_on_error")) | |
for target, name in any_targets: | |
for source_name, label in sources: | |
edges.add((source_name, target, name)) | |
for target, name in any_except_targets: | |
for source_name, label in sources: | |
if target == source_name: | |
continue | |
edges.add((source_name, target, name)) | |
# TODO format for all versions of python | |
mermaid_markdown = "stateDiagram-v2\n" | |
transition_template = " {source} --> {target} : {name}\n" | |
if field.default: # Adding initial state notation | |
mermaid_markdown += f" [*] --> {field.default}\n" | |
for source, target, transition_name in edges: | |
mermaid_markdown += transition_template.format(source=source, target=target, name=transition_name) | |
print(mermaid_markdown) | |
def generate_dot(fields_data): | |
result = graphviz.Digraph() | |
for field, model in fields_data: | |
sources, targets, edges, any_targets, any_except_targets = set(), set(), set(), set(), set() | |
# dump nodes and edges | |
for transition in field.get_all_transitions(model): | |
if transition.source == '*': | |
any_targets.add((transition.target, transition.name)) | |
elif transition.source == '+': | |
any_except_targets.add((transition.target, transition.name)) | |
else: | |
source_name = node_name(field, transition.source) | |
if transition.target is not None: | |
if isinstance(transition.target, GET_STATE) or isinstance(transition.target, RETURN_VALUE): | |
if transition.target.allowed_states: | |
for transition_target_index, transition_target in enumerate(transition.target.allowed_states): | |
add_transition(transition.source, transition_target, transition.name, | |
source_name, field, sources, targets, edges) | |
else: | |
add_transition(transition.source, transition.target, transition.name, | |
source_name, field, sources, targets, edges) | |
if transition.on_error: | |
on_error_name = node_name(field, transition.on_error) | |
targets.add( | |
(on_error_name, node_label(field, transition.on_error)) | |
) | |
edges.add((source_name, on_error_name, (('style', 'dotted'),))) | |
for target, name in any_targets: | |
target_name = node_name(field, target) | |
targets.add((target_name, node_label(field, target))) | |
for source_name, label in sources: | |
edges.add((source_name, target_name, (('label', name),))) | |
for target, name in any_except_targets: | |
target_name = node_name(field, target) | |
targets.add((target_name, node_label(field, target))) | |
for source_name, label in sources: | |
if target_name == source_name: | |
continue | |
edges.add((source_name, target_name, (('label', name),))) | |
# construct subgraph | |
opts = field.model._meta | |
subgraph = graphviz.Digraph( | |
name="cluster_%s_%s_%s" % (opts.app_label, opts.object_name, field.name), | |
graph_attr={'label': "%s.%s.%s" % (opts.app_label, opts.object_name, field.name)}) | |
final_states = targets - sources | |
for name, label in final_states: | |
subgraph.node(name, label=label, shape='doublecircle') | |
for name, label in (sources | targets) - final_states: | |
subgraph.node(name, label=label, shape='circle') | |
if field.default: # Adding initial state notation | |
if label == field.default: | |
initial_name = node_name(field, '_initial') | |
subgraph.node(name=initial_name, label='', shape='point') | |
subgraph.edge(initial_name, name) | |
for source_name, target_name, attrs in edges: | |
subgraph.edge(source_name, target_name, **dict(attrs)) | |
result.subgraph(subgraph) | |
return result | |
def add_transition(transition_source, transition_target, transition_name, source_name, field, sources, targets, edges): | |
target_name = node_name(field, transition_target) | |
sources.add((source_name, node_label(field, transition_source))) | |
targets.add((target_name, node_label(field, transition_target))) | |
edges.add((source_name, target_name, (('label', transition_name),))) | |
def get_graphviz_layouts(): | |
try: | |
import graphviz | |
return graphviz.backend.ENGINES | |
except Exception: | |
return {'sfdp', 'circo', 'twopi', 'dot', 'neato', 'fdp', 'osage', 'patchwork'} | |
class Command(BaseCommand): | |
requires_system_checks = True | |
if not HAS_ARGPARSE: | |
option_list = BaseCommand.option_list + ( | |
make_option('--output', '-o', action='store', dest='outputfile', | |
help=('Render output file. Type of output dependent on file extensions. ' | |
'Use png or jpg to render graph to image.')), | |
# NOQA | |
make_option('--layout', '-l', action='store', dest='layout', default='dot', | |
help=('Layout to be used by GraphViz for visualization. ' | |
'Layouts: %s.' % ' '.join(get_graphviz_layouts()))), | |
) | |
args = "[appname[.model[.field]]]" | |
else: | |
def add_arguments(self, parser): | |
parser.add_argument( | |
'--output', '-o', action='store', dest='outputfile', | |
help=('Render output file. Type of output dependent on file extensions. ' | |
'Use png or jpg to render graph to image.')) | |
parser.add_argument( | |
'--layout', '-l', action='store', dest='layout', default='dot', | |
help=('Layout to be used by GraphViz for visualization. ' | |
'Layouts: %s.' % ' '.join(get_graphviz_layouts()))) | |
parser.add_argument('args', nargs='*', | |
help=('[appname[.model[.field]]]')) | |
help = ("Creates a GraphViz dot file with transitions for selected fields") | |
def render_output(self, graph, **options): | |
filename, format = options['outputfile'].rsplit('.', 1) | |
graph.engine = options['layout'] | |
graph.format = format | |
graph.render(filename) | |
def handle(self, *args, **options): | |
fields_data = [] | |
if len(args) != 0: | |
for arg in args: | |
field_spec = arg.split('.') | |
if len(field_spec) == 1: | |
if NEW_META_API: | |
app = apps.get_app(field_spec[0]) | |
models = apps.get_models(app) | |
else: | |
app = get_app(field_spec[0]) | |
models = get_models(app) | |
for model in models: | |
fields_data += all_fsm_fields_data(model) | |
elif len(field_spec) == 2: | |
if NEW_META_API: | |
model = apps.get_model(field_spec[0], field_spec[1]) | |
else: | |
model = get_model(field_spec[0], field_spec[1]) | |
fields_data += all_fsm_fields_data(model) | |
elif len(field_spec) == 3: | |
if NEW_META_API: | |
model = apps.get_model(field_spec[0], field_spec[1]) | |
else: | |
model = get_model(field_spec[0], field_spec[1]) | |
fields_data += all_fsm_fields_data(model) | |
else: | |
if NEW_META_API: | |
for model in apps.get_models(): | |
fields_data += all_fsm_fields_data(model) | |
else: | |
for app in get_apps(): | |
for model in get_models(app): | |
fields_data += all_fsm_fields_data(model) | |
fields_data = fields_data[:1] | |
mermaid_data = generate_mermaid(fields_data) | |
dotdata = generate_dot(fields_data) | |
if options['outputfile']: | |
self.render_output(dotdata, **options) | |
else: | |
print(dotdata) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment