Created
August 1, 2020 19:57
-
-
Save jtribble/c9512980c5ac827c8108099ba771c88e to your computer and use it in GitHub Desktop.
A simple graph represention in Python along with some common methods, like breadth-first traversal, depth-first traversal, and shortest path.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from collections import deque | |
from dataclasses import dataclass, field | |
from typing import Set, Optional, List, Generator, Dict, Deque | |
from unittest import TestCase, main | |
@dataclass | |
class GraphNode: | |
label: str | |
neighbors: Set['GraphNode'] = field(default_factory=set) | |
color: Optional[str] = None | |
def __hash__(self) -> int: | |
return hash(self.label) | |
def __repr__(self) -> str: | |
return f'GraphNode(label={self.label}, color={self.color}, neighbors=[{", ".join(n.label for n in self.neighbors)}])' | |
@dataclass | |
class Graph: | |
nodes: Set[GraphNode] = field(default_factory=set) | |
def breadth_first_traversal(self, first_node: Optional[GraphNode] = None) -> Generator[GraphNode, None, None]: | |
visited = set() # type: Set[GraphNode] | |
if first_node is None: | |
first_node = next(iter(self.nodes)) | |
visited.add(first_node) | |
queue = deque([first_node]) # type: Deque[GraphNode] | |
while queue: | |
node = queue.popleft() | |
yield node | |
for neighbor in node.neighbors: | |
if neighbor not in visited: | |
visited.add(neighbor) | |
queue.append(neighbor) | |
def depth_first_traversal(self, first_node: Optional[GraphNode] = None) -> Generator[GraphNode, None, None]: | |
visited = set() # type: Set[GraphNode] | |
if first_node is None: | |
first_node = next(iter(self.nodes)) | |
visited.add(first_node) | |
stack = [first_node] # type: List[GraphNode] | |
while stack: | |
node = stack.pop() | |
yield node | |
for neighbor in node.neighbors: | |
if neighbor not in visited: | |
visited.add(neighbor) | |
stack.append(neighbor) | |
def apply_colors(self, colors: Set[str]) -> None: | |
for node in self.nodes: | |
illegal_colors = set(n.color for n in node.neighbors if n.color is not None) | |
for color in colors: | |
if color not in illegal_colors: | |
node.color = color | |
break | |
else: | |
raise ValueError(f"Can't apply legal color to {node}") | |
@staticmethod | |
def shortest_path(source: GraphNode, destination: GraphNode) -> List[GraphNode]: | |
queue = deque([[source]]) # type: Deque[List[GraphNode]] | |
while queue: | |
path = queue.popleft() | |
last_node = path[-1] | |
for neighbor in last_node.neighbors: | |
if neighbor == destination: | |
path.append(destination) | |
return path | |
if neighbor not in path: | |
queue.append(path + [neighbor]) | |
raise ValueError(f'Could not find path from {source} to {destination}') | |
# These aren't really tests—most of the methods just print out state for you to observe. | |
class UnitTests(TestCase): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
nodes = {} # type: Dict[int, GraphNode] | |
for n in range(1, 13): | |
nodes[n] = GraphNode(str(n)) | |
nodes[1].neighbors.update({nodes[2], nodes[3], nodes[5]}) | |
nodes[2].neighbors.update({nodes[1], nodes[5], nodes[7]}) | |
nodes[3].neighbors.update({nodes[1], nodes[5], nodes[6]}) | |
nodes[4].neighbors.update({nodes[2], nodes[6], nodes[7]}) | |
nodes[5].neighbors.update({nodes[1], nodes[3], nodes[10]}) | |
nodes[6].neighbors.update({nodes[3], nodes[4], nodes[5]}) | |
nodes[7].neighbors.update({nodes[2], nodes[4], nodes[11]}) | |
nodes[8].neighbors.update({nodes[6], nodes[9], nodes[11]}) | |
nodes[9].neighbors.update({nodes[8], nodes[10], nodes[12]}) | |
nodes[10].neighbors.update({nodes[5], nodes[9], nodes[12]}) | |
nodes[11].neighbors.update({nodes[7], nodes[8], nodes[12]}) | |
nodes[12].neighbors.update({nodes[9], nodes[10], nodes[11]}) | |
self.graph = Graph(set(nodes.values())) | |
self.nodes = nodes | |
def test_breadth_first_traversal(self) -> None: | |
print('Breadth-first traversal:') | |
for i, node in enumerate(self.graph.breadth_first_traversal(self.nodes[1])): | |
print(f'{str(i + 1).rjust(2)}. {node}') | |
def test_depth_first_traversal(self) -> None: | |
print('Depth-first traversal:') | |
for i, node in enumerate(self.graph.depth_first_traversal(self.nodes[1])): | |
print(f'{str(i + 1).rjust(2)}. {node}') | |
def test_legal_graph_coloring(self) -> None: | |
colors = {'red', 'blue', 'green', 'yellow'} | |
print(f'Coloring graph with {len(colors)} colors: {colors}') | |
self.graph.apply_colors(colors) | |
print('Breadth-first traversal:') | |
for i, node in enumerate(self.graph.breadth_first_traversal(self.nodes[1])): | |
print(f'{str(i + 1).rjust(2)}. {node}') | |
def test_illegal_graph_coloring(self) -> None: | |
colors = {'blue', 'black'} | |
print(f'Coloring graph with {len(colors)} colors: {colors}') | |
with self.assertRaises(ValueError): | |
self.graph.apply_colors(colors) | |
def test_shortest_path(self) -> None: | |
print(f'Shortest path from 5 => 6') | |
for i, node in enumerate(self.graph.shortest_path(self.nodes[5], self.nodes[6])): | |
print(f'{i}. {node}') | |
print(f'\nShortest path from 2 => 12') | |
for i, node in enumerate(self.graph.shortest_path(self.nodes[2], self.nodes[12])): | |
print(f'{i}. {node}') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment