Skip to content

Instantly share code, notes, and snippets.

@cristipufu
Last active May 1, 2025 11:26
Show Gist options
  • Save cristipufu/1d7280ee7ec3cdf7a5faf53c45507043 to your computer and use it in GitHub Desktop.
Save cristipufu/1d7280ee7ec3cdf7a5faf53c45507043 to your computer and use it in GitHub Desktop.
LangGraph GitHub Docs Agent
import ast
import os
import re
from typing import List, Tuple
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langgraph.graph import END, START, MessagesState, StateGraph
from pydantic import BaseModel, Field
# --- Dependency Graph Functions ---
# These functions are now called within different nodes as needed
def extract_imports(tree: ast.Module):
"""
Extract all import statements from an AST tree.
Args:
tree: The AST tree of a Python file.
Returns:
A tuple containing:
- direct_imports (list of str): Names of modules imported directly (e.g., 'os', 'sys').
- from_imports (dict): Dictionary where keys are module names (e.g., 'my_package', '.models')
and values are lists of names imported from that module (e.g., ['User', 'Post']).
"""
direct_imports = []
from_imports = {}
for node in ast.walk(tree):
# Extract direct imports (import x, import x as y)
if isinstance(node, ast.Import):
for alias in node.names:
direct_imports.append(alias.name) # Get the original module name
# Extract from imports (from x import y, from x import y as z)
elif isinstance(node, ast.ImportFrom):
# node.module is the module being imported from (e.g., 'my_package', '.models')
# node.level indicates the level for relative imports (0 for absolute, 1 for ., 2 for .., etc.)
module_name = (
node.module or ""
) # Handle cases like 'from . import module' where module is None
level = node.level # Number of leading dots
# For relative imports, reconstruct the full relative path string
if level > 0:
# Prepend dots based on the level
relative_prefix = "." * level
full_module_name = (
f"{relative_prefix}{module_name}"
if module_name
else relative_prefix
)
else:
full_module_name = module_name # Absolute import
# Get the names being imported from the module
imported_names = [alias.name for alias in node.names]
# Store in the dictionary
# We store the full module name (including dots for relative imports)
if full_module_name not in from_imports:
from_imports[full_module_name] = []
from_imports[full_module_name].extend(imported_names)
return direct_imports, from_imports
def resolve_relative_import_path(current_file_path, relative_module_string):
"""
Resolve a relative module import string (e.g., ".models", "..utils.helpers")
to a potential absolute file path within the project.
Args:
current_file_path: The absolute path of the file containing the import.
relative_module_string: The relative import string starting with '.' or '..'.
Returns:
The resolved absolute file path (pointing to a .py or __init__.py)
or None if the path cannot be resolved to an existing file.
"""
if not relative_module_string.startswith("."):
# This function is specifically for relative imports
return None
parts = relative_module_string.split(".")
# Count the number of leading dots to determine the base directory level
dot_count = 0
for part in parts:
if part == "":
dot_count += 1
else:
break # Stop counting dots when a non-empty part is found
# The actual module/package path parts after the dots
module_parts = parts[dot_count:]
# Get the directory of the current file
current_dir = os.path.dirname(current_file_path)
# Determine the base directory for the relative import
# A single dot (level 1) means the current package.
# The directory containing the current file is part of the current package.
# So, for level N, we go up N-1 directories from the current file's directory
# to reach the package directory indicated by the dots.
base_dir = current_dir
for _ in range(dot_count - 1):
base_dir = os.path.dirname(base_dir)
# Note: This could potentially go above the project root if the relative import is malformed.
# For simplicity, we assume valid relative imports within a project structure.
# Construct the potential path relative to the base_dir
# Join the module parts to form the path within the package
module_path_from_base = os.path.join(*module_parts)
# Check for two possibilities:
# 1. It's a module file (e.g., .utils.helpers -> base_dir/utils/helpers.py)
# 2. It's a package (e.g., .utils -> base_dir/utils/__init__.py)
# 3. It's a package import itself (e.g., from . import models -> base_dir/models/__init__.py or base_dir/models.py)
# The module_parts would be ['models'] in this case. The logic below handles this.
# 4. It's a package import (e.g., from . -> base_dir/__init__.py). module_parts would be empty.
potential_file = os.path.join(base_dir, module_path_from_base + ".py")
potential_init = os.path.join(base_dir, module_path_from_base, "__init__.py")
if os.path.exists(potential_file):
return potential_file
elif os.path.exists(potential_init):
return potential_init
elif not module_parts and dot_count > 0:
# Case like 'from .' or 'from ..', importing the package itself
# The base_dir already points to the directory of the package being imported.
# We look for the __init__.py file in that directory.
potential_package_init = os.path.join(base_dir, "__init__.py")
if os.path.exists(potential_package_init):
return potential_package_init
# If none of the above match, we couldn't resolve it to a local file
return None
def get_dependency_relationships(base_dir: str) -> List[Tuple[str, str]]:
"""
Scans the directory and identifies local dependency relationships based on imports.
Args:
base_dir: The base directory to scan for Python files.
Returns:
A list of (source_file_rel_path, target_file_rel_path) tuples.
"""
file_paths = []
relationships = set() # Use a set to avoid duplicate edges
# Find all Python files
print(f"Scanning directory for dependencies: {base_dir}")
for root, _, files in os.walk(base_dir):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
file_paths.append(file_path)
print(f"Found {len(file_paths)} Python files for dependency analysis.")
# Process each file to find dependencies
for file_path in file_paths:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Parse the file content into an AST
tree = ast.parse(content)
# Extract imports
_, from_imports = extract_imports(
tree
) # We only care about 'from' imports for relative paths
# Process from imports to find relative dependencies
for module_name in from_imports.keys():
if module_name.startswith("."):
# This is a relative import. Resolve its path.
resolved_path = resolve_relative_import_path(file_path, module_name)
if (
resolved_path
and os.path.exists(resolved_path)
and resolved_path != file_path
):
# Found a valid local dependency. Add the relationship.
# Use relative paths from the base_dir for graph clarity.
rel_source = os.path.relpath(file_path, base_dir)
rel_target = os.path.relpath(resolved_path, base_dir)
relationships.add((rel_source, rel_target))
except FileNotFoundError:
print(f"Error: File not found {file_path}")
except SyntaxError as e:
print(f"Error parsing syntax in {file_path}: {str(e)}")
except Exception as e:
print(
f"An unexpected error occurred processing {file_path} for dependencies: {str(e)}"
)
return list(relationships) # Return as a list
def create_mermaid_id(file_rel_path):
"""
Creates a valid Mermaid ID from a file path.
Replaces non-alphanumeric characters (except underscore) with underscore.
"""
# Replace path separators and dots with underscores, keep alphanumeric and underscore
return re.sub(r"[^a-zA-Z0-9_]", "_", file_rel_path)
def build_mermaid_graph_from_relationships(relationships: List[Tuple[str, str]]) -> str:
"""
Generates a Mermaid graph string from a list of dependency relationships.
Args:
relationships: A list of (source_file_rel_path, target_file_rel_path) tuples.
Returns:
A string containing the Mermaid graph definition.
"""
mermaid = "graph TD\n"
# Collect all unique files involved in relationships
all_files = set()
for source, target in relationships:
all_files.add(source)
all_files.add(target)
# Add nodes (files) to the graph
for file in sorted(list(all_files)): # Sort for consistent output
file_id = create_mermaid_id(file)
# Create the label, replacing backslashes with dots
file_label = file.replace(os.sep, ".")
# Mermaid node syntax: ID["Label"]
mermaid += f' {file_id}["{file_label}"]\n'
# Add edges (dependencies) to the graph
for source, target in sorted(relationships): # Sort for consistent output
source_id = create_mermaid_id(source)
target_id = create_mermaid_id(target)
# Mermaid edge syntax: SourceID --> TargetID
mermaid += f" {source_id} --> {target_id}\n"
return mermaid
# --- LangChain Graph Definitions ---
class FileMetadata(BaseModel):
path: str
imports: List[str]
method_names: List[str]
docstrings: List[str]
class Concept(BaseModel):
name: str
description: str
files: List[str] = Field(default_factory=list)
class Concepts(BaseModel):
concepts: List[Concept]
class GraphInput(BaseModel):
topic: str # Input for the graph
# Define the state for the LangChain graph
# Now includes the list of dependency relationships
class GraphState(MessagesState):
files_metadata: List[FileMetadata] = Field(default_factory=list)
concepts: List[Concept] = Field(default_factory=list)
dependencies: List[Tuple[str, str]] = Field(
default_factory=list
) # Store dependency relationships
class GraphOutput(BaseModel):
documentation: str
# Initialize LLM and Vector Store (assuming OpenAI key is set in environment variables)
embeddings = OpenAIEmbeddings()
# Initialize ChromaDB - it will store embeddings in memory or persist if configured
# For this example, it's in-memory by default
vector_store = Chroma(
embedding_function=embeddings,
collection_name="code_analysis", # Optional name for your collection
)
# Initialize Parsers for structured output from LLM
concepts_parser = PydanticOutputParser(pydantic_object=Concepts)
concept_parser = PydanticOutputParser(pydantic_object=Concept)
def extract_file_metadata(tree: ast.Module, path: str) -> FileMetadata:
"""Extract all metadata from an AST tree in a single pass."""
imports = []
method_names = []
docstrings = []
# Get module docstring
module_doc = ast.get_docstring(tree)
if module_doc:
docstrings.append(f"MODULE: {module_doc}")
# Single walk through the AST
for node in ast.walk(tree):
# Extract imports
if isinstance(node, ast.Import):
for name in node.names:
imports.append(f"import {name.name}")
elif isinstance(node, ast.ImportFrom):
module = node.module or ""
names = ", ".join(name.name for name in node.names)
imports.append(f"from {module} import {names}")
# Extract function and class names
elif isinstance(node, ast.FunctionDef):
method_names.append(node.name)
# Get function docstring
doc = ast.get_docstring(node)
if doc:
docstrings.append(f"{node.name}: {doc}")
elif isinstance(node, ast.ClassDef):
method_names.append(f"class {node.name}")
# Get class docstring
doc = ast.get_docstring(node)
if doc:
docstrings.append(f"class {node.name}: {doc}")
elif isinstance(node, ast.AsyncFunctionDef):
method_names.append(f"async {node.name}")
# Get async function docstring
doc = ast.get_docstring(node)
if doc:
docstrings.append(f"async {node.name}: {doc}")
return FileMetadata(
path=path,
imports=imports,
method_names=method_names,
docstrings=docstrings,
)
async def discover_key_concepts(state: GraphState) -> GraphState:
"""
Scans the directory, extracts file metadata, generates dependency relationships,
and uses an LLM to discover key concepts.
"""
directory = "src" # Define the directory to scan
all_files_metadata = []
documents = []
# Check if the directory exists
if not os.path.isdir(directory):
print(f"Error: Directory '{directory}' not found.")
# Return current state or raise an error, depending on desired behavior
return state # Or raise ValueError(f"Directory '{directory}' not found.")
# --- Step 1: Extract File Metadata and Create Documents for Vector Store ---
print(f"Scanning directory for metadata: {directory}")
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Parse the file content into an AST
tree = ast.parse(content)
# Extract all metadata in a single pass
relative_path = os.path.relpath(file_path, directory)
file_metadata = extract_file_metadata(tree, relative_path)
all_files_metadata.append(file_metadata)
# Create document for file content
documents.append(
Document(
page_content=content,
metadata={"path": relative_path, "type": "content"},
)
)
# Create document for file metadata
metadata_content = (
f"File: {relative_path}\n"
f"Imports: {', '.join(file_metadata.imports)}\n"
f"Methods: {', '.join(file_metadata.method_names)}\n"
f"Docstrings: {' | '.join(file_metadata.docstrings)}"
)
documents.append(
Document(
page_content=metadata_content,
metadata={"path": relative_path, "type": "metadata"},
)
)
except FileNotFoundError:
print(f"Error: File not found {file_path}")
except SyntaxError as e:
print(f"Error parsing syntax in {file_path}: {str(e)}")
except Exception as e:
print(
f"An unexpected error occurred processing {file_path} for metadata: {str(e)}"
)
# Add documents to the vector store
if documents:
vector_store.add_documents(documents)
print(f"Added {len(documents)} documents to the vector store.")
else:
print("No documents found to add to the vector store.")
# --- Step 2: Generate Dependency Relationships ---
# Call the function to get the raw relationships
relationships = get_dependency_relationships(directory)
print(f"Identified {len(relationships)} dependency relationships.")
# --- Step 3: Use LLM to Discover Key Concepts based on Metadata ---
llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# Prepare metadata summary for the LLM
metadata_summary = "\n".join(
[
f"File: {meta.path}\nImports: {', '.join(meta.imports)}\nMethods: {', '.join(meta.method_names)}"
for meta in all_files_metadata
]
)
# You could potentially include dependency information here if the LLM should consider it
# For now, we are just passing the relationships in the state.
dependency_summary = "\n".join([f"{s} --> {t}" for s, t in relationships])
# system_prompt = f"""... {metadata_summary}\n\nDependencies:\n{dependency_summary}\n\nWhat are likely..."""
system_prompt = f"""
Based on these file names, imported modules, and defined classes/functions:
{metadata_summary}
And these relationships between modules:\n{dependency_summary}
What are likely to be the top 5 most relevant concepts or abstractions in this codebase?
For each concept, provide:
1. A name
2. A brief description
3. Which files likely implement this concept
Focus on identifying high-level ideas and components.
MUST ALWAYS output JUST the json with this format:
{concepts_parser.get_format_instructions()}
"""
print("Asking LLM to discover key concepts...")
output = await llm_model.ainvoke([SystemMessage(system_prompt)])
try:
concepts = concepts_parser.parse(output.content)
print(f"LLM identified {len(concepts.concepts)} initial concepts.")
except Exception as e:
print(f"Error parsing concepts from LLM output: {e}")
print(f"LLM Output: {output.content}")
concepts = Concepts(concepts=[]) # Return empty list if parsing fails
# Return the updated state, including the dependency relationships
return GraphState(
files_metadata=all_files_metadata,
concepts=concepts.concepts,
dependencies=relationships, # Store the list of relationships in the state
)
async def refine_concepts(state: GraphState) -> GraphState:
"""
Refines discovered concepts using code samples retrieved from the vector store.
"""
llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
refined_concepts = []
initial_concepts = state.get(
"concepts", []
) # Get concepts from state, default to empty list
if not initial_concepts:
print("No initial concepts found to refine.")
return state # Return current state if no concepts
print(f"Refining {len(initial_concepts)} concepts...")
for concept in initial_concepts:
print(f"Refining concept: {concept.name}")
query = f"{concept.name} {concept.description}"
# Search for relevant code content documents
results = vector_store.similarity_search(query, k=5, filter={"type": "content"})
code_samples = []
for doc in results:
code_samples.append(
f"File: {doc.metadata['path']}\n```python\n{doc.page_content[:1500]}...\n```" # Use code block
)
# Refine concept with code evidence
system_prompt = f"""
Initial concept:
Name: {concept.name}
Description: {concept.description}
Files: {", ".join(concept.files)}
Relevant code samples:
{code_samples if code_samples else "No relevant code samples found."}
Based on these code samples, refine the concept description.
Provide:
1. A more accurate name if needed (keep it concise).
2. A detailed description that references the actual implementation and functionality.
3. The description should contain the functionality this concept provides.
4. The description should contain brief implementation details.
5. Which files likely implement this concept (referencing the provided file paths).
Ensure the refined concept is grounded in the code evidence.
MUST ALWAYS output JUST the json with this format:
{concept_parser.get_format_instructions()}
"""
# Use structured output parser to get proper format
print(f"Asking LLM to refine concept: {concept.name}")
output = await llm_model.ainvoke([SystemMessage(system_prompt)])
try:
refined_concept = concept_parser.parse(output.content)
refined_concepts.append(refined_concept)
print(f"Successfully refined concept: {refined_concept.name}")
except Exception as e:
print(
f"Error parsing refined concept from LLM output for '{concept.name}': {e}"
)
print(f"LLM Output: {output.content}")
# Append the original concept if refinement fails
refined_concepts.append(concept)
# Return the updated state with refined concepts (preserving dependencies)
return GraphState(
files_metadata=state.get("files_metadata"),
concepts=refined_concepts,
dependencies=state.get("dependencies"),
)
async def generate_documentation(state: GraphState) -> GraphOutput:
"""
Generate comprehensive documentation for the codebase based on concepts and includes the dependency graph.
Dependency graph is built here from relationships in the state.
The final documentation is saved to disk.
"""
llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
refined_concepts = state.get("concepts", []) # Get concepts from state
dependencies = state.get(
"dependencies", []
) # Get dependency relationships from state
# Prepare documentation structure
doc = "# Codebase Documentation\n\n"
# --- Add Dependency Graph Section ---
if dependencies:
# Build the mermaid graph string from the relationships
mermaid_graph = build_mermaid_graph_from_relationships(dependencies)
doc += "## Dependency Graph\n\n"
doc += "This graph visualizes the dependencies between files based on local imports.\n\n"
doc += "```mermaid\n"
doc += mermaid_graph
doc += "```\n\n---\n\n"
print(
f"Built and added dependency graph with {len(dependencies)} edges to documentation."
)
else:
print("No dependency relationships found to build graph.")
doc += f"## Dependency Graph\n\nCould not generate dependency graph: No dependency relationships found.\n\n---\n\n"
# --- Add Project Overview Section ---
doc += "## Project Overview\n\n"
if refined_concepts:
overview_prompt = f"""
Based on these concepts and their descriptions:
{[{"name": c.name, "description": c.description} for c in refined_concepts]}
Write a concise overview of what this project does, its architecture, and how the main components work together.
Keep it under 4 paragraphs and focus on making it clear to a new developer.
"""
print("Asking LLM to generate project overview...")
overview_response = await llm_model.ainvoke(
[SystemMessage(content=overview_prompt)]
)
doc += f"{overview_response.content}\n\n"
else:
doc += "Project overview could not be generated as no concepts were identified.\n\n"
print("Skipped project overview generation due to no concepts.")
# --- Add Table of Contents ---
if refined_concepts:
doc += "## Table of Contents\n\n"
for i, concept in enumerate(refined_concepts):
# Ensure valid markdown links
link_text = concept.name.replace(" ", "-").lower() # Simple slug
doc += f"{i + 1}. [{concept.name}](#{link_text}-{i + 1})\n"
doc += "\n"
print("Added table of contents.")
else:
print("Skipped table of contents generation due to no concepts.")
# --- Document Each Concept ---
if refined_concepts:
print("Generating documentation for each concept...")
for i, concept in enumerate(refined_concepts):
print(f"Documenting concept: {concept.name}")
query = f"{concept.name} {concept.description} example usage"
# Search for relevant code content documents
results = vector_store.similarity_search(
query, k=3, filter={"type": "content"}
)
code_samples = []
for doc_result in results:
code_samples.append(
f"File: {doc_result.metadata['path']}\n```python\n{doc_result.page_content[:1000]}...\n```" # Use code block, limit length
)
concept_prompt = f"""
Concept: {concept.name}
Description: {concept.description}
Files: {", ".join(concept.files)}
Relevant code samples:
{code_samples if code_samples else "No relevant code samples found."}
Write comprehensive documentation for this concept including:
1. A detailed explanation of what it does
2. Its implementation details (how it works internally)
3. How to use it (with code examples)
4. Any important configuration options or parameters
Format using Markdown with proper headings, code blocks, and examples.
Make the examples practical and helpful for a developer new to this codebase.
"""
concept_response = await llm_model.ainvoke(
[SystemMessage(content=concept_prompt)]
)
# Add to documentation
link_text = concept.name.replace(" ", "-").lower() # Simple slug for anchor
doc += f"## {concept.name} <a name='{link_text}-{i + 1}'></a>\n\n" # Add anchor for TOC
doc += f"{concept_response.content}\n\n"
if concept.files:
doc += "**Implemented in files:**\n\n"
for file in concept.files:
doc += f"- `{file}`\n"
doc += "\n"
doc += "---\n\n"
print(f"Finished documenting concept: {concept.name}")
else:
print("Skipped concept documentation generation due to no concepts.")
# --- Add Getting Started Guide ---
# Provide the generated documentation so far as context
start_prompt = f"""
Based on the concepts and documentation about this codebase, write a getting started guide that includes:
1. Prerequisites
2. Installation steps
3. A simple example showing the main functionality
4. Common gotchas or things to watch out for
Make it practical and easy to follow for a new developer. Refer to the concepts and files mentioned in the documentation.
Here is the current documentation content for context:
---
{doc}
---
"""
print("Asking LLM to generate getting started guide...")
start_response = await llm_model.ainvoke([SystemMessage(content=start_prompt)])
doc += "## Getting Started\n\n"
doc += f"{start_response.content}\n\n"
print("Added getting started guide.")
# --- Save Documentation to Disk ---
output_filename = "code_documentation.md"
try:
with open(output_filename, "w", encoding="utf-8") as f:
f.write(doc)
print(f"\nFinal documentation saved to '{output_filename}'")
except IOError as e:
print(f"Error writing documentation to file {output_filename}: {str(e)}")
# Return the final documentation string in the output
return GraphOutput(
documentation=doc,
)
# --- Build the LangChain Graph ---
builder = StateGraph(state_schema=GraphState, input=GraphInput, output=GraphOutput)
builder.add_node("discover_key_concepts", discover_key_concepts)
builder.add_node("refine_concepts", refine_concepts)
builder.add_node("generate_documentation", generate_documentation)
# Define the flow of the graph
builder.add_edge(START, "discover_key_concepts")
builder.add_edge("discover_key_concepts", "refine_concepts")
builder.add_edge("refine_concepts", "generate_documentation")
builder.add_edge("generate_documentation", END)
# Compile the graph
graph = builder.compile()
# To run the graph, you would typically do:
# import asyncio
# async def run_graph():
# # Replace "your_topic" with the actual topic or goal for analysis
# result = await graph.ainvoke({"topic": "Explain the core functionality"})
# # The final documentation will be in result['documentation']
# # It will also be saved to code_documentation.md
# print(result['documentation'])
# if __name__ == "__main__":
# # Set your OpenAI API key environment variable before running
# # os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"
# # asyncio.run(run_graph())
# print("LangChain graph defined. Uncomment the run_graph section to execute.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment