Last active
March 8, 2024 07:23
-
-
Save HirbodBehnam/54b708cd1ac21c0606c631e8fe957a87 to your computer and use it in GitHub Desktop.
Convert NFA to DFA and visualize it with networkx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import networkx as nx | |
import matplotlib.pyplot as plt | |
EPSILON = "e" | |
class NFA: | |
def __init__(self, alphabet: list[str], adjacency_list: list[dict[str, list[int]]]): | |
""" | |
adjacency_list is a list of states that each state has an dict that specifies | |
each transition to new state. | |
""" | |
self.adjacency_list = adjacency_list | |
self.alphabet = alphabet | |
def to_dfa( | |
self, | |
) -> tuple[frozenset[int], dict[frozenset[int], dict[str, frozenset[int]]]]: | |
result: dict[frozenset[int], dict[str, frozenset[int]]] = {} | |
# Create the start state | |
to_explore_states: set[frozenset[int]] = set() | |
start_state = self.calculate_epsilon_moves(set([1])) | |
to_explore_states.add(start_state) # Start is the first state | |
# Explore all possible states | |
while len(to_explore_states) != 0: | |
to_explore_state = to_explore_states.pop() | |
current_state_result: dict[str, frozenset[int]] = dict() | |
for s in self.alphabet: | |
next_states = self.next_states(to_explore_state, s) | |
if len(next_states) == 0: | |
continue | |
current_state_result[s] = next_states | |
if next_states not in result: | |
to_explore_states.add(next_states) | |
result[to_explore_state] = current_state_result | |
return start_state, result | |
def calculate_epsilon_moves(self, current_states: frozenset[int]) -> frozenset[int]: | |
result = set(current_states) | |
while True: | |
new_state = result.copy() | |
for state in result: | |
if EPSILON in self.adjacency_list[state]: | |
new_state.update(self.adjacency_list[state][EPSILON]) | |
if len(result) == len(new_state): | |
break | |
result = new_state | |
return frozenset(result) | |
def next_states( | |
self, current_states: frozenset[int], transition: str | |
) -> frozenset[int]: | |
result: set[int] = set() | |
for state in current_states: | |
if transition in self.adjacency_list[state]: | |
result.update(self.adjacency_list[state][transition]) | |
return self.calculate_epsilon_moves(result) | |
# Q1 | |
nfa_stuff = [ | |
{}, # Zeroth state is dummy | |
{"a": [2], EPSILON: [5, 17]}, # 1 | |
{"a": [3]}, # 2 | |
{"a": [4]}, # 3 | |
{EPSILON: [1, 5]}, # 4 | |
{EPSILON: [6, 9, 13]}, # 5 | |
{"a": [7]}, # 6 | |
{"b": [8]}, # 7 | |
{"a": [12]}, # 8 | |
{"b": [10]}, # 9 | |
{"a": [11]}, # 10 | |
{"b": [12]}, # 11 | |
{EPSILON: [5, 13]}, # 12 | |
{"b": [14], EPSILON: [17]}, # 13 | |
{"b": [15]}, # 14 | |
{"b": [16]}, # 15 | |
{EPSILON: [13, 17]}, # 16 | |
{EPSILON: [1]}, # 17 | |
] | |
# Q2 | |
nfa_stuff = [ | |
{}, # Dummy | |
{"a": [2], "b": [2], EPSILON: [2]}, # 1 | |
{EPSILON: [3, 5, 7]}, # 2 | |
{"a": [4]}, # 3 | |
{"a": [7]}, # 4 | |
{"b": [6]}, # 5 | |
{"b": [7]}, # 6 | |
{EPSILON: [8, 10, 12]}, # 7 | |
{"a": [9]}, # 8 | |
{"b": [12]}, # 9 | |
{"b": [11]}, # 10 | |
{"a": [12]}, # 11 | |
{EPSILON: [7]}, # 12 | |
] | |
nfa = NFA(["a", "b"], nfa_stuff) | |
start_state, dfa = nfa.to_dfa() | |
print(dfa) | |
# Visualize | |
def my_draw_networkx_edge_labels( | |
G, | |
pos, | |
edge_labels=None, | |
label_pos=0.5, | |
font_size=10, | |
font_color="k", | |
font_family="sans-serif", | |
font_weight="normal", | |
alpha=None, | |
bbox=None, | |
horizontalalignment="center", | |
verticalalignment="center", | |
ax=None, | |
rotate=True, | |
clip_on=True, | |
rad=0, | |
): | |
"""Draw edge labels. | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
pos : dictionary | |
A dictionary with nodes as keys and positions as values. | |
Positions should be sequences of length 2. | |
edge_labels : dictionary (default={}) | |
Edge labels in a dictionary of labels keyed by edge two-tuple. | |
Only labels for the keys in the dictionary are drawn. | |
label_pos : float (default=0.5) | |
Position of edge label along edge (0=head, 0.5=center, 1=tail) | |
font_size : int (default=10) | |
Font size for text labels | |
font_color : string (default='k' black) | |
Font color string | |
font_weight : string (default='normal') | |
Font weight | |
font_family : string (default='sans-serif') | |
Font family | |
alpha : float or None (default=None) | |
The text transparency | |
bbox : Matplotlib bbox, optional | |
Specify text box properties (e.g. shape, color etc.) for edge labels. | |
Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}. | |
horizontalalignment : string (default='center') | |
Horizontal alignment {'center', 'right', 'left'} | |
verticalalignment : string (default='center') | |
Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'} | |
ax : Matplotlib Axes object, optional | |
Draw the graph in the specified Matplotlib axes. | |
rotate : bool (deafult=True) | |
Rotate edge labels to lie parallel to edges | |
clip_on : bool (default=True) | |
Turn on clipping of edge labels at axis boundaries | |
Returns | |
------- | |
dict | |
`dict` of labels keyed by edge | |
Examples | |
-------- | |
>>> G = nx.dodecahedral_graph() | |
>>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G)) | |
Also see the NetworkX drawing examples at | |
https://networkx.org/documentation/latest/auto_examples/index.html | |
See Also | |
-------- | |
draw | |
draw_networkx | |
draw_networkx_nodes | |
draw_networkx_edges | |
draw_networkx_labels | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
if ax is None: | |
ax = plt.gca() | |
if edge_labels is None: | |
labels = {(u, v): d for u, v, d in G.edges(data=True)} | |
else: | |
labels = edge_labels | |
text_items = {} | |
for (n1, n2), label in labels.items(): | |
(x1, y1) = pos[n1] | |
(x2, y2) = pos[n2] | |
(x, y) = ( | |
x1 * label_pos + x2 * (1.0 - label_pos), | |
y1 * label_pos + y2 * (1.0 - label_pos), | |
) | |
pos_1 = ax.transData.transform(np.array(pos[n1])) | |
pos_2 = ax.transData.transform(np.array(pos[n2])) | |
linear_mid = 0.5 * pos_1 + 0.5 * pos_2 | |
d_pos = pos_2 - pos_1 | |
rotation_matrix = np.array([(0, 1), (-1, 0)]) | |
ctrl_1 = linear_mid + rad * rotation_matrix @ d_pos | |
ctrl_mid_1 = 0.5 * pos_1 + 0.5 * ctrl_1 | |
ctrl_mid_2 = 0.5 * pos_2 + 0.5 * ctrl_1 | |
bezier_mid = 0.5 * ctrl_mid_1 + 0.5 * ctrl_mid_2 | |
(x, y) = ax.transData.inverted().transform(bezier_mid) | |
if rotate: | |
# in degrees | |
angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360 | |
# make label orientation "right-side-up" | |
if angle > 90: | |
angle -= 180 | |
if angle < -90: | |
angle += 180 | |
# transform data coordinate angle to screen coordinate angle | |
xy = np.array((x, y)) | |
trans_angle = ax.transData.transform_angles( | |
np.array((angle,)), xy.reshape((1, 2)) | |
)[0] | |
else: | |
trans_angle = 0.0 | |
# use default box of white with white border | |
if bbox is None: | |
bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)) | |
if not isinstance(label, str): | |
label = str(label) # this makes "1" and 1 labeled the same | |
t = ax.text( | |
x, | |
y, | |
label, | |
size=font_size, | |
color=font_color, | |
family=font_family, | |
weight=font_weight, | |
alpha=alpha, | |
horizontalalignment=horizontalalignment, | |
verticalalignment=verticalalignment, | |
rotation=trans_angle, | |
transform=ax.transData, | |
bbox=bbox, | |
zorder=1, | |
clip_on=clip_on, | |
) | |
text_items[(n1, n2)] = t | |
ax.tick_params( | |
axis="both", | |
which="both", | |
bottom=False, | |
left=False, | |
labelbottom=False, | |
labelleft=False, | |
) | |
return text_items | |
G = nx.DiGraph() | |
for state in dfa.keys(): | |
G.add_node(state, label=str(sorted(state))) | |
color_map = ["green" if node == start_state else "blue" for node in G] | |
for current_state, transitions in dfa.items(): | |
for transition, next_state in transitions.items(): | |
G.add_edge(current_state, next_state, label=transition) | |
pos = nx.spring_layout(G) | |
nx.draw_networkx_nodes(G, pos, node_color=color_map) | |
curved_edges = [edge for edge in G.edges() if reversed(edge) in G.edges()] | |
straight_edges = list(set(G.edges()) - set(curved_edges)) | |
nx.draw_networkx_edges(G, pos, edgelist=straight_edges) | |
nx.draw_networkx_edges(G, pos, edgelist=curved_edges, connectionstyle="arc3, rad = 0.1") | |
nx.draw_networkx_labels(G, pos, labels=nx.get_node_attributes(G, "label")) | |
edge_weights = nx.get_edge_attributes(G, "label") | |
curved_edge_labels = {edge: edge_weights[edge] for edge in curved_edges} | |
straight_edge_labels = {edge: edge_weights[edge] for edge in straight_edges} | |
my_draw_networkx_edge_labels( | |
G, pos, edge_labels=curved_edge_labels, rotate=False, rad=0.1 | |
) | |
nx.draw_networkx_edge_labels(G, pos, edge_labels=straight_edge_labels, rotate=False) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment