Created
June 26, 2018 14:58
-
-
Save timini/965a5917b741b4c4fba93dbc4ea55f58 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from pprint import pprint | |
class Graph: | |
def __init__(self, directed=False): | |
self._adj_matrix = np.zeros([0,0]) | |
self._nodes = [] | |
self._directed = directed | |
def debug(self): | |
pprint(self._nodes) | |
pprint(self._adj_matrix) | |
def add_node(self, N): | |
if not self.has_node(N): | |
self._nodes.append(N) | |
self._adj_matrix = np.pad(self._adj_matrix, (0,1), mode='constant') | |
def has_node(self, N): | |
return N in self._nodes | |
def has_edge(self, A, B): | |
A_index = self._nodes.index(A) | |
B_index = self._nodes.index(B) | |
return self._adj_matrix[A_index][B_index] == 1 | |
def add_edge(self, A, B): | |
A_index = self._nodes.index(A) | |
B_index = self._nodes.index(B) | |
self._adj_matrix[A_index][B_index] = 1 | |
if not self._directed: | |
self._adj_matrix[B_index][A_index] = 1 | |
def remove_edge(self, A, B): | |
A_index = self._nodes.index(A) | |
B_index = self._nodes.index(B) | |
self._adj_matrix[A_index][B_index] = 0 | |
def get_connected_nodes(self, N, visited=None): | |
if visited is None: | |
visited = [] | |
if N in visited: | |
return set([]) | |
visited.extend(N) | |
try: | |
current_node_index = self._nodes.index(N) | |
# if a node is not in the graph then it has no connections | |
except: | |
return set([]) | |
connected_nodes = set([]) | |
for i in range(0, len(self._nodes)): | |
if self._adj_matrix[current_node_index][i] == 1: | |
connected_nodes.add(self._nodes[i]) | |
connected_nodes = connected_nodes.union(self.get_connected_nodes(self._nodes[i], visited=visited)) | |
return set(connected_nodes) | |
def is_connected(self, A, B, visited=None): | |
if visited is None: | |
visited = [] | |
# get the indicies of the nodes | |
try: | |
A_index = self._nodes.index(A) | |
B_index = self._nodes.index(B) | |
# if the node is not in the node list then it is not connected! | |
except: | |
return False | |
if A in visited: | |
return self._adj_matrix[A_index][B_index] == 1 | |
visited.extend(A) | |
for i in range(0, len(self._nodes)): | |
if self._adj_matrix[A_index][i] == 1: | |
if i == B_index: | |
return True | |
if self.is_connected(self._nodes[i], B, visited=visited): | |
return True | |
return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment