Last active
May 1, 2025 11:26
-
-
Save cristipufu/1d7280ee7ec3cdf7a5faf53c45507043 to your computer and use it in GitHub Desktop.
LangGraph GitHub Docs Agent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import os | |
import re | |
from typing import List, Tuple | |
from langchain_chroma import Chroma | |
from langchain_core.documents import Document | |
from langchain_core.messages import SystemMessage | |
from langchain_core.output_parsers import PydanticOutputParser | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langgraph.graph import END, START, MessagesState, StateGraph | |
from pydantic import BaseModel, Field | |
# --- Dependency Graph Functions --- | |
# These functions are now called within different nodes as needed | |
def extract_imports(tree: ast.Module): | |
""" | |
Extract all import statements from an AST tree. | |
Args: | |
tree: The AST tree of a Python file. | |
Returns: | |
A tuple containing: | |
- direct_imports (list of str): Names of modules imported directly (e.g., 'os', 'sys'). | |
- from_imports (dict): Dictionary where keys are module names (e.g., 'my_package', '.models') | |
and values are lists of names imported from that module (e.g., ['User', 'Post']). | |
""" | |
direct_imports = [] | |
from_imports = {} | |
for node in ast.walk(tree): | |
# Extract direct imports (import x, import x as y) | |
if isinstance(node, ast.Import): | |
for alias in node.names: | |
direct_imports.append(alias.name) # Get the original module name | |
# Extract from imports (from x import y, from x import y as z) | |
elif isinstance(node, ast.ImportFrom): | |
# node.module is the module being imported from (e.g., 'my_package', '.models') | |
# node.level indicates the level for relative imports (0 for absolute, 1 for ., 2 for .., etc.) | |
module_name = ( | |
node.module or "" | |
) # Handle cases like 'from . import module' where module is None | |
level = node.level # Number of leading dots | |
# For relative imports, reconstruct the full relative path string | |
if level > 0: | |
# Prepend dots based on the level | |
relative_prefix = "." * level | |
full_module_name = ( | |
f"{relative_prefix}{module_name}" | |
if module_name | |
else relative_prefix | |
) | |
else: | |
full_module_name = module_name # Absolute import | |
# Get the names being imported from the module | |
imported_names = [alias.name for alias in node.names] | |
# Store in the dictionary | |
# We store the full module name (including dots for relative imports) | |
if full_module_name not in from_imports: | |
from_imports[full_module_name] = [] | |
from_imports[full_module_name].extend(imported_names) | |
return direct_imports, from_imports | |
def resolve_relative_import_path(current_file_path, relative_module_string): | |
""" | |
Resolve a relative module import string (e.g., ".models", "..utils.helpers") | |
to a potential absolute file path within the project. | |
Args: | |
current_file_path: The absolute path of the file containing the import. | |
relative_module_string: The relative import string starting with '.' or '..'. | |
Returns: | |
The resolved absolute file path (pointing to a .py or __init__.py) | |
or None if the path cannot be resolved to an existing file. | |
""" | |
if not relative_module_string.startswith("."): | |
# This function is specifically for relative imports | |
return None | |
parts = relative_module_string.split(".") | |
# Count the number of leading dots to determine the base directory level | |
dot_count = 0 | |
for part in parts: | |
if part == "": | |
dot_count += 1 | |
else: | |
break # Stop counting dots when a non-empty part is found | |
# The actual module/package path parts after the dots | |
module_parts = parts[dot_count:] | |
# Get the directory of the current file | |
current_dir = os.path.dirname(current_file_path) | |
# Determine the base directory for the relative import | |
# A single dot (level 1) means the current package. | |
# The directory containing the current file is part of the current package. | |
# So, for level N, we go up N-1 directories from the current file's directory | |
# to reach the package directory indicated by the dots. | |
base_dir = current_dir | |
for _ in range(dot_count - 1): | |
base_dir = os.path.dirname(base_dir) | |
# Note: This could potentially go above the project root if the relative import is malformed. | |
# For simplicity, we assume valid relative imports within a project structure. | |
# Construct the potential path relative to the base_dir | |
# Join the module parts to form the path within the package | |
module_path_from_base = os.path.join(*module_parts) | |
# Check for two possibilities: | |
# 1. It's a module file (e.g., .utils.helpers -> base_dir/utils/helpers.py) | |
# 2. It's a package (e.g., .utils -> base_dir/utils/__init__.py) | |
# 3. It's a package import itself (e.g., from . import models -> base_dir/models/__init__.py or base_dir/models.py) | |
# The module_parts would be ['models'] in this case. The logic below handles this. | |
# 4. It's a package import (e.g., from . -> base_dir/__init__.py). module_parts would be empty. | |
potential_file = os.path.join(base_dir, module_path_from_base + ".py") | |
potential_init = os.path.join(base_dir, module_path_from_base, "__init__.py") | |
if os.path.exists(potential_file): | |
return potential_file | |
elif os.path.exists(potential_init): | |
return potential_init | |
elif not module_parts and dot_count > 0: | |
# Case like 'from .' or 'from ..', importing the package itself | |
# The base_dir already points to the directory of the package being imported. | |
# We look for the __init__.py file in that directory. | |
potential_package_init = os.path.join(base_dir, "__init__.py") | |
if os.path.exists(potential_package_init): | |
return potential_package_init | |
# If none of the above match, we couldn't resolve it to a local file | |
return None | |
def get_dependency_relationships(base_dir: str) -> List[Tuple[str, str]]: | |
""" | |
Scans the directory and identifies local dependency relationships based on imports. | |
Args: | |
base_dir: The base directory to scan for Python files. | |
Returns: | |
A list of (source_file_rel_path, target_file_rel_path) tuples. | |
""" | |
file_paths = [] | |
relationships = set() # Use a set to avoid duplicate edges | |
# Find all Python files | |
print(f"Scanning directory for dependencies: {base_dir}") | |
for root, _, files in os.walk(base_dir): | |
for file in files: | |
if file.endswith(".py"): | |
file_path = os.path.join(root, file) | |
file_paths.append(file_path) | |
print(f"Found {len(file_paths)} Python files for dependency analysis.") | |
# Process each file to find dependencies | |
for file_path in file_paths: | |
try: | |
with open(file_path, "r", encoding="utf-8") as f: | |
content = f.read() | |
# Parse the file content into an AST | |
tree = ast.parse(content) | |
# Extract imports | |
_, from_imports = extract_imports( | |
tree | |
) # We only care about 'from' imports for relative paths | |
# Process from imports to find relative dependencies | |
for module_name in from_imports.keys(): | |
if module_name.startswith("."): | |
# This is a relative import. Resolve its path. | |
resolved_path = resolve_relative_import_path(file_path, module_name) | |
if ( | |
resolved_path | |
and os.path.exists(resolved_path) | |
and resolved_path != file_path | |
): | |
# Found a valid local dependency. Add the relationship. | |
# Use relative paths from the base_dir for graph clarity. | |
rel_source = os.path.relpath(file_path, base_dir) | |
rel_target = os.path.relpath(resolved_path, base_dir) | |
relationships.add((rel_source, rel_target)) | |
except FileNotFoundError: | |
print(f"Error: File not found {file_path}") | |
except SyntaxError as e: | |
print(f"Error parsing syntax in {file_path}: {str(e)}") | |
except Exception as e: | |
print( | |
f"An unexpected error occurred processing {file_path} for dependencies: {str(e)}" | |
) | |
return list(relationships) # Return as a list | |
def create_mermaid_id(file_rel_path): | |
""" | |
Creates a valid Mermaid ID from a file path. | |
Replaces non-alphanumeric characters (except underscore) with underscore. | |
""" | |
# Replace path separators and dots with underscores, keep alphanumeric and underscore | |
return re.sub(r"[^a-zA-Z0-9_]", "_", file_rel_path) | |
def build_mermaid_graph_from_relationships(relationships: List[Tuple[str, str]]) -> str: | |
""" | |
Generates a Mermaid graph string from a list of dependency relationships. | |
Args: | |
relationships: A list of (source_file_rel_path, target_file_rel_path) tuples. | |
Returns: | |
A string containing the Mermaid graph definition. | |
""" | |
mermaid = "graph TD\n" | |
# Collect all unique files involved in relationships | |
all_files = set() | |
for source, target in relationships: | |
all_files.add(source) | |
all_files.add(target) | |
# Add nodes (files) to the graph | |
for file in sorted(list(all_files)): # Sort for consistent output | |
file_id = create_mermaid_id(file) | |
# Create the label, replacing backslashes with dots | |
file_label = file.replace(os.sep, ".") | |
# Mermaid node syntax: ID["Label"] | |
mermaid += f' {file_id}["{file_label}"]\n' | |
# Add edges (dependencies) to the graph | |
for source, target in sorted(relationships): # Sort for consistent output | |
source_id = create_mermaid_id(source) | |
target_id = create_mermaid_id(target) | |
# Mermaid edge syntax: SourceID --> TargetID | |
mermaid += f" {source_id} --> {target_id}\n" | |
return mermaid | |
# --- LangChain Graph Definitions --- | |
class FileMetadata(BaseModel): | |
path: str | |
imports: List[str] | |
method_names: List[str] | |
docstrings: List[str] | |
class Concept(BaseModel): | |
name: str | |
description: str | |
files: List[str] = Field(default_factory=list) | |
class Concepts(BaseModel): | |
concepts: List[Concept] | |
class GraphInput(BaseModel): | |
topic: str # Input for the graph | |
# Define the state for the LangChain graph | |
# Now includes the list of dependency relationships | |
class GraphState(MessagesState): | |
files_metadata: List[FileMetadata] = Field(default_factory=list) | |
concepts: List[Concept] = Field(default_factory=list) | |
dependencies: List[Tuple[str, str]] = Field( | |
default_factory=list | |
) # Store dependency relationships | |
class GraphOutput(BaseModel): | |
documentation: str | |
# Initialize LLM and Vector Store (assuming OpenAI key is set in environment variables) | |
embeddings = OpenAIEmbeddings() | |
# Initialize ChromaDB - it will store embeddings in memory or persist if configured | |
# For this example, it's in-memory by default | |
vector_store = Chroma( | |
embedding_function=embeddings, | |
collection_name="code_analysis", # Optional name for your collection | |
) | |
# Initialize Parsers for structured output from LLM | |
concepts_parser = PydanticOutputParser(pydantic_object=Concepts) | |
concept_parser = PydanticOutputParser(pydantic_object=Concept) | |
def extract_file_metadata(tree: ast.Module, path: str) -> FileMetadata: | |
"""Extract all metadata from an AST tree in a single pass.""" | |
imports = [] | |
method_names = [] | |
docstrings = [] | |
# Get module docstring | |
module_doc = ast.get_docstring(tree) | |
if module_doc: | |
docstrings.append(f"MODULE: {module_doc}") | |
# Single walk through the AST | |
for node in ast.walk(tree): | |
# Extract imports | |
if isinstance(node, ast.Import): | |
for name in node.names: | |
imports.append(f"import {name.name}") | |
elif isinstance(node, ast.ImportFrom): | |
module = node.module or "" | |
names = ", ".join(name.name for name in node.names) | |
imports.append(f"from {module} import {names}") | |
# Extract function and class names | |
elif isinstance(node, ast.FunctionDef): | |
method_names.append(node.name) | |
# Get function docstring | |
doc = ast.get_docstring(node) | |
if doc: | |
docstrings.append(f"{node.name}: {doc}") | |
elif isinstance(node, ast.ClassDef): | |
method_names.append(f"class {node.name}") | |
# Get class docstring | |
doc = ast.get_docstring(node) | |
if doc: | |
docstrings.append(f"class {node.name}: {doc}") | |
elif isinstance(node, ast.AsyncFunctionDef): | |
method_names.append(f"async {node.name}") | |
# Get async function docstring | |
doc = ast.get_docstring(node) | |
if doc: | |
docstrings.append(f"async {node.name}: {doc}") | |
return FileMetadata( | |
path=path, | |
imports=imports, | |
method_names=method_names, | |
docstrings=docstrings, | |
) | |
async def discover_key_concepts(state: GraphState) -> GraphState: | |
""" | |
Scans the directory, extracts file metadata, generates dependency relationships, | |
and uses an LLM to discover key concepts. | |
""" | |
directory = "src" # Define the directory to scan | |
all_files_metadata = [] | |
documents = [] | |
# Check if the directory exists | |
if not os.path.isdir(directory): | |
print(f"Error: Directory '{directory}' not found.") | |
# Return current state or raise an error, depending on desired behavior | |
return state # Or raise ValueError(f"Directory '{directory}' not found.") | |
# --- Step 1: Extract File Metadata and Create Documents for Vector Store --- | |
print(f"Scanning directory for metadata: {directory}") | |
for root, _, files in os.walk(directory): | |
for file in files: | |
if file.endswith(".py"): | |
file_path = os.path.join(root, file) | |
try: | |
with open(file_path, "r", encoding="utf-8") as f: | |
content = f.read() | |
# Parse the file content into an AST | |
tree = ast.parse(content) | |
# Extract all metadata in a single pass | |
relative_path = os.path.relpath(file_path, directory) | |
file_metadata = extract_file_metadata(tree, relative_path) | |
all_files_metadata.append(file_metadata) | |
# Create document for file content | |
documents.append( | |
Document( | |
page_content=content, | |
metadata={"path": relative_path, "type": "content"}, | |
) | |
) | |
# Create document for file metadata | |
metadata_content = ( | |
f"File: {relative_path}\n" | |
f"Imports: {', '.join(file_metadata.imports)}\n" | |
f"Methods: {', '.join(file_metadata.method_names)}\n" | |
f"Docstrings: {' | '.join(file_metadata.docstrings)}" | |
) | |
documents.append( | |
Document( | |
page_content=metadata_content, | |
metadata={"path": relative_path, "type": "metadata"}, | |
) | |
) | |
except FileNotFoundError: | |
print(f"Error: File not found {file_path}") | |
except SyntaxError as e: | |
print(f"Error parsing syntax in {file_path}: {str(e)}") | |
except Exception as e: | |
print( | |
f"An unexpected error occurred processing {file_path} for metadata: {str(e)}" | |
) | |
# Add documents to the vector store | |
if documents: | |
vector_store.add_documents(documents) | |
print(f"Added {len(documents)} documents to the vector store.") | |
else: | |
print("No documents found to add to the vector store.") | |
# --- Step 2: Generate Dependency Relationships --- | |
# Call the function to get the raw relationships | |
relationships = get_dependency_relationships(directory) | |
print(f"Identified {len(relationships)} dependency relationships.") | |
# --- Step 3: Use LLM to Discover Key Concepts based on Metadata --- | |
llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
# Prepare metadata summary for the LLM | |
metadata_summary = "\n".join( | |
[ | |
f"File: {meta.path}\nImports: {', '.join(meta.imports)}\nMethods: {', '.join(meta.method_names)}" | |
for meta in all_files_metadata | |
] | |
) | |
# You could potentially include dependency information here if the LLM should consider it | |
# For now, we are just passing the relationships in the state. | |
dependency_summary = "\n".join([f"{s} --> {t}" for s, t in relationships]) | |
# system_prompt = f"""... {metadata_summary}\n\nDependencies:\n{dependency_summary}\n\nWhat are likely...""" | |
system_prompt = f""" | |
Based on these file names, imported modules, and defined classes/functions: | |
{metadata_summary} | |
And these relationships between modules:\n{dependency_summary} | |
What are likely to be the top 5 most relevant concepts or abstractions in this codebase? | |
For each concept, provide: | |
1. A name | |
2. A brief description | |
3. Which files likely implement this concept | |
Focus on identifying high-level ideas and components. | |
MUST ALWAYS output JUST the json with this format: | |
{concepts_parser.get_format_instructions()} | |
""" | |
print("Asking LLM to discover key concepts...") | |
output = await llm_model.ainvoke([SystemMessage(system_prompt)]) | |
try: | |
concepts = concepts_parser.parse(output.content) | |
print(f"LLM identified {len(concepts.concepts)} initial concepts.") | |
except Exception as e: | |
print(f"Error parsing concepts from LLM output: {e}") | |
print(f"LLM Output: {output.content}") | |
concepts = Concepts(concepts=[]) # Return empty list if parsing fails | |
# Return the updated state, including the dependency relationships | |
return GraphState( | |
files_metadata=all_files_metadata, | |
concepts=concepts.concepts, | |
dependencies=relationships, # Store the list of relationships in the state | |
) | |
async def refine_concepts(state: GraphState) -> GraphState: | |
""" | |
Refines discovered concepts using code samples retrieved from the vector store. | |
""" | |
llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
refined_concepts = [] | |
initial_concepts = state.get( | |
"concepts", [] | |
) # Get concepts from state, default to empty list | |
if not initial_concepts: | |
print("No initial concepts found to refine.") | |
return state # Return current state if no concepts | |
print(f"Refining {len(initial_concepts)} concepts...") | |
for concept in initial_concepts: | |
print(f"Refining concept: {concept.name}") | |
query = f"{concept.name} {concept.description}" | |
# Search for relevant code content documents | |
results = vector_store.similarity_search(query, k=5, filter={"type": "content"}) | |
code_samples = [] | |
for doc in results: | |
code_samples.append( | |
f"File: {doc.metadata['path']}\n```python\n{doc.page_content[:1500]}...\n```" # Use code block | |
) | |
# Refine concept with code evidence | |
system_prompt = f""" | |
Initial concept: | |
Name: {concept.name} | |
Description: {concept.description} | |
Files: {", ".join(concept.files)} | |
Relevant code samples: | |
{code_samples if code_samples else "No relevant code samples found."} | |
Based on these code samples, refine the concept description. | |
Provide: | |
1. A more accurate name if needed (keep it concise). | |
2. A detailed description that references the actual implementation and functionality. | |
3. The description should contain the functionality this concept provides. | |
4. The description should contain brief implementation details. | |
5. Which files likely implement this concept (referencing the provided file paths). | |
Ensure the refined concept is grounded in the code evidence. | |
MUST ALWAYS output JUST the json with this format: | |
{concept_parser.get_format_instructions()} | |
""" | |
# Use structured output parser to get proper format | |
print(f"Asking LLM to refine concept: {concept.name}") | |
output = await llm_model.ainvoke([SystemMessage(system_prompt)]) | |
try: | |
refined_concept = concept_parser.parse(output.content) | |
refined_concepts.append(refined_concept) | |
print(f"Successfully refined concept: {refined_concept.name}") | |
except Exception as e: | |
print( | |
f"Error parsing refined concept from LLM output for '{concept.name}': {e}" | |
) | |
print(f"LLM Output: {output.content}") | |
# Append the original concept if refinement fails | |
refined_concepts.append(concept) | |
# Return the updated state with refined concepts (preserving dependencies) | |
return GraphState( | |
files_metadata=state.get("files_metadata"), | |
concepts=refined_concepts, | |
dependencies=state.get("dependencies"), | |
) | |
async def generate_documentation(state: GraphState) -> GraphOutput: | |
""" | |
Generate comprehensive documentation for the codebase based on concepts and includes the dependency graph. | |
Dependency graph is built here from relationships in the state. | |
The final documentation is saved to disk. | |
""" | |
llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
refined_concepts = state.get("concepts", []) # Get concepts from state | |
dependencies = state.get( | |
"dependencies", [] | |
) # Get dependency relationships from state | |
# Prepare documentation structure | |
doc = "# Codebase Documentation\n\n" | |
# --- Add Dependency Graph Section --- | |
if dependencies: | |
# Build the mermaid graph string from the relationships | |
mermaid_graph = build_mermaid_graph_from_relationships(dependencies) | |
doc += "## Dependency Graph\n\n" | |
doc += "This graph visualizes the dependencies between files based on local imports.\n\n" | |
doc += "```mermaid\n" | |
doc += mermaid_graph | |
doc += "```\n\n---\n\n" | |
print( | |
f"Built and added dependency graph with {len(dependencies)} edges to documentation." | |
) | |
else: | |
print("No dependency relationships found to build graph.") | |
doc += f"## Dependency Graph\n\nCould not generate dependency graph: No dependency relationships found.\n\n---\n\n" | |
# --- Add Project Overview Section --- | |
doc += "## Project Overview\n\n" | |
if refined_concepts: | |
overview_prompt = f""" | |
Based on these concepts and their descriptions: | |
{[{"name": c.name, "description": c.description} for c in refined_concepts]} | |
Write a concise overview of what this project does, its architecture, and how the main components work together. | |
Keep it under 4 paragraphs and focus on making it clear to a new developer. | |
""" | |
print("Asking LLM to generate project overview...") | |
overview_response = await llm_model.ainvoke( | |
[SystemMessage(content=overview_prompt)] | |
) | |
doc += f"{overview_response.content}\n\n" | |
else: | |
doc += "Project overview could not be generated as no concepts were identified.\n\n" | |
print("Skipped project overview generation due to no concepts.") | |
# --- Add Table of Contents --- | |
if refined_concepts: | |
doc += "## Table of Contents\n\n" | |
for i, concept in enumerate(refined_concepts): | |
# Ensure valid markdown links | |
link_text = concept.name.replace(" ", "-").lower() # Simple slug | |
doc += f"{i + 1}. [{concept.name}](#{link_text}-{i + 1})\n" | |
doc += "\n" | |
print("Added table of contents.") | |
else: | |
print("Skipped table of contents generation due to no concepts.") | |
# --- Document Each Concept --- | |
if refined_concepts: | |
print("Generating documentation for each concept...") | |
for i, concept in enumerate(refined_concepts): | |
print(f"Documenting concept: {concept.name}") | |
query = f"{concept.name} {concept.description} example usage" | |
# Search for relevant code content documents | |
results = vector_store.similarity_search( | |
query, k=3, filter={"type": "content"} | |
) | |
code_samples = [] | |
for doc_result in results: | |
code_samples.append( | |
f"File: {doc_result.metadata['path']}\n```python\n{doc_result.page_content[:1000]}...\n```" # Use code block, limit length | |
) | |
concept_prompt = f""" | |
Concept: {concept.name} | |
Description: {concept.description} | |
Files: {", ".join(concept.files)} | |
Relevant code samples: | |
{code_samples if code_samples else "No relevant code samples found."} | |
Write comprehensive documentation for this concept including: | |
1. A detailed explanation of what it does | |
2. Its implementation details (how it works internally) | |
3. How to use it (with code examples) | |
4. Any important configuration options or parameters | |
Format using Markdown with proper headings, code blocks, and examples. | |
Make the examples practical and helpful for a developer new to this codebase. | |
""" | |
concept_response = await llm_model.ainvoke( | |
[SystemMessage(content=concept_prompt)] | |
) | |
# Add to documentation | |
link_text = concept.name.replace(" ", "-").lower() # Simple slug for anchor | |
doc += f"## {concept.name} <a name='{link_text}-{i + 1}'></a>\n\n" # Add anchor for TOC | |
doc += f"{concept_response.content}\n\n" | |
if concept.files: | |
doc += "**Implemented in files:**\n\n" | |
for file in concept.files: | |
doc += f"- `{file}`\n" | |
doc += "\n" | |
doc += "---\n\n" | |
print(f"Finished documenting concept: {concept.name}") | |
else: | |
print("Skipped concept documentation generation due to no concepts.") | |
# --- Add Getting Started Guide --- | |
# Provide the generated documentation so far as context | |
start_prompt = f""" | |
Based on the concepts and documentation about this codebase, write a getting started guide that includes: | |
1. Prerequisites | |
2. Installation steps | |
3. A simple example showing the main functionality | |
4. Common gotchas or things to watch out for | |
Make it practical and easy to follow for a new developer. Refer to the concepts and files mentioned in the documentation. | |
Here is the current documentation content for context: | |
--- | |
{doc} | |
--- | |
""" | |
print("Asking LLM to generate getting started guide...") | |
start_response = await llm_model.ainvoke([SystemMessage(content=start_prompt)]) | |
doc += "## Getting Started\n\n" | |
doc += f"{start_response.content}\n\n" | |
print("Added getting started guide.") | |
# --- Save Documentation to Disk --- | |
output_filename = "code_documentation.md" | |
try: | |
with open(output_filename, "w", encoding="utf-8") as f: | |
f.write(doc) | |
print(f"\nFinal documentation saved to '{output_filename}'") | |
except IOError as e: | |
print(f"Error writing documentation to file {output_filename}: {str(e)}") | |
# Return the final documentation string in the output | |
return GraphOutput( | |
documentation=doc, | |
) | |
# --- Build the LangChain Graph --- | |
builder = StateGraph(state_schema=GraphState, input=GraphInput, output=GraphOutput) | |
builder.add_node("discover_key_concepts", discover_key_concepts) | |
builder.add_node("refine_concepts", refine_concepts) | |
builder.add_node("generate_documentation", generate_documentation) | |
# Define the flow of the graph | |
builder.add_edge(START, "discover_key_concepts") | |
builder.add_edge("discover_key_concepts", "refine_concepts") | |
builder.add_edge("refine_concepts", "generate_documentation") | |
builder.add_edge("generate_documentation", END) | |
# Compile the graph | |
graph = builder.compile() | |
# To run the graph, you would typically do: | |
# import asyncio | |
# async def run_graph(): | |
# # Replace "your_topic" with the actual topic or goal for analysis | |
# result = await graph.ainvoke({"topic": "Explain the core functionality"}) | |
# # The final documentation will be in result['documentation'] | |
# # It will also be saved to code_documentation.md | |
# print(result['documentation']) | |
# if __name__ == "__main__": | |
# # Set your OpenAI API key environment variable before running | |
# # os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY" | |
# # asyncio.run(run_graph()) | |
# print("LangChain graph defined. Uncomment the run_graph section to execute.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment