Created
September 3, 2024 20:38
-
-
Save Shavvimal/b6e6c90787d8d081bc23161284ecbcd7 to your computer and use it in GitHub Desktop.
Demonstrates how to visualize a knowledge graph generated using GraphRAG with UMAP embeddings. The script loads entity and relationship data from parquet files and then visualizes the graph using networkx and matplotlib.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import graspologic as gc | |
# Load environment variables | |
from dotenv import load_dotenv | |
_ = load_dotenv() | |
# Define paths | |
FOLDER = "20240903-194043" | |
INPUT_DIR = f"../bin/output/{FOLDER}/artifacts" | |
OUTPUT_DIR = f"../bin/output/{FOLDER}/artifacts" | |
ENTITY_TABLE = "create_final_nodes" | |
RELATIONSHIP_TABLE = "create_final_relationships" | |
# Function to add color column to entities based on community | |
def add_color_column_to_entities(entity_df: pd.DataFrame, community_column: str = "community") -> pd.DataFrame: | |
""" | |
Adds a 'color' column to the entity_df DataFrame based on unique community values. | |
""" | |
# Ensure the community column exists in the dataframe | |
if community_column not in entity_df.columns: | |
raise ValueError(f"The specified community column '{community_column}' does not exist in the DataFrame.") | |
# Get unique communities | |
unique_communities = entity_df[community_column].unique() | |
# Generate a color map for the unique communities | |
color_map = gc.layouts.categorical_colors({str(community): community for community in unique_communities}) | |
# Map each community to its corresponding color | |
entity_df['color'] = entity_df[community_column].map(lambda community: color_map[str(community)]) | |
return entity_df | |
# Convert entities to dictionaries for easier processing | |
def convert_entities_to_dicts(df): | |
"""Convert the entities dataframe to a list of dicts for yfiles-jupyter-graphs.""" | |
nodes_dict = {} | |
df.drop(columns=["size"], inplace=True) | |
# Make sure x and y are float and nor NaN | |
df["x"] = df["x"].astype(float) | |
df["y"] = df["y"].astype(float) | |
df = df[~df["x"].isna()] | |
df = df[~df["y"].isna()] | |
for _, row in df.iterrows(): | |
node_id = row["title"] # Use 'title' as the unique identifier | |
if node_id not in nodes_dict: | |
nodes_dict[node_id] = { | |
"id": node_id, | |
"properties": row.to_dict(), | |
} | |
return list(nodes_dict.values()) | |
# Convert relationships to dictionaries | |
def convert_relationships_to_dicts(df): | |
"""Convert the relationships dataframe to a list of dicts for yfiles-jupyter-graphs.""" | |
relationships = [] | |
for _, row in df.iterrows(): | |
relationships.append({ | |
"start": row["source"], | |
"end": row["target"], | |
"properties": row.to_dict(), | |
}) | |
return relationships | |
# Load the dataframes from parquet files | |
entity_df = pd.read_parquet(f"{OUTPUT_DIR}/{ENTITY_TABLE}.parquet") | |
# Filter by "level" column to avoid "id" duplicates | |
entity_df = entity_df[entity_df["level"] == 0] | |
relationship_df = pd.read_parquet(f"{OUTPUT_DIR}/{RELATIONSHIP_TABLE}.parquet") | |
# Add color to each node based on the community | |
entity_df = add_color_column_to_entities(entity_df) | |
# Convert dataframes to required formats | |
nodes = convert_entities_to_dicts(entity_df) | |
edges = convert_relationships_to_dicts(relationship_df) | |
# Initialize a directed graph | |
G = nx.DiGraph() | |
# Add nodes and their attributes to the graph | |
for node in nodes: | |
node_id = node["id"] | |
properties = node["properties"] | |
G.add_node( | |
node_id, | |
pos=(properties["x"], properties["y"]), | |
size=properties["degree"], | |
node_color=properties["color"], | |
label=properties["title"], # Add the title as a label | |
**properties | |
) | |
# Check if all nodes referenced by edges have positions | |
missing_nodes = set() | |
filtered_edges = [] | |
# Filter edges and add them to the graph | |
for edge in edges: | |
source = edge["start"] | |
target = edge["end"] | |
# Ensure both source and target nodes exist in the graph | |
if source not in G.nodes or target not in G.nodes: | |
missing_nodes.update([source, target]) | |
continue | |
edge_properties = edge["properties"] | |
filtered_edges.append((source, target, edge_properties)) | |
G.add_edge(source, target, **edge_properties) | |
# Debugging: Print any missing nodes | |
if missing_nodes: | |
print(f"Missing nodes detected: {missing_nodes}") | |
print("Some edges will not be drawn due to missing node positions.") | |
# Prepare node position, size, and color for visualization | |
node_position_dict = {node["id"]: (node["properties"]["x"], node["properties"]["y"]) for node in nodes} | |
node_sizes = [node["properties"]["degree"] * 10 for node in nodes] # Scale sizes for better visibility | |
node_colors = [node["properties"]["color"] for node in nodes] # Get color from the node properties | |
assert len(node_position_dict) == len(node_sizes) == len(node_colors), f"Lengths of node attributes do not match ({len(node_position_dict)}, {len(node_sizes)}, {len(node_colors)})" | |
# Prepare node labels for visualization | |
node_labels = {node["id"]: node["properties"]["title"] for node in nodes} | |
# Plot the graph | |
plt.clf() | |
figure = plt.gcf() | |
ax = plt.gca() | |
ax.set_axis_off() | |
figure.set_size_inches(20, 10) | |
figure.set_dpi(400) | |
# Draw nodes with networkx | |
nx.draw_networkx_nodes( | |
G, | |
pos=node_position_dict, | |
node_color=node_colors, | |
node_size=node_sizes, | |
alpha=1.0, | |
linewidths=0.01, | |
node_shape="o", | |
ax=ax, | |
) | |
# Draw edges only for nodes with defined positions | |
nx.draw_networkx_edges(G, pos=node_position_dict, edgelist=filtered_edges, ax=ax, alpha=0.5) | |
# Draw node labels | |
nx.draw_networkx_labels(G, pos=node_position_dict, labels=node_labels, ax=ax, font_size=2) | |
# Display the plot | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment