Skip to content

Instantly share code, notes, and snippets.

@danielpodrazka
Last active December 12, 2024 18:41
Show Gist options
  • Save danielpodrazka/388a2f4082a75f0d70840a81b5070e06 to your computer and use it in GitHub Desktop.
Save danielpodrazka/388a2f4082a75f0d70840a81b5070e06 to your computer and use it in GitHub Desktop.
Create a dependency graph of random terms in a repository
import os
import ast
import re
import json
from graphviz import Digraph
from collections import defaultdict
# Define your repository root
REPO_DIR = '/home/daniel/PycharmProjects/api'
PARENT_DIR = os.path.dirname(REPO_DIR) # This will be '/home/daniel/PycharmProjects'
# Define the search terms here
SEARCH_TERMS = ['some_table_name', 'another_term']
REFRESH_KEYWORD = 'refresh' # The word we look for on the same line as the search term references
def get_relative_path(path):
"""
Get the relative path starting from PARENT_DIR.
"""
path = os.path.abspath(path)
relative_path = os.path.relpath(path, PARENT_DIR)
return relative_path
def sanitize_label(label):
"""
Generate a label by getting the relative path from PARENT_DIR.
"""
return get_relative_path(label)
def sanitize_node_id(node_id):
"""
Generate a node ID from the relative path, sanitized to remove characters Graphviz might misinterpret.
"""
relative_path = get_relative_path(node_id)
# Replace special characters with underscores
sanitized_id = re.sub(r'[^a-zA-Z0-9_]', '_', relative_path)
return sanitized_id
def should_exclude(file_path):
"""
Determine if a file should be excluded based on its path.
"""
relative_path = get_relative_path(file_path).lower()
filename = os.path.basename(relative_path)
return 'test' in relative_path or 'alembic' in relative_path or filename.startswith('.')
class MultiSearchVisitor(ast.NodeVisitor):
"""
AST Node Visitor to find occurrences of multiple SEARCH_TERMS in Python code.
"""
def __init__(self, file_path, search_terms, file_code):
self.file_path = file_path
self.search_terms = search_terms
self.file_code_lines = file_code.split('\n')
self.found = []
def line_contains_refresh(self, lineno):
# Check if the given line (1-based) contains the refresh keyword
if 1 <= lineno <= len(self.file_code_lines):
line = self.file_code_lines[lineno - 1]
return REFRESH_KEYWORD.lower() in line.lower()
return False
def record_occurrences(self, text, node_type, lineno, **kwargs):
# Check if any search terms are present in the given text
# If multiple search terms appear, record all of them
for term in self.search_terms:
if re.search(r'\b' + re.escape(term) + r'\b', text, re.IGNORECASE):
refresh_present = self.line_contains_refresh(lineno) or (REFRESH_KEYWORD.lower() in text.lower())
occurrence = {
'search_term': term,
'type': node_type,
'file': self.file_path,
'lineno': lineno,
'refresh_present': refresh_present
}
# Merge extra kwargs into the occurrence
occurrence.update(kwargs)
self.found.append(occurrence)
def visit_Import(self, node):
for alias in node.names:
# check alias.name and alias.asname
if alias.name:
self.record_occurrences(alias.name, 'import', node.lineno, name=alias.name, asname=alias.asname)
if alias.asname:
self.record_occurrences(alias.asname, 'import', node.lineno, name=alias.name, asname=alias.asname)
self.generic_visit(node)
def visit_ImportFrom(self, node):
module = node.module or ''
# Check module name first
self.record_occurrences(module, 'from_import', node.lineno, module=module)
# Check each alias
for alias in node.names:
self.record_occurrences(alias.name, 'from_import', node.lineno, module=module, name=alias.name, asname=alias.asname)
if alias.asname:
self.record_occurrences(alias.asname, 'from_import', node.lineno, module=module, name=alias.name, asname=alias.asname)
self.generic_visit(node)
def visit_Name(self, node):
self.record_occurrences(node.id, 'variable', node.lineno, name=node.id)
self.generic_visit(node)
def visit_ClassDef(self, node):
self.record_occurrences(node.name.lower(), 'class', node.lineno, name=node.name)
self.generic_visit(node)
def visit_FunctionDef(self, node):
self.record_occurrences(node.name.lower(), 'function', node.lineno, name=node.name)
self.generic_visit(node)
def visit_Constant(self, node):
# For Python 3.8 and above
if isinstance(node.value, str):
self.record_occurrences(node.value, 'embedded_sql', node.lineno)
self.generic_visit(node)
# Initialize the list to hold all occurrences for each search term
occurrences = {term: [] for term in SEARCH_TERMS}
# Walk through all files in the repository
for root, dirs, files in os.walk(REPO_DIR):
for file in files:
file_path = os.path.join(root, file)
if should_exclude(file_path):
continue
if file.endswith('.py'):
with open(file_path, 'r', encoding='utf-8') as f:
code = f.read()
try:
tree = ast.parse(code, filename=file_path)
visitor = MultiSearchVisitor(file_path, SEARCH_TERMS, code)
visitor.visit(tree)
# Add found occurrences to respective search_terms
for occ in visitor.found:
occurrences[occ['search_term']].append(occ)
except SyntaxError as e:
print(f"SyntaxError in file {file_path}: {e}")
elif file.endswith('.sql'):
# For SQL files, we check line by line for each search term
with open(file_path, 'r', encoding='utf-8') as f:
sql_lines = f.readlines()
for idx, line in enumerate(sql_lines, start=1):
for term in SEARCH_TERMS:
if re.search(r'\b' + re.escape(term) + r'\b', line, re.IGNORECASE):
refresh_present = (REFRESH_KEYWORD.lower() in line.lower())
occurrences[term].append({
'search_term': term,
'type': 'sql_file',
'file': file_path,
'lineno': idx,
'refresh_present': refresh_present
})
# Categorize the occurrences per search term
categorized = {
term: {
'import': [],
'from_import': [],
'variable': [],
'class': [],
'function': [],
'embedded_sql': [],
'sql_file': []
} for term in SEARCH_TERMS
}
for term, occ_list in occurrences.items():
for item in occ_list:
categorized[term][item['type']].append(item)
# Save the occurrences to a JSON file (optional)
with open('multi_occurrences.json', 'w', encoding='utf-8') as f:
json.dump(categorized, f, ensure_ascii=False, indent=4)
# Prepare to create the graph
dot = Digraph(
comment='Multiple Search Terms Dependencies',
engine='neato',
graph_attr={
'rankdir': 'TB',
'nodesep': '0.1',
'ranksep': '0.5',
'overlap': 'false',
'dpi': '300'
}
)
dot.node_attr.update({
'fontsize': '10',
})
dot.edge_attr.update({
'fontsize': '8'
})
# We need to collect files from all terms
files = set()
edges = []
# Process 'import' and 'from_import' occurrences for all terms to determine file dependencies
def add_file_dependencies(cat):
# For imports and from_import, we try to resolve modules to file paths
for item in cat['import']:
source_file = item['file']
module_name = item.get('name')
if module_name:
module_path = module_name.replace('.', os.sep)
possible_paths = [
os.path.join(REPO_DIR, module_path + '.py'),
os.path.join(REPO_DIR, module_path, '__init__.py'),
]
target_module = None
for path in possible_paths:
if os.path.exists(path):
target_module = path
break
if target_module and not (should_exclude(source_file) or should_exclude(target_module)):
files.add(source_file)
files.add(target_module)
edges.append((source_file, target_module, False))
for item in cat['from_import']:
source_file = item['file']
module = item.get('module')
if module:
module_path = module.replace('.', os.sep)
possible_paths = [
os.path.join(REPO_DIR, module_path + '.py'),
os.path.join(REPO_DIR, module_path, '__init__.py'),
]
target_module = None
for path in possible_paths:
if os.path.exists(path):
target_module = path
break
if target_module and not (should_exclude(source_file) or should_exclude(target_module)):
files.add(source_file)
files.add(target_module)
edges.append((source_file, target_module, False))
# Add file dependencies from all search terms
for term in SEARCH_TERMS:
add_file_dependencies(categorized[term])
# Add nodes for files
for file in files:
if should_exclude(file):
continue
label = sanitize_label(file)
node_id = sanitize_node_id(file)
dot.node(node_id, label)
# Add edges between files
for source, target, _ in edges:
if should_exclude(source) or should_exclude(target):
continue
source_id = sanitize_node_id(source)
target_id = sanitize_node_id(target)
dot.edge(source_id, target_id)
# Add nodes for each search term
# We'll create a separate node for each search term's references
term_nodes = {}
for term in SEARCH_TERMS:
term_node_id = f'search_term_node_{sanitize_node_id(term)}'
dot.node(term_node_id, label=f'References to "{term}"', shape='cylinder', style='filled', color='lightgrey')
term_nodes[term] = term_node_id
def make_edge_label(item_type, refresh_present):
# Create an appropriate label based on the item type and refresh presence
if item_type == 'embedded_sql':
label = 'embedded SQL referencing search term'
else:
label = 'references search term'
if refresh_present:
label += ' (includes refresh)'
return label
# Add edges from files to search term nodes
def add_reference_edges(cat, term):
for item_type in cat:
# Skip import and from_import since they reference modules, handled above
if item_type in ['import', 'from_import']:
continue
for item in cat[item_type]:
source_file = item['file']
if should_exclude(source_file):
continue
source_id = sanitize_node_id(source_file)
edge_color = 'red' if item.get('refresh_present') else 'black'
edge_label = make_edge_label(item_type, item.get('refresh_present', False))
dot.edge(source_id, term_nodes[term], label=edge_label, color=edge_color)
# Add reference edges for all terms
for term in SEARCH_TERMS:
add_reference_edges(categorized[term], term)
# Render the graph
dot.render(f"{'_'.join(SEARCH_TERMS)}.gv", view=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment