Created
May 3, 2013 14:16
-
-
Save calebsmith/5509360 to your computer and use it in GitHub Desktop.
FSM example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "Untitled1" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from functools import partial\n", | |
"\n", | |
"\n", | |
"class FSM(object):\n", | |
" \"\"\"\n", | |
" A simple finite state machine that defines a set of states and the\n", | |
" transitions between those states. The constructor takes the name of the\n", | |
" initial state, a list of transitions, and a dictionary of callbacks.\n", | |
" `transitions` is a list of dictionaries with 'source', 'destination', and\n", | |
" 'name' keys. Transitions of the same name may have the different sources or\n", | |
" map to different destinations.\n", | |
"\n", | |
" Callbacks have keys that correspond to transition names and are prefixed\n", | |
" with 'on_' or 'on_before_'. For example, a transition named 'enter'\n", | |
" could have the callbacks 'on_enter' and 'on_before_enter'. The\n", | |
" 'on_before' callbacks return True if the transition should occur or False\n", | |
" to short circuit the impending transition. Return values of the 'on'\n", | |
" callbacks are passed to the caller of the transition (e.g. result in the\n", | |
" expression `result = fms.enter()` would be set to the return of `on_enter`)\n", | |
"\n", | |
" Transitions may take any positional or keyword arguments. These are passed\n", | |
" to the associated callback(s) for the transition. In the above example, a\n", | |
" call to fms.enter(score=10), would pass the score=10 kwarg to the\n", | |
" `on_enter` and `on_before_enter` for handling.\n", | |
" \"\"\"\n", | |
"\n", | |
" # Tracks the set of possible states or \"nodes\" that the machine may enter\n", | |
" possible_states = set()\n", | |
" # Maintains the mapping of a source/transition name pair to a destination\n", | |
" transitions = {}\n", | |
" # Maintains the mapping of each callback's name to the callback function\n", | |
" callbacks = {}\n", | |
" # Maintains the mapping of each source to the list of possible transitions\n", | |
" _source_to_names = {}\n", | |
"\n", | |
" class IllegalNameException(Exception):\n", | |
" pass\n", | |
"\n", | |
" class IllegalCallbackException(Exception):\n", | |
" pass\n", | |
"\n", | |
" class IllegalTransitionException(Exception):\n", | |
" pass\n", | |
"\n", | |
" def __init__(self, initial, transitions=None, callbacks=None):\n", | |
" # Define the initial state and store the transitions and callbacks if\n", | |
" # provided\n", | |
" self.state = initial\n", | |
" self.possible_states.add(initial)\n", | |
" callbacks = callbacks or {}\n", | |
" transitions = transitions or []\n", | |
" map(self.add_transition, transitions)\n", | |
" map(\n", | |
" lambda pair: self.add_callback(pair[0], pair[1]), callbacks.items()\n", | |
" )\n", | |
"\n", | |
" def add_transition(self, transition):\n", | |
" \"\"\"\n", | |
" Given a transition dictionary that defines a `source`, `name`, and\n", | |
" `destination`, add the transition function with the `name` that moves\n", | |
" the state from the `source` to the `destination`.\n", | |
"\n", | |
" An IllegalNameException is thrown if the transition name would override\n", | |
" another method.\n", | |
" \"\"\"\n", | |
" source, name = transition['source'], transition['name']\n", | |
" destination = transition['destination']\n", | |
" # Assure transition names won't override existing methods\n", | |
" transition_names = [\n", | |
" t_name for t_source, t_name in self.transitions.keys()\n", | |
" ]\n", | |
" reserved_methods = set(dir(self)).difference(transition_names)\n", | |
" if name in reserved_methods:\n", | |
" err_msg = u'The transition name `{0}` shadows an existing method'\n", | |
" raise self.IllegalNameException(\n", | |
" err_msg.format(name)\n", | |
" )\n", | |
" # Update transitions, possible_states\n", | |
" self.transitions.update({\n", | |
" (source, name): destination\n", | |
" })\n", | |
" existing_names = self._source_to_names.get(source, [])\n", | |
" existing_names.append(name)\n", | |
" self._source_to_names[source] = existing_names\n", | |
" self.possible_states.add(source)\n", | |
" self.possible_states.add(destination)\n", | |
" if not hasattr(self, name):\n", | |
" # Create a function for the transition\n", | |
" func = partial(self._transition_function_factory, source, name)\n", | |
" setattr(self, name, func)\n", | |
"\n", | |
" def add_callback(self, name, func):\n", | |
" \"\"\"\n", | |
" Given a `name` and `func`, registers the `func` as a callback for the\n", | |
" transition associated with `name`.\n", | |
"\n", | |
" An IllegalCallbackException is thrown if the callback name does not\n", | |
" correspond to an existing transition. This is meant to safeguard\n", | |
" against registering callbacks with incorrect names, which will never be\n", | |
" called.\n", | |
" \"\"\"\n", | |
" # Determine the name of the associated transition\n", | |
" transition_name = name[3:]\n", | |
" if transition_name.startswith('before_'):\n", | |
" transition_name = transition_name[7:]\n", | |
" transition_names = [\n", | |
" t_name for t_source, t_name in self.transitions.keys()\n", | |
" ]\n", | |
" if transition_name not in transition_names:\n", | |
" err_msg = u'Callback {0} can not be registered because {1} is not a transition name'\n", | |
" raise self.IllegalCallbackException(\n", | |
" err_msg.format(name, transition_name)\n", | |
" )\n", | |
" self.callbacks.update({\n", | |
" name: func\n", | |
" })\n", | |
"\n", | |
" def _transition_function_factory(self, source, name, *args, **kwargs):\n", | |
" \"\"\"\n", | |
" Given an existing transition's `source` and `name` create a method with\n", | |
" that name, that moves the state from the source to the destination\n", | |
" state. When called, validates that the current state is the source\n", | |
" state, and calls any registered callbacks for the transition.\n", | |
" \"\"\"\n", | |
" destination = self.transitions.get((source, name), None)\n", | |
" if destination is not None:\n", | |
" if self.callbacks:\n", | |
" resume = self._call_callback(name, 'before', *args, **kwargs)\n", | |
" if resume:\n", | |
" self.state = destination\n", | |
" return self._call_callback(name, '', *args, **kwargs)\n", | |
" else:\n", | |
" self.state = destination\n", | |
" else:\n", | |
" err_msg = '{0} called when current state was {1}'\n", | |
" raise self.IllegalTransitionException(\n", | |
" err_msg.format(name, self.state)\n", | |
" )\n", | |
"\n", | |
" def _call_callback(self, transition_name, prefix, *args, **kwargs):\n", | |
" \"\"\"Calls the callback on behalf of the transition function\"\"\"\n", | |
" if prefix:\n", | |
" name_parts = ('on', prefix, transition_name)\n", | |
" else:\n", | |
" name_parts = ('on', transition_name)\n", | |
" callback_name = '_'.join(name_parts)\n", | |
" callback = self.callbacks.get(callback_name, None)\n", | |
" if callback:\n", | |
" return callback(*args, **kwargs)\n", | |
" return True\n", | |
"\n", | |
" def is_state(self, check_state):\n", | |
" \"\"\"Checks if the current state is `check_state`\"\"\"\n", | |
" return self.state == check_state\n", | |
"\n", | |
" def can(self, name):\n", | |
" \"\"\"\n", | |
" Checks if the given `name` is a possible transition from the current\n", | |
" state\n", | |
" \"\"\"\n", | |
" return name in self._source_to_names[self.state]\n", | |
"\n", | |
" def __repr__(self):\n", | |
" return u'State machine: ({0}) '.format(self.state) + u' '.join([\n", | |
" state\n", | |
" for state in self.possible_states\n", | |
" if state != self.state\n", | |
" ])\n", | |
"\n", | |
" def callbacks_display(self):\n", | |
" return self.callbacks.keys()\n", | |
"\n", | |
" def transitions_display(self):\n", | |
" return sorted([\n", | |
" '{0}: {1} -> {2}'.format(name, source, destination)\n", | |
" for (source, name), destination in self.transitions.items()\n", | |
" ])\n" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"\"\"\"Defines transitions and callbacks for the game's fsm\"\"\"\n", | |
"\n", | |
"from fsm import FSM\n", | |
"from game_assets import Level\n", | |
"\n", | |
"# FSM transitions and callbacks\n", | |
"transitions = [\n", | |
" {\n", | |
" 'name': 'answer',\n", | |
" 'source': 'question',\n", | |
" 'destination': 'main',\n", | |
" },\n", | |
" {\n", | |
" 'name': 'item_collected',\n", | |
" 'source': 'item',\n", | |
" 'destination': 'main',\n", | |
" },\n", | |
" {\n", | |
" 'name': 'popup_question',\n", | |
" 'source': 'main',\n", | |
" 'destination': 'question'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'popup_item',\n", | |
" 'source': 'main',\n", | |
" 'destination': 'item'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'exit',\n", | |
" 'source': 'main',\n", | |
" 'destination': 'exit'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'exit',\n", | |
" 'source': 'question',\n", | |
" 'destination': 'exit'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'exit',\n", | |
" 'source': 'item',\n", | |
" 'destination': 'exit'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'exit',\n", | |
" 'source': 'splash',\n", | |
" 'destination': 'exit'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'exit',\n", | |
" 'source': 'endscreen',\n", | |
" 'destination': 'exit'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'start',\n", | |
" 'source': 'splash',\n", | |
" 'destination': 'main'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'end',\n", | |
" 'source': 'main',\n", | |
" 'destination': 'endscreen'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'popup_info',\n", | |
" 'source': 'main',\n", | |
" 'destination': 'info'\n", | |
" },\n", | |
" {\n", | |
" 'name': 'info_closed',\n", | |
" 'source': 'info',\n", | |
" 'destination': 'main'\n", | |
" },\n", | |
"]\n", | |
"\n", | |
"\n", | |
"def add_item_to_inventory(player, item):\n", | |
" player.items.append(item)\n", | |
" player.current_item = item\n", | |
" return True\n", | |
"\n", | |
"\n", | |
"def handle_answer(is_correct, level, player):\n", | |
" if is_correct:\n", | |
" # Remove all monsters on the player's current location\n", | |
" for monster in level.monsters:\n", | |
" if (monster.x, monster.y) == (player.x, player.y):\n", | |
" level.monsters.remove(monster)\n", | |
" else:\n", | |
" level.reset_items(player)\n", | |
" level.reset_monsters()\n", | |
" game_state.popup_info()\n", | |
"\n", | |
"\n", | |
"def handle_item_collected(level, player):\n", | |
" if len(level.items) == len(player.items):\n", | |
" # no more items. you win!\n", | |
" game_state.end()\n", | |
"\n", | |
"# No callbacks for now. Refer to fsm.py when implementating\n", | |
"callbacks = {\n", | |
" 'on_before_popup_item': add_item_to_inventory,\n", | |
" 'on_answer': handle_answer,\n", | |
" 'on_item_collected': handle_item_collected,\n", | |
"}\n", | |
"\n", | |
"game_state = FSM('splash', transitions, callbacks)\n" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 4 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"print game_state" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"State machine: (splash) info endscreen question item exit main\n" | |
] | |
} | |
], | |
"prompt_number": 5 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"game_state.transitions_display()" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "pyout", | |
"prompt_number": 7, | |
"text": [ | |
"['answer: question -> main',\n", | |
" 'end: main -> endscreen',\n", | |
" 'exit: endscreen -> exit',\n", | |
" 'exit: item -> exit',\n", | |
" 'exit: main -> exit',\n", | |
" 'exit: question -> exit',\n", | |
" 'exit: splash -> exit',\n", | |
" 'info_closed: info -> main',\n", | |
" 'item_collected: item -> main',\n", | |
" 'popup_info: main -> info',\n", | |
" 'popup_item: main -> item',\n", | |
" 'popup_question: main -> question',\n", | |
" 'start: splash -> main']" | |
] | |
} | |
], | |
"prompt_number": 7 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"game_state.callbacks_display()" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "pyout", | |
"prompt_number": 8, | |
"text": [ | |
"['on_answer', 'on_item_collected', 'on_before_popup_item']" | |
] | |
} | |
], | |
"prompt_number": 8 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
http://nbviewer.ipython.org/5509360