Created
May 1, 2025 10:45
-
-
Save cristipufu/ab7e1ce5c67a36ad4085b96f9f435efe to your computer and use it in GitHub Desktop.
Generate mermaid based on relative file imports
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import os | |
import re | |
def extract_imports(tree: ast.Module): | |
""" | |
Extract all import statements from an AST tree. | |
Args: | |
tree: The AST tree of a Python file. | |
Returns: | |
A tuple containing: | |
- direct_imports (list of str): Names of modules imported directly (e.g., 'os', 'sys'). | |
- from_imports (dict): Dictionary where keys are module names (e.g., 'my_package', '.models') | |
and values are lists of names imported from that module (e.g., ['User', 'Post']). | |
""" | |
direct_imports = [] | |
from_imports = {} | |
for node in ast.walk(tree): | |
# Extract direct imports (import x, import x as y) | |
if isinstance(node, ast.Import): | |
for alias in node.names: | |
direct_imports.append(alias.name) # Get the original module name | |
# Extract from imports (from x import y, from x import y as z) | |
elif isinstance(node, ast.ImportFrom): | |
# node.module is the module being imported from (e.g., 'my_package', '.models') | |
# node.level indicates the level for relative imports (0 for absolute, 1 for ., 2 for .., etc.) | |
module_name = ( | |
node.module or "" | |
) # Handle cases like 'from . import module' where module is None | |
level = node.level # Number of leading dots | |
# For relative imports, reconstruct the full relative path string | |
if level > 0: | |
# Prepend dots based on the level | |
relative_prefix = "." * level | |
full_module_name = ( | |
f"{relative_prefix}{module_name}" | |
if module_name | |
else relative_prefix | |
) | |
else: | |
full_module_name = module_name # Absolute import | |
# Get the names being imported from the module | |
imported_names = [alias.name for alias in node.names] | |
# Store in the dictionary | |
# We store the full module name (including dots for relative imports) | |
if full_module_name not in from_imports: | |
from_imports[full_module_name] = [] | |
from_imports[full_module_name].extend(imported_names) | |
return direct_imports, from_imports | |
def resolve_relative_import_path(current_file_path, relative_module_string): | |
""" | |
Resolve a relative module import string (e.g., ".models", "..utils.helpers") | |
to a potential absolute file path within the project. | |
Args: | |
current_file_path: The absolute path of the file containing the import. | |
relative_module_string: The relative import string starting with '.' or '..'. | |
Returns: | |
The resolved absolute file path (pointing to a .py or __init__.py) | |
or None if the path cannot be resolved to an existing file. | |
""" | |
if not relative_module_string.startswith("."): | |
# This function is specifically for relative imports | |
return None | |
parts = relative_module_string.split(".") | |
# Count the number of leading dots to determine the base directory level | |
dot_count = 0 | |
for part in parts: | |
if part == "": | |
dot_count += 1 | |
else: | |
break # Stop counting dots when a non-empty part is found | |
# The actual module/package path parts after the dots | |
module_parts = parts[dot_count:] | |
# Get the directory of the current file | |
current_dir = os.path.dirname(current_file_path) | |
# Determine the base directory for the relative import | |
# A single dot (level 1) means the current package. | |
# The directory containing the current file is part of the current package. | |
# So, for level N, we go up N-1 directories from the current file's directory | |
# to reach the package directory indicated by the dots. | |
base_dir = current_dir | |
for _ in range(dot_count - 1): | |
base_dir = os.path.dirname(base_dir) | |
# Note: This could potentially go above the project root if the relative import is malformed. | |
# For simplicity, we assume valid relative imports within a project structure. | |
# Construct the potential path relative to the base_dir | |
# Join the module parts to form the path within the package | |
module_path_from_base = os.path.join(*module_parts) | |
# Check for two possibilities: | |
# 1. It's a module file (e.g., .utils.helpers -> base_dir/utils/helpers.py) | |
# 2. It's a package (e.g., .utils -> base_dir/utils/__init__.py) | |
# 3. It's a package import itself (e.g., from . import models -> base_dir/models/__init__.py or base_dir/models.py) | |
# The module_parts would be ['models'] in this case. The logic below handles this. | |
# 4. It's a package import (e.g., from . -> base_dir/__init__.py). module_parts would be empty. | |
potential_file = os.path.join(base_dir, module_path_from_base + ".py") | |
potential_init = os.path.join(base_dir, module_path_from_base, "__init__.py") | |
if os.path.exists(potential_file): | |
return potential_file | |
elif os.path.exists(potential_init): | |
return potential_init | |
elif not module_parts and dot_count > 0: | |
# Case like 'from .' or 'from ..', importing the package itself | |
# The base_dir already points to the directory of the package being imported. | |
# We look for the __init__.py file in that directory. | |
potential_package_init = os.path.join(base_dir, "__init__.py") | |
if os.path.exists(potential_package_init): | |
return potential_package_init | |
# If none of the above match, we couldn't resolve it to a local file | |
return None | |
def generate_dependency_graph(base_dir: str): | |
""" | |
Generate a Mermaid dependency graph showing file dependencies based on local imports. | |
Args: | |
base_dir: The base directory to scan for Python files. | |
Returns: | |
A string containing the Mermaid graph definition. | |
""" | |
file_paths = [] | |
relationships = set() # Use a set to avoid duplicate edges | |
# Find all Python files | |
print(f"Scanning directory: {base_dir}") | |
for root, _, files in os.walk(base_dir): | |
for file in files: | |
if file.endswith(".py"): | |
file_path = os.path.join(root, file) | |
file_paths.append(file_path) | |
print(f"Found {len(file_paths)} Python files.") | |
# Process each file to find dependencies | |
for file_path in file_paths: | |
try: | |
with open(file_path, "r", encoding="utf-8") as f: | |
content = f.read() | |
# Parse the file content into an AST | |
tree = ast.parse(content) | |
# Extract imports | |
_, from_imports = extract_imports( | |
tree | |
) # We only care about 'from' imports for relative paths | |
# Process from imports to find relative dependencies | |
for module_name in from_imports.keys(): | |
if module_name.startswith("."): | |
# This is a relative import. Resolve its path. | |
resolved_path = resolve_relative_import_path(file_path, module_name) | |
if resolved_path and resolved_path != file_path: | |
# Found a valid local dependency. Add the relationship. | |
# Use relative paths from the base_dir for graph clarity. | |
rel_source = os.path.relpath(file_path, base_dir) | |
rel_target = os.path.relpath(resolved_path, base_dir) | |
relationships.add((rel_source, rel_target)) | |
# print(f" Dependency found: {rel_source} --> {rel_target}") # Optional: for debugging | |
print(f"Processed: {os.path.relpath(file_path, base_dir)}") | |
except FileNotFoundError: | |
print(f"Error: File not found {file_path}") | |
except SyntaxError as e: | |
print(f"Error parsing syntax in {file_path}: {str(e)}") | |
except Exception as e: | |
print(f"An unexpected error occurred processing {file_path}: {str(e)}") | |
# Generate Mermaid diagram string | |
mermaid = "graph TD\n" | |
# Collect all unique files involved in relationships | |
all_files = set() | |
for source, target in relationships: | |
all_files.add(source) | |
all_files.add(target) | |
# Add nodes (files) to the graph | |
# Use a simple replacement for valid Mermaid IDs | |
# Replace non-alphanumeric characters (except underscore) with underscore | |
def create_mermaid_id(file_rel_path): | |
# Replace path separators and dots with underscores, keep alphanumeric and underscore | |
return re.sub(r"[^a-zA-Z0-9_]", "_", file_rel_path) | |
for file in sorted(list(all_files)): # Sort for consistent output | |
file_id = create_mermaid_id(file) | |
# Create the label, replacing backslashes with dots | |
file_label = file.replace(os.sep, ".") | |
# Mermaid node syntax: ID["Label"] | |
mermaid += f' {file_id}["{file_label}"]\n' | |
# Add edges (dependencies) to the graph | |
for source, target in sorted(list(relationships)): # Sort for consistent output | |
source_id = create_mermaid_id(source) | |
target_id = create_mermaid_id(target) | |
# Mermaid edge syntax: SourceID --> TargetID | |
mermaid += f" {source_id} --> {target_id}\n" | |
return mermaid, len(relationships) | |
def main(): | |
"""Main function to run the dependency graph generation.""" | |
# Define the base directory to scan | |
# You can change this to the root of your project or any subdirectory | |
base_directory_to_scan = "src" # Change this to your target directory | |
# Check if the directory exists | |
if not os.path.isdir(base_directory_to_scan): | |
print(f"Error: Directory '{base_directory_to_scan}' not found.") | |
print( | |
"Please update the 'base_directory_to_scan' variable in the script to point to your project directory." | |
) | |
return | |
# Generate the dependency graph | |
mermaid_graph, edge_count = generate_dependency_graph(base_directory_to_scan) | |
# Define the output file name | |
output_filename = "dependency_graph.md" | |
# Write the Mermaid graph to the output file | |
try: | |
with open(output_filename, "w", encoding="utf-8") as f: | |
# Optionally, wrap the mermaid code in a mermaid block for Markdown viewers | |
f.write("```mermaid\n") | |
f.write(mermaid_graph) | |
f.write("```\n") | |
print( | |
f"\nDependency graph generated successfully and saved to '{output_filename}'" | |
) | |
print(f"Found {edge_count} dependencies between files.") | |
print( | |
f"Open '{output_filename}' in a Markdown viewer that supports Mermaid to see the graph." | |
) | |
except IOError as e: | |
print(f"Error writing to file {output_filename}: {str(e)}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment