Last active
June 4, 2024 11:39
-
-
Save Kautenja/29f13b543ecd210202417dfc9e328249 to your computer and use it in GitHub Desktop.
Uniform Cost Search (UCS) in Python with path backtrace.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class Node(object):\n", | |
" \"\"\"This class represents a node in a graph.\"\"\"\n", | |
" \n", | |
" def __init__(self, label: str=None):\n", | |
" \"\"\"\n", | |
" Initialize a new node.\n", | |
" \n", | |
" Args:\n", | |
" label: the string identifier for the node\n", | |
" \"\"\"\n", | |
" self.label = label\n", | |
" self.children = []\n", | |
" \n", | |
" def __lt__(self,other):\n", | |
" \"\"\"\n", | |
" Perform the less than operation (self < other).\n", | |
" \n", | |
" Args:\n", | |
" other: the other Node to compare to\n", | |
" \"\"\"\n", | |
" return (self.label < other.label)\n", | |
" \n", | |
" def __gt__(self,other):\n", | |
" \"\"\"\n", | |
" Perform the greater than operation (self > other).\n", | |
" \n", | |
" Args:\n", | |
" other: the other Node to compare to\n", | |
" \"\"\"\n", | |
" return (self.label > other.label)\n", | |
" \n", | |
" def __repr__(self):\n", | |
" \"\"\"Return a string form of this node.\"\"\"\n", | |
" return '{} -> {}'.format(self.label, self.children)\n", | |
" \n", | |
" def add_child(self, node, cost=1):\n", | |
" \"\"\"\n", | |
" Add a child node to this node.\n", | |
" \n", | |
" Args:\n", | |
" node: the node to add to the children\n", | |
" cost: the cost of the edge (default 1)\n", | |
" \"\"\"\n", | |
" edge = Edge(self, node, cost)\n", | |
" self.children.append(edge)\n", | |
" \n", | |
" \n", | |
"class Edge(object):\n", | |
" \"\"\"This class represents an edge in a graph.\"\"\"\n", | |
" \n", | |
" def __init__(self, source: Node, destination: Node, cost: int=1):\n", | |
" \"\"\"\n", | |
" Initialize a new edge.\n", | |
" \n", | |
" Args:\n", | |
" source: the source of the edge\n", | |
" destination: the destination of the edge\n", | |
" cost: the cost of the edge (default 1)\n", | |
" \"\"\"\n", | |
" self.source = source\n", | |
" self.destination = destination\n", | |
" self.cost = cost\n", | |
" \n", | |
" def __repr__(self):\n", | |
" \"\"\"Return a string form of this edge.\"\"\"\n", | |
" return '{}: {}'.format(self.cost, self.destination.label)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"create all the nodes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"S = Node('S')\n", | |
"A = Node('A')\n", | |
"B = Node('B')\n", | |
"C = Node('C')\n", | |
"D = Node('D')\n", | |
"G = Node('G')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"create all the edges" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"S.add_child(A, 1)\n", | |
"S.add_child(G, 12)\n", | |
"\n", | |
"A.add_child(B, 3)\n", | |
"A.add_child(C, 1)\n", | |
"\n", | |
"B.add_child(D, 3)\n", | |
"\n", | |
"C.add_child(D, 1)\n", | |
"C.add_child(G, 2)\n", | |
"\n", | |
"D.add_child(G, 3)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"take a look" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"S -> [1: A, 3: G]\n", | |
"A -> [3: B, 1: C]\n", | |
"B -> [3: D]\n", | |
"C -> [1: D, 2: G]\n", | |
"D -> [3: G]\n", | |
"G -> []\n" | |
] | |
} | |
], | |
"source": [ | |
"_ = [print(node) for node in [S, A, B, C, D, G]]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"```\n", | |
"UCS(root):\n", | |
" Insert the root into the queue\n", | |
" While the queue is not empty\n", | |
" Dequeue the maximum priority element from the queue\n", | |
" (If priorities are same, alphabetically smaller path is chosen)\n", | |
" If the path is ending in the goal state, print the path and exit\n", | |
" Else\n", | |
" Insert all the children of the dequeued element, with the cumulative costs as priority\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from queue import PriorityQueue\n", | |
"\n", | |
"\n", | |
"def ucs(root, goal):\n", | |
" \"\"\"\n", | |
" Return the uniform cost search path from root to gaol.\n", | |
" \n", | |
" Args:\n", | |
" root: the starting node for the search\n", | |
" goal: the goal node for the search\n", | |
" \n", | |
" Returns: a list with the path from root to goal\n", | |
" \n", | |
" Raises: ValueError if goal isn't in the graph\n", | |
" \"\"\"\n", | |
" # create a priority queue of paths\n", | |
" queue = PriorityQueue()\n", | |
" queue.put((0, [root]))\n", | |
" # iterate over the items in the queue\n", | |
" while not queue.empty():\n", | |
" # get the highest priority item\n", | |
" pair = queue.get()\n", | |
" current = pair[1][-1]\n", | |
" # if it's the goal, return\n", | |
" if current.label == goal:\n", | |
" return pair[1]\n", | |
" # add all the edges to the priority queue\n", | |
" for edge in current.children:\n", | |
" # create a new path with the node from the edge\n", | |
" new_path = list(pair[1])\n", | |
" new_path.append(edge.destination)\n", | |
" # append the new path to the queue with the edges priority\n", | |
" queue.put((pair[0] + edge.cost, new_path))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[S -> [1: A, 12: G], A -> [3: B, 1: C], C -> [1: D, 2: G], G -> []]" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ucs(S, 'G')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Reference\n", | |
"\n", | |
"visualization and algorithm courtesy of: [algorithmthoughts](https://algorithmicthoughts.wordpress.com/2012/12/15/artificial-intelligence-uniform-cost-searchucs/)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment