Skip to content

Instantly share code, notes, and snippets.

@cristipufu
Created May 1, 2025 10:45
Show Gist options
  • Save cristipufu/ab7e1ce5c67a36ad4085b96f9f435efe to your computer and use it in GitHub Desktop.
Save cristipufu/ab7e1ce5c67a36ad4085b96f9f435efe to your computer and use it in GitHub Desktop.
Generate mermaid based on relative file imports
import ast
import os
import re
def extract_imports(tree: ast.Module):
"""
Extract all import statements from an AST tree.
Args:
tree: The AST tree of a Python file.
Returns:
A tuple containing:
- direct_imports (list of str): Names of modules imported directly (e.g., 'os', 'sys').
- from_imports (dict): Dictionary where keys are module names (e.g., 'my_package', '.models')
and values are lists of names imported from that module (e.g., ['User', 'Post']).
"""
direct_imports = []
from_imports = {}
for node in ast.walk(tree):
# Extract direct imports (import x, import x as y)
if isinstance(node, ast.Import):
for alias in node.names:
direct_imports.append(alias.name) # Get the original module name
# Extract from imports (from x import y, from x import y as z)
elif isinstance(node, ast.ImportFrom):
# node.module is the module being imported from (e.g., 'my_package', '.models')
# node.level indicates the level for relative imports (0 for absolute, 1 for ., 2 for .., etc.)
module_name = (
node.module or ""
) # Handle cases like 'from . import module' where module is None
level = node.level # Number of leading dots
# For relative imports, reconstruct the full relative path string
if level > 0:
# Prepend dots based on the level
relative_prefix = "." * level
full_module_name = (
f"{relative_prefix}{module_name}"
if module_name
else relative_prefix
)
else:
full_module_name = module_name # Absolute import
# Get the names being imported from the module
imported_names = [alias.name for alias in node.names]
# Store in the dictionary
# We store the full module name (including dots for relative imports)
if full_module_name not in from_imports:
from_imports[full_module_name] = []
from_imports[full_module_name].extend(imported_names)
return direct_imports, from_imports
def resolve_relative_import_path(current_file_path, relative_module_string):
"""
Resolve a relative module import string (e.g., ".models", "..utils.helpers")
to a potential absolute file path within the project.
Args:
current_file_path: The absolute path of the file containing the import.
relative_module_string: The relative import string starting with '.' or '..'.
Returns:
The resolved absolute file path (pointing to a .py or __init__.py)
or None if the path cannot be resolved to an existing file.
"""
if not relative_module_string.startswith("."):
# This function is specifically for relative imports
return None
parts = relative_module_string.split(".")
# Count the number of leading dots to determine the base directory level
dot_count = 0
for part in parts:
if part == "":
dot_count += 1
else:
break # Stop counting dots when a non-empty part is found
# The actual module/package path parts after the dots
module_parts = parts[dot_count:]
# Get the directory of the current file
current_dir = os.path.dirname(current_file_path)
# Determine the base directory for the relative import
# A single dot (level 1) means the current package.
# The directory containing the current file is part of the current package.
# So, for level N, we go up N-1 directories from the current file's directory
# to reach the package directory indicated by the dots.
base_dir = current_dir
for _ in range(dot_count - 1):
base_dir = os.path.dirname(base_dir)
# Note: This could potentially go above the project root if the relative import is malformed.
# For simplicity, we assume valid relative imports within a project structure.
# Construct the potential path relative to the base_dir
# Join the module parts to form the path within the package
module_path_from_base = os.path.join(*module_parts)
# Check for two possibilities:
# 1. It's a module file (e.g., .utils.helpers -> base_dir/utils/helpers.py)
# 2. It's a package (e.g., .utils -> base_dir/utils/__init__.py)
# 3. It's a package import itself (e.g., from . import models -> base_dir/models/__init__.py or base_dir/models.py)
# The module_parts would be ['models'] in this case. The logic below handles this.
# 4. It's a package import (e.g., from . -> base_dir/__init__.py). module_parts would be empty.
potential_file = os.path.join(base_dir, module_path_from_base + ".py")
potential_init = os.path.join(base_dir, module_path_from_base, "__init__.py")
if os.path.exists(potential_file):
return potential_file
elif os.path.exists(potential_init):
return potential_init
elif not module_parts and dot_count > 0:
# Case like 'from .' or 'from ..', importing the package itself
# The base_dir already points to the directory of the package being imported.
# We look for the __init__.py file in that directory.
potential_package_init = os.path.join(base_dir, "__init__.py")
if os.path.exists(potential_package_init):
return potential_package_init
# If none of the above match, we couldn't resolve it to a local file
return None
def generate_dependency_graph(base_dir: str):
"""
Generate a Mermaid dependency graph showing file dependencies based on local imports.
Args:
base_dir: The base directory to scan for Python files.
Returns:
A string containing the Mermaid graph definition.
"""
file_paths = []
relationships = set() # Use a set to avoid duplicate edges
# Find all Python files
print(f"Scanning directory: {base_dir}")
for root, _, files in os.walk(base_dir):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
file_paths.append(file_path)
print(f"Found {len(file_paths)} Python files.")
# Process each file to find dependencies
for file_path in file_paths:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Parse the file content into an AST
tree = ast.parse(content)
# Extract imports
_, from_imports = extract_imports(
tree
) # We only care about 'from' imports for relative paths
# Process from imports to find relative dependencies
for module_name in from_imports.keys():
if module_name.startswith("."):
# This is a relative import. Resolve its path.
resolved_path = resolve_relative_import_path(file_path, module_name)
if resolved_path and resolved_path != file_path:
# Found a valid local dependency. Add the relationship.
# Use relative paths from the base_dir for graph clarity.
rel_source = os.path.relpath(file_path, base_dir)
rel_target = os.path.relpath(resolved_path, base_dir)
relationships.add((rel_source, rel_target))
# print(f" Dependency found: {rel_source} --> {rel_target}") # Optional: for debugging
print(f"Processed: {os.path.relpath(file_path, base_dir)}")
except FileNotFoundError:
print(f"Error: File not found {file_path}")
except SyntaxError as e:
print(f"Error parsing syntax in {file_path}: {str(e)}")
except Exception as e:
print(f"An unexpected error occurred processing {file_path}: {str(e)}")
# Generate Mermaid diagram string
mermaid = "graph TD\n"
# Collect all unique files involved in relationships
all_files = set()
for source, target in relationships:
all_files.add(source)
all_files.add(target)
# Add nodes (files) to the graph
# Use a simple replacement for valid Mermaid IDs
# Replace non-alphanumeric characters (except underscore) with underscore
def create_mermaid_id(file_rel_path):
# Replace path separators and dots with underscores, keep alphanumeric and underscore
return re.sub(r"[^a-zA-Z0-9_]", "_", file_rel_path)
for file in sorted(list(all_files)): # Sort for consistent output
file_id = create_mermaid_id(file)
# Create the label, replacing backslashes with dots
file_label = file.replace(os.sep, ".")
# Mermaid node syntax: ID["Label"]
mermaid += f' {file_id}["{file_label}"]\n'
# Add edges (dependencies) to the graph
for source, target in sorted(list(relationships)): # Sort for consistent output
source_id = create_mermaid_id(source)
target_id = create_mermaid_id(target)
# Mermaid edge syntax: SourceID --> TargetID
mermaid += f" {source_id} --> {target_id}\n"
return mermaid, len(relationships)
def main():
"""Main function to run the dependency graph generation."""
# Define the base directory to scan
# You can change this to the root of your project or any subdirectory
base_directory_to_scan = "src" # Change this to your target directory
# Check if the directory exists
if not os.path.isdir(base_directory_to_scan):
print(f"Error: Directory '{base_directory_to_scan}' not found.")
print(
"Please update the 'base_directory_to_scan' variable in the script to point to your project directory."
)
return
# Generate the dependency graph
mermaid_graph, edge_count = generate_dependency_graph(base_directory_to_scan)
# Define the output file name
output_filename = "dependency_graph.md"
# Write the Mermaid graph to the output file
try:
with open(output_filename, "w", encoding="utf-8") as f:
# Optionally, wrap the mermaid code in a mermaid block for Markdown viewers
f.write("```mermaid\n")
f.write(mermaid_graph)
f.write("```\n")
print(
f"\nDependency graph generated successfully and saved to '{output_filename}'"
)
print(f"Found {edge_count} dependencies between files.")
print(
f"Open '{output_filename}' in a Markdown viewer that supports Mermaid to see the graph."
)
except IOError as e:
print(f"Error writing to file {output_filename}: {str(e)}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment