Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Created August 15, 2025 14:21
Show Gist options
  • Select an option

  • Save grahama1970/ac61b615ed18f3622049704ddeb53a11 to your computer and use it in GitHub Desktop.

Select an option

Save grahama1970/ac61b615ed18f3622049704ddeb53a11 to your computer and use it in GitHub Desktop.
codebase indexer for an agent
#!/usr/bin/env python3
"""
Codebase Indexer for Semantic Code Search
A tool for indexing code repositories into ArangoDB with semantic embeddings,
enabling intelligent code search beyond simple text matching.
Key Features:
- Extracts functions/classes using tree-sitter AST parsing
- Generates semantic embeddings using nomic-embed-code model (1024-dim)
- Creates AI summaries for each code chunk using Ollama
- Supports incremental indexing (--since flag for git changes)
- Multi-language support (Python, JavaScript, TypeScript, Go, Rust)
- Per-project collections in ArangoDB for clean separation
- GPU acceleration when available, CPU fallback
Use Cases:
- Semantic code search: Find functions by meaning, not just text
- Code understanding: AI-generated summaries for quick comprehension
- Incremental updates: Only re-index changed files in large codebases
- Multi-project support: Index multiple repositories separately
Example Usage:
# Initial indexing
python codebase_indexer.py index /path/to/repo --project-id myproject
# Incremental update (only changed files)
python codebase_indexer.py index /path/to/repo --project-id myproject --since HEAD~10
# Search for authentication-related code
python codebase_indexer.py search "user authentication" --project myproject
# Show system capabilities
python codebase_indexer.py info
"""
import os
import json
import hashlib
import subprocess
import asyncio
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any, Union
import time
import textwrap
from datetime import datetime
import typer
from arango.client import ArangoClient
from arango.database import StandardDatabase
from arango.collection import StandardCollection
from sentence_transformers import SentenceTransformer
from tree_sitter_language_pack import get_parser
from dotenv import find_dotenv, load_dotenv
from loguru import logger
import numpy as np
from tqdm.asyncio import tqdm
import torch
import litellm
from tenacity import retry, stop_after_attempt, wait_exponential
try:
from lean4_prover.utils.json_utils import clean_json_string
except ImportError:
# Fallback for standalone execution
def clean_json_string(content: Union[str, dict, list], return_dict: bool = False) -> Union[str, dict, list]:
import json
try:
if isinstance(content, (dict, list)):
return content if return_dict else json.dumps(content)
if isinstance(content, str):
return json.loads(content) if return_dict else content
except:
return content
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
logger.warning("FAISS not installed - k-NN clustering will be limited")
FAISS_AVAILABLE = False
faiss = None # Type: ignore
try:
from ollama import Client as OllamaClient
OLLAMA_NATIVE_AVAILABLE = True
except ImportError:
logger.warning("Native ollama client not installed - will use LiteLLM")
OLLAMA_NATIVE_AVAILABLE = False
OllamaClient = None # Type: ignore
# Load environment
load_dotenv(find_dotenv())
# =========================================================
# CONFIGURATION
# =========================================================
EMBED_MODEL = os.getenv("CODE_EMBEDDING_MODEL", "nomic-ai/nomic-embed-code")
SUMMARY_MODEL = os.getenv("CODE_SUMMARY_MODEL", "ollama/gemma3:12b")
OLLAMA_TURBO_KEY = os.getenv("OLLAMA_TURBO_API_KEY", "")
ARANGO_URL = f"http://{os.getenv('ARANGO_HOST', 'localhost:8529')}"
ARANGO_DB = os.getenv("ARANGO_DB", "code_index")
ARANGO_USER = os.getenv("ARANGO_USERNAME", "root")
ARANGO_PASS = os.getenv("ARANGO_PASSWORD", "")
# Embedding batch size - can be overridden via environment
DEFAULT_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32"))
# File batch size for memory efficiency
FILE_BATCH_SIZE = int(os.getenv("FILE_BATCH_SIZE", "50"))
# Concurrency for API calls
CONCURRENT_SUMMARIES = int(os.getenv("CONCURRENT_SUMMARIES", "10"))
# Similarity threshold for search
DEFAULT_SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
# Use native ollama client instead of LiteLLM
USE_NATIVE_OLLAMA = os.getenv("USE_NATIVE_OLLAMA", "false").lower() == "true"
# Native ollama model to use
OLLAMA_NATIVE_MODEL = os.getenv("OLLAMA_NATIVE_MODEL", "gpt-oss:120b")
# =========================================================
# UTILITY FUNCTIONS
# =========================================================
def get_optimal_batch_size() -> int:
"""Calculate optimal batch size based on available GPU memory"""
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory
available_memory = gpu_memory - torch.cuda.memory_allocated()
# Rough estimate: each embedding uses ~4MB of GPU memory
optimal_size = int(available_memory / (4 * 1024 * 1024))
# Clamp to reasonable range
return min(max(optimal_size, 8), 128)
return DEFAULT_BATCH_SIZE
# =========================================================
# LANGUAGE MAPPINGS
# =========================================================
LANGUAGES = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".go": "go",
".rs": "rust",
}
# =========================================================
# CHUNK TYPES TO EXTRACT
# =========================================================
CHUNK_TYPES = {
"python": {"function_definition", "class_definition"},
"javascript": {"function_declaration", "class_declaration", "arrow_function"},
"typescript": {"function_declaration", "class_declaration", "arrow_function"},
"go": {"function_declaration", "method_declaration"},
"rust": {"function_item", "impl_item"},
}
# -------------------------------------------------
# DATABASE
# -------------------------------------------------
def get_db() -> StandardDatabase:
"""Get ArangoDB connection"""
client = ArangoClient(hosts=ARANGO_URL)
sys_db = client.db("_system", username=ARANGO_USER, password=ARANGO_PASS)
if not sys_db.has_database(ARANGO_DB):
sys_db.create_database(ARANGO_DB)
return client.db(ARANGO_DB, username=ARANGO_USER, password=ARANGO_PASS)
def setup_collection(db: StandardDatabase, project_id: str) -> StandardCollection:
"""Create collection if needed with all required indexes, view, and graph"""
collection_name = f"code_{project_id}"
view_name = f"{collection_name}_search"
graph_name = f"{project_id}_code_graph"
# Create collection if needed
if not db.has_collection(collection_name):
col = db.create_collection(collection_name)
col.add_hash_index(fields=["file_path"])
col.add_hash_index(fields=["chunk_type"])
# Add vector index for semantic search - CRITICAL for performance!
# Using inverted index which supports ANN (Approximate Nearest Neighbor) search
try:
col.add_inverted_index(
fields={"embedding": {}},
name="idx_embedding_vector",
analyzer="identity",
inBackground=False,
parallelism=2,
primarySort={"fields": [{"field": "embedding", "direction": "asc"}]},
cache=True
)
except Exception as e:
# Try simpler version if advanced options not supported
logger.warning(f"Advanced inverted index options failed: {e}, trying simple version")
try:
col.add_inverted_index(
fields={"embedding": {}},
name="idx_embedding_vector"
)
except Exception as e2:
logger.warning(f"Could not create inverted index for embeddings: {e2}")
logger.info(f"Created vector index on embedding field for {collection_name}")
else:
col = db.collection(collection_name)
# Create ArangoSearch view for BM25 search
views = db.views()
view_exists = any(v['name'] == view_name for v in views)
if not view_exists:
try:
db.create_view(
name=view_name,
view_type='arangosearch',
properties={
'links': {
collection_name: {
'analyzers': ['text_en'],
'fields': {
'text': {'analyzers': ['text_en']},
'file_summary': {'analyzers': ['text_en']},
'file_path': {'analyzers': ['identity']}
}
}
}
}
)
logger.info(f"Created ArangoSearch view: {view_name}")
except Exception as e:
logger.warning(f"Could not create search view {view_name}: {e}")
# Create graph for relationship traversal
try:
graphs = db.graphs()
graph_exists = any(g['name'] == graph_name for g in graphs)
if not graph_exists:
edge_collection = f"{collection_name}_relationships"
# Create edge collection if needed
if not db.has_collection(edge_collection):
db.create_collection(edge_collection, edge=True)
logger.info(f"Created edge collection: {edge_collection}")
# Create graph
db.create_graph(
name=graph_name,
edge_definitions=[{
'edge_collection': edge_collection,
'from_vertex_collections': [collection_name],
'to_vertex_collections': [collection_name]
}],
orphan_collections=[]
)
logger.info(f"Created graph: {graph_name}")
except Exception as e:
logger.warning(f"Could not create graph {graph_name}: {e}")
return col
# -------------------------------------------------
# CODE PARSING
# -------------------------------------------------
def extract_chunks(file_path: Path) -> List[Dict[str, Any]]:
"""Extract function/class chunks from a file"""
language = LANGUAGES.get(file_path.suffix)
if not language:
return []
try:
parser = get_parser(language)
tree = parser.parse(file_path.read_bytes())
chunks: List[Dict[str, Any]] = []
def walk(node: Any) -> None:
if node.type in CHUNK_TYPES.get(language, set()):
text = node.text.decode('utf-8', errors='ignore')
chunks.append({
"text": text,
"type": node.type,
"start_line": node.start_point[0] + 1,
"language": language,
})
for child in node.children:
walk(child)
walk(tree.root_node)
return chunks
except Exception as e:
logger.error(f"Failed to parse {file_path}: {e}")
return []
# -------------------------------------------------
# SUMMARIZATION
# -------------------------------------------------
def generate_file_summary_native(file_content: str) -> str:
"""Generate summary using native ollama client (synchronous)"""
if not OLLAMA_NATIVE_AVAILABLE:
return "File summary generation unavailable"
if not OLLAMA_TURBO_KEY:
logger.warning("OLLAMA_TURBO_API_KEY not set in environment")
return "Ollama API key not configured"
client = OllamaClient(
host="https://ollama.com",
headers={'Authorization': f'{OLLAMA_TURBO_KEY}'}
)
prompt = textwrap.dedent(f"""
Provide a brief summary (2–3 sentences) of the purpose and main functionality of the following code file. Focus on its key components, responsibilities, and overall role:
```
{file_content}
```
Summary:
""").strip()
response = client.chat(
model=OLLAMA_NATIVE_MODEL,
messages=[{"role": "user", "content": prompt}],
stream=False
)
return response['message']['content'].strip()
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
async def generate_file_summary(file_content: str) -> str:
"""Generate a concise summary of the entire file"""
# Use native ollama if configured
if USE_NATIVE_OLLAMA and OLLAMA_NATIVE_AVAILABLE:
# Run synchronous function in thread pool
import asyncio
return await asyncio.to_thread(generate_file_summary_native, file_content)
# Otherwise use LiteLLM
prompt = textwrap.dedent(f"""
Provide a brief summary (2–3 sentences) of the purpose and main functionality of the following code file. Focus on its key components, responsibilities, and overall role:
```
{file_content}
```
Summary:
""").strip()
response = await litellm.acompletion(
model=SUMMARY_MODEL,
messages=[{"role": "user", "content": prompt}],
max_tokens=100,
temperature=0.3,
)
summary = clean_json_string(response.choices[0].message.content, return_dict=False)
return str(summary).strip()
# -------------------------------------------------
# MAIN INDEXING
# -------------------------------------------------
async def index_repository(repo: Path, files: List[Path], project_id: str, batch_files: bool = True) -> None:
"""Index repository files with optional batching for memory efficiency"""
db = get_db()
collection = setup_collection(db, project_id)
# Process files in batches if enabled (default)
if batch_files and len(files) > FILE_BATCH_SIZE:
logger.info(f"Processing {len(files)} files in batches of {FILE_BATCH_SIZE}")
total_chunks = 0
for i in range(0, len(files), FILE_BATCH_SIZE):
batch = files[i:i + FILE_BATCH_SIZE]
logger.info(f"Processing batch {i//FILE_BATCH_SIZE + 1}/{(len(files) + FILE_BATCH_SIZE - 1)//FILE_BATCH_SIZE}")
# Process this batch
chunk_count = await _process_file_batch(repo, batch, collection, project_id)
total_chunks += chunk_count
# Clear memory between batches
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"Indexed {total_chunks} chunks total from {len(files)} files")
else:
# Process all files at once for small repositories
chunk_count = await _process_file_batch(repo, files, collection, project_id)
logger.info(f"Indexed {chunk_count} chunks from {len(files)} files")
async def _process_file_batch(repo: Path, files: List[Path], collection: StandardCollection, project_id: str) -> int:
"""Process a batch of files and return the number of chunks indexed"""
# First, generate file-level summaries
logger.info(f"Generating summaries for {len(files)} files...")
file_summaries: Dict[str, str] = {}
semaphore = asyncio.Semaphore(CONCURRENT_SUMMARIES)
async def summarize_file(file_path: Path) -> None:
async with semaphore:
try:
# Safe file reading with encoding handling
content = file_path.read_text(encoding='utf-8', errors='replace')
summary = await generate_file_summary(content)
file_summaries[str(file_path.relative_to(repo))] = summary
except Exception as e:
logger.error(f"Failed to summarize {file_path}: {e}")
file_summaries[str(file_path.relative_to(repo))] = ""
tasks = [summarize_file(f) for f in files]
with tqdm(total=len(tasks), desc="Generating file summaries") as pbar:
for coro in asyncio.as_completed(tasks):
await coro
pbar.update(1)
# Extract all chunks with file summaries
all_chunks = []
for file_path in tqdm(files, desc="Parsing files"):
chunks = extract_chunks(file_path)
relative_path = str(file_path.relative_to(repo))
file_summary = file_summaries.get(relative_path, "")
for chunk in chunks:
chunk['file_path'] = relative_path
chunk['file_summary'] = file_summary
all_chunks.append(chunk)
if not all_chunks:
logger.warning("No chunks found in this batch")
return 0
logger.info(f"Found {len(all_chunks)} chunks in this batch")
# Create embeddings with file summary as context
logger.info("Creating embeddings...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = SentenceTransformer(EMBED_MODEL, device=device)
# Prepend file summary to each chunk for better context
texts = [f"File summary: {chunk['file_summary']}\n\n{chunk['text']}" for chunk in all_chunks]
batch_size = get_optimal_batch_size()
logger.debug(f"Using batch size {batch_size} for embeddings")
embeddings = encoder.encode(texts, batch_size=batch_size, show_progress_bar=True)
# Store in database
logger.info("Storing in database...")
documents = []
for chunk, embedding in zip(all_chunks, embeddings):
key = hashlib.sha256(
f"{chunk['file_path']}:{chunk['start_line']}".encode()
).hexdigest()
documents.append({
"_key": key,
"file_path": chunk['file_path'],
"chunk_type": chunk['type'],
"start_line": chunk['start_line'],
"language": chunk['language'],
"file_summary": chunk['file_summary'],
"text": chunk['text'],
"embedding": embedding.tolist(),
})
collection.import_bulk(documents, on_duplicate="replace")
return len(documents)
# -------------------------------------------------
# FAISS INDEX BUILDING
# -------------------------------------------------
def build_faiss_index(collection_name: str, db: Optional[StandardDatabase] = None, force_rebuild: bool = False) -> Optional[Dict[str, Any]]:
"""Build FAISS index for fast k-NN search
Args:
collection_name: Collection to index
db: Database connection (optional)
force_rebuild: Force rebuild even if index exists
Returns:
Dict with index info or None if failed
"""
global _faiss_indexes
if not FAISS_AVAILABLE:
logger.warning("FAISS not available - cannot build index")
return None
if not hasattr(build_faiss_index, '_faiss_indexes'):
build_faiss_index._faiss_indexes = {}
# Check if index already exists
if collection_name in build_faiss_index._faiss_indexes and not force_rebuild:
return build_faiss_index._faiss_indexes[collection_name]
try:
if db is None:
db = get_db()
# Get all documents with embeddings
cursor = db.aql.execute(f"""
FOR doc IN {collection_name}
FILTER doc.embedding != null
RETURN {{
_key: doc._key,
file_path: doc.file_path,
start_line: doc.start_line,
embedding: doc.embedding,
text: doc.text,
file_summary: doc.file_summary
}}
""")
docs = list(cursor)
if not docs:
logger.warning(f"No documents with embeddings found in {collection_name}")
return None
# Extract embeddings and build index
embeddings = np.array([doc['embedding'] for doc in docs], dtype='float32')
dimension = embeddings.shape[1]
# Create FAISS index - using IndexFlatIP for inner product (cosine similarity with normalized vectors)
index = faiss.IndexFlatIP(dimension)
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings)
# Add to index
index.add(embeddings)
# Store index and metadata
index_info = {
'index': index,
'documents': docs,
'dimension': dimension,
'num_vectors': len(docs),
'collection': collection_name,
'created_at': datetime.now().isoformat()
}
build_faiss_index._faiss_indexes[collection_name] = index_info
logger.info(f"Built FAISS index for {collection_name}: {len(docs)} vectors, {dimension}D")
return index_info
except Exception as e:
logger.error(f"Failed to build FAISS index: {e}")
return None
def search_with_faiss(
query_embedding: np.ndarray,
collection_name: str,
k: int = 10,
min_similarity: float = 0.0
) -> List[Dict[str, Any]]:
"""Search using FAISS index for fast k-NN
Args:
query_embedding: Query vector
collection_name: Collection to search in
k: Number of neighbors
min_similarity: Minimum similarity threshold
Returns:
List of similar documents with scores
"""
if not FAISS_AVAILABLE:
return []
# Get or build index
index_info = build_faiss_index(collection_name)
if not index_info:
return []
try:
index = index_info['index']
documents = index_info['documents']
# Normalize query for cosine similarity
query_vec = np.array([query_embedding], dtype='float32')
faiss.normalize_L2(query_vec)
# Search
distances, indices = index.search(query_vec, min(k, index.ntotal))
# Convert to results
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx >= 0 and dist >= min_similarity: # Valid index and above threshold
doc = documents[idx].copy()
doc['score'] = float(dist) # Inner product = cosine similarity for normalized vectors
doc['search_type'] = 'faiss_knn'
results.append(doc)
return results
except Exception as e:
logger.error(f"FAISS search error: {e}")
return []
async def build_similarity_relationships(
collection_name: str,
similarity_threshold: float = 0.8,
max_relationships_per_doc: int = 5,
db: Optional[StandardDatabase] = None
) -> Dict[str, Any]:
"""Build similarity relationships using FAISS k-NN and LLM rationales
This follows the established pattern:
1. FAISS finds similar code chunks
2. LLM generates rationale for why they're related
3. Relationships stored in edge collection
Args:
collection_name: Collection to analyze
similarity_threshold: Minimum similarity for creating edge
max_relationships_per_doc: Maximum edges per document
db: Database connection
Returns:
Dict with statistics about created relationships
"""
if not FAISS_AVAILABLE:
logger.warning("FAISS not available - cannot build relationships")
return {"success": False, "error": "FAISS not installed"}
try:
if db is None:
db = get_db()
# Get or build FAISS index
index_info = build_faiss_index(collection_name, db)
if not index_info:
return {"success": False, "error": "Failed to build FAISS index"}
documents = index_info['documents']
index = index_info['index']
edge_collection = f"{collection_name}_relationships"
# Ensure edge collection exists
if not db.has_collection(edge_collection):
db.create_collection(edge_collection, edge=True)
created_count = 0
semaphore = asyncio.Semaphore(CONCURRENT_SUMMARIES)
# For each document, find similar ones
for i, doc in enumerate(tqdm(documents, desc="Finding relationships")):
# Get embedding
query_vec = np.array([doc['embedding']], dtype='float32')
faiss.normalize_L2(query_vec)
# Search for similar documents
distances, indices = index.search(query_vec, max_relationships_per_doc + 1)
# Process similar documents
for dist, idx in zip(distances[0], indices[0]):
if idx == i or idx >= len(documents):
continue # Skip self and invalid indices
similarity = float(dist) # Inner product = cosine similarity
if similarity >= similarity_threshold:
target_doc = documents[idx]
# Calculate weight based on hierarchy and similarity
weight = calculate_relationship_weight(doc, target_doc, similarity)
# Generate LLM rationale and check if relationship should be created
async with semaphore:
rationale, should_create = await generate_relationship_rationale(
doc, target_doc, similarity, weight
)
if not should_create:
logger.debug(f"LLM rejected relationship: {rationale}")
continue
# Create edge document following project pattern
edge_doc = {
"_from": f"{collection_name}/{doc['_key']}",
"_to": f"{collection_name}/{target_doc['_key']}",
"weight": weight,
"rationale": rationale,
"similarity_score": similarity,
"relationship_type": "similar_code",
"created_at": datetime.now().isoformat()
}
try:
db.collection(edge_collection).insert(edge_doc)
created_count += 1
except Exception as e:
# Ignore duplicate edge errors
if "duplicate" not in str(e).lower():
logger.warning(f"Failed to create edge: {e}")
return {
"success": True,
"relationships_created": created_count,
"documents_analyzed": len(documents),
"edge_collection": edge_collection
}
except Exception as e:
logger.error(f"Failed to build similarity relationships: {e}")
return {"success": False, "error": str(e)}
def calculate_relationship_weight(source_doc: Dict[str, Any], target_doc: Dict[str, Any], similarity: float) -> float:
"""Calculate edge weight based on file hierarchy and similarity
Following the project pattern of using hierarchy to influence weights:
- Same directory = higher weight
- Parent/child directories = medium weight
- Different subtrees = lower weight
"""
source_path = Path(source_doc['file_path'])
target_path = Path(target_doc['file_path'])
# Base weight from similarity
weight = similarity * 0.6 # 60% from similarity
# Hierarchy bonus (40%)
if source_path.parent == target_path.parent:
# Same directory - highest bonus
weight += 0.4
elif source_path.parent in target_path.parents or target_path.parent in source_path.parents:
# Parent/child relationship - medium bonus
weight += 0.25
else:
# Check common ancestor depth
common_parts = 0
for s, t in zip(source_path.parts, target_path.parts):
if s == t:
common_parts += 1
else:
break
# More common ancestors = higher bonus
weight += 0.1 * (common_parts / max(len(source_path.parts), len(target_path.parts)))
return min(1.0, weight) # Cap at 1.0
def generate_relationship_rationale_native(
source_doc: Dict[str, Any],
target_doc: Dict[str, Any],
similarity_score: float,
weight: float
) -> Tuple[str, bool]:
"""Generate rationale using native ollama client (synchronous)"""
if not OLLAMA_NATIVE_AVAILABLE:
return "Rationale generation unavailable", False
if not OLLAMA_TURBO_KEY:
logger.warning("OLLAMA_TURBO_API_KEY not set in environment")
return "Ollama API key not configured", False
client = OllamaClient(
host="https://ollama.com",
headers={'Authorization': f'{OLLAMA_TURBO_KEY}'}
)
prompt = textwrap.dedent(f"""
Analyze these two code chunks (similarity: {similarity_score:.3f}, weight: {weight:.3f}).
Code 1 ({source_doc['file_path']}:{source_doc['start_line']}):
```
{source_doc['text'][:500]}
```
Code 2 ({target_doc['file_path']}:{target_doc['start_line']}):
```
{target_doc['text'][:500]}
```
Determine if these code chunks have a meaningful relationship worth tracking.
Consider:
- Are they genuinely related in functionality/purpose?
- Do they share important patterns or algorithms?
- Would knowing about one help when working on the other?
- Or are they just superficially similar (e.g., boilerplate)?
Respond with:
1. DECISION: "CREATE" or "REJECT"
2. RATIONALE: Brief explanation (1-2 sentences)
Format:
DECISION: [CREATE/REJECT]
RATIONALE: [Your explanation]
""").strip()
response = client.chat(
model=OLLAMA_NATIVE_MODEL,
messages=[{"role": "user", "content": prompt}],
stream=False
)
content = response['message']['content'].strip()
# Parse response
lines = content.split('\n')
decision = "REJECT" # Default to reject
rationale = "No clear relationship identified"
for line in lines:
if line.startswith("DECISION:"):
decision = line.replace("DECISION:", "").strip().upper()
elif line.startswith("RATIONALE:"):
rationale = line.replace("RATIONALE:", "").strip()
should_create = decision == "CREATE"
return rationale, should_create
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
async def generate_relationship_rationale(
source_doc: Dict[str, Any],
target_doc: Dict[str, Any],
similarity_score: float,
weight: float
) -> Tuple[str, bool]:
"""Generate LLM rationale and decide if relationship should be created
Returns:
Tuple of (rationale, should_create)
"""
# Use native ollama if configured
if USE_NATIVE_OLLAMA and OLLAMA_NATIVE_AVAILABLE:
# Run synchronous function in thread pool
import asyncio
return await asyncio.to_thread(
generate_relationship_rationale_native,
source_doc, target_doc, similarity_score, weight
)
# Otherwise use LiteLLM
prompt = textwrap.dedent(f"""
Analyze these two code chunks (similarity: {similarity_score:.3f}, weight: {weight:.3f}).
Code 1 ({source_doc['file_path']}:{source_doc['start_line']}):
```
{source_doc['text'][:500]}
```
Code 2 ({target_doc['file_path']}:{target_doc['start_line']}):
```
{target_doc['text'][:500]}
```
Determine if these code chunks have a meaningful relationship worth tracking.
Consider:
- Are they genuinely related in functionality/purpose?
- Do they share important patterns or algorithms?
- Would knowing about one help when working on the other?
- Or are they just superficially similar (e.g., boilerplate)?
Respond with:
1. DECISION: "CREATE" or "REJECT"
2. RATIONALE: Brief explanation (1-2 sentences)
Format:
DECISION: [CREATE/REJECT]
RATIONALE: [Your explanation]
""").strip()
response = await litellm.acompletion(
model=SUMMARY_MODEL,
messages=[{"role": "user", "content": prompt}],
max_tokens=150,
temperature=0.3,
)
content = clean_json_string(response.choices[0].message.content, return_dict=False)
content = str(content).strip()
# Parse response
lines = content.split('\n')
decision = "REJECT" # Default to reject
rationale = "No clear relationship identified"
for line in lines:
if line.startswith("DECISION:"):
decision = line.replace("DECISION:", "").strip().upper()
elif line.startswith("RATIONALE:"):
rationale = line.replace("RATIONALE:", "").strip()
should_create = decision == "CREATE"
return rationale, should_create
def build_code_clusters(
collection_name: str,
num_clusters: int = 20,
db: Optional[StandardDatabase] = None
) -> Optional[Dict[str, Any]]:
"""Build k-means clusters of code embeddings
Args:
collection_name: Collection to cluster
num_clusters: Number of clusters
db: Database connection
Returns:
Dict with cluster info or None
"""
if not FAISS_AVAILABLE:
logger.warning("FAISS not available - cannot build clusters")
return None
try:
# Get or build index
index_info = build_faiss_index(collection_name, db)
if not index_info:
return None
documents = index_info['documents']
embeddings = np.array([doc['embedding'] for doc in documents], dtype='float32')
# Normalize for cosine similarity
faiss.normalize_L2(embeddings)
# Perform k-means clustering
dimension = embeddings.shape[1]
kmeans = faiss.Kmeans(dimension, num_clusters, niter=20, verbose=True)
kmeans.train(embeddings)
# Get cluster assignments
_, cluster_ids = kmeans.index.search(embeddings, 1)
# Build cluster info
clusters = {}
for i, (doc, cluster_id) in enumerate(zip(documents, cluster_ids.flatten())):
cluster_id = int(cluster_id)
if cluster_id not in clusters:
clusters[cluster_id] = {
'documents': [],
'centroid': kmeans.centroids[cluster_id].tolist(),
'size': 0
}
clusters[cluster_id]['documents'].append({
'file_path': doc['file_path'],
'start_line': doc['start_line'],
'text_preview': doc['text'][:200] + '...' if len(doc['text']) > 200 else doc['text']
})
clusters[cluster_id]['size'] += 1
# Find representative samples for each cluster (closest to centroid)
for cluster_id, cluster_info in clusters.items():
centroid = kmeans.centroids[cluster_id:cluster_id+1]
cluster_embeddings = embeddings[cluster_ids.flatten() == cluster_id]
distances = np.dot(cluster_embeddings, centroid.T).flatten()
best_idx = np.argmax(distances)
cluster_info['representative'] = cluster_info['documents'][best_idx]
cluster_info = {
'num_clusters': num_clusters,
'clusters': clusters,
'collection': collection_name,
'total_documents': len(documents),
'created_at': datetime.now().isoformat()
}
logger.info(f"Built {num_clusters} clusters for {collection_name}")
return cluster_info
except Exception as e:
logger.error(f"Failed to build clusters: {e}")
return None
# -------------------------------------------------
# SEARCH
# -------------------------------------------------
async def search_code(query: str, project_id: str, limit: int = 10, threshold: float = DEFAULT_SIMILARITY_THRESHOLD, use_faiss: bool = True) -> List[Dict[str, Any]]:
"""Search for code using hybrid approach: FAISS k-NN + BM25 + graph
This follows the established pattern:
1. FAISS for fast k-NN semantic search
2. BM25 for keyword matching
3. Graph traversal for related code
4. Results merged with Reciprocal Rank Fusion
"""
db = get_db()
collection_name = f"code_{project_id}"
view_name = f"{collection_name}_search"
if not db.has_collection(collection_name):
return []
# Generate query embedding
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = SentenceTransformer(EMBED_MODEL, device=device)
query_embedding = encoder.encode(query)
all_results = []
# 1. FAISS k-NN search (if available and requested)
if use_faiss and FAISS_AVAILABLE:
faiss_results = search_with_faiss(
query_embedding=query_embedding,
collection_name=collection_name,
k=limit * 2, # Get more candidates
min_similarity=threshold
)
all_results.extend(faiss_results)
logger.debug(f"FAISS found {len(faiss_results)} results")
# 2. Try hybrid search with graph (if available)
try:
from lean4_prover.core.storage_adapter import hybrid_search_with_graph
result = await hybrid_search_with_graph(
collection=collection_name,
query=query,
text_fields=["text", "file_summary"],
view_name=view_name,
embedding_field="embedding",
graph_name=f"{project_id}_code_graph",
top_k=limit * 2,
bm25_weight=0.3,
semantic_weight=0.4,
graph_weight=0.3,
min_bm25_score=0.1,
min_semantic_score=threshold,
max_graph_depth=2,
rrf_k=60,
include_metadata=True
)
if result.get("success") and result.get("results"):
for r in result["results"]:
all_results.append({
"file_path": r.get("file_path", ""),
"start_line": r.get("start_line", 0),
"chunk_type": r.get("chunk_type", ""),
"language": r.get("language", ""),
"text": r.get("text", ""),
"file_summary": r.get("file_summary", ""),
"score": r.get("final_score", r.get("semantic_score", 0)),
"search_type": "hybrid_graph",
"_key": r.get("_key", "")
})
logger.debug(f"Hybrid search found {len(result['results'])} results")
except ImportError:
logger.warning("Hybrid search not available")
except Exception as e:
logger.error(f"Hybrid search error: {e}")
# 3. If no results yet, fall back to simple semantic search
if not all_results:
logger.info("Using fallback semantic search")
all_results = await _semantic_search_fallback(query, collection_name, limit, threshold)
# 4. Deduplicate and rank results
seen_keys = set()
final_results = []
for result in sorted(all_results, key=lambda x: x.get('score', 0), reverse=True):
key = result.get('_key') or f"{result['file_path']}:{result['start_line']}"
if key not in seen_keys:
seen_keys.add(key)
final_results.append(result)
if len(final_results) >= limit:
break
logger.info(f"Search returned {len(final_results)} results")
return final_results
async def _semantic_search_fallback(query: str, collection_name: str, limit: int, threshold: float) -> List[Dict[str, Any]]:
"""Fallback to simple semantic search if hybrid search unavailable"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = SentenceTransformer(EMBED_MODEL, device=device)
query_embedding = encoder.encode(query)
db = get_db()
cursor = db.aql.execute(f"""
LET query_vec = @query_embedding
FOR doc IN {collection_name}
LET score = COSINE_SIMILARITY(query_vec, doc.embedding)
FILTER score > @threshold
SORT score DESC
LIMIT @limit
RETURN {{
file_path: doc.file_path,
start_line: doc.start_line,
chunk_type: doc.chunk_type,
language: doc.language,
text: doc.text,
file_summary: doc.file_summary,
score: score
}}
""", bind_vars={
'query_embedding': query_embedding.tolist(),
'threshold': threshold,
'limit': limit
})
results = []
for doc in cursor:
results.append(doc)
return results
# -------------------------------------------------
# CLI
# -------------------------------------------------
app = typer.Typer()
@app.command()
def index(
repo: Path = typer.Argument(..., help="Repository path"),
project_id: Optional[str] = typer.Option(None, "--project-id"),
since: Optional[str] = typer.Option(None, "--since", help="Index only changes since commit"),
) -> None:
"""Index a repository"""
if not project_id:
project_id = repo.name.lower().replace("-", "_")
# Find files
if since:
cmd = f"git diff --name-only {since}..HEAD"
result = subprocess.run(cmd.split(), cwd=repo, capture_output=True, text=True)
changed = result.stdout.strip().splitlines()
files = [repo / f for f in changed if (repo / f).suffix in LANGUAGES and (repo / f).exists()]
else:
files = [f for f in repo.rglob("*") if f.suffix in LANGUAGES and ".git" not in f.parts]
if not files:
typer.secho("No files to index", fg=typer.colors.YELLOW)
return
typer.secho(f"Indexing {len(files)} files...", fg=typer.colors.BLUE)
start = time.time()
asyncio.run(index_repository(repo, files, project_id))
elapsed = time.time() - start
typer.secho(f"✓ Completed in {elapsed:.1f}s", fg=typer.colors.GREEN)
@app.command()
def search(
query: str = typer.Argument(..., help="Search query"),
project_id: str = typer.Option(..., "--project"),
limit: int = typer.Option(10, "--limit"),
threshold: float = typer.Option(DEFAULT_SIMILARITY_THRESHOLD, "--threshold", help="Similarity threshold (0.0-1.0)"),
) -> None:
"""Search indexed code"""
results = asyncio.run(search_code(query, project_id, limit, threshold))
if not results:
typer.secho("No results found", fg=typer.colors.YELLOW)
return
for i, result in enumerate(results, 1):
typer.secho(f"\n{i}. Score: {result.get('score', 0):.3f}", fg=typer.colors.CYAN)
typer.secho(f" {result['file_path']}:{result.get('start_line', '?')}", fg=typer.colors.GREEN)
if result.get('summary'):
typer.secho(f" {result['summary']}", fg=typer.colors.BLUE)
@app.command()
def build_relationships(
project_id: str = typer.Option(..., "--project"),
threshold: float = typer.Option(0.8, "--threshold", help="Similarity threshold"),
max_per_doc: int = typer.Option(5, "--max-per-doc", help="Max relationships per document"),
) -> None:
"""Build similarity relationships using FAISS and LLM rationales"""
collection_name = f"code_{project_id}"
typer.secho(f"Building relationships for {collection_name}...", fg=typer.colors.BLUE)
result = asyncio.run(build_similarity_relationships(
collection_name=collection_name,
similarity_threshold=threshold,
max_relationships_per_doc=max_per_doc
))
if result.get("success"):
typer.secho(
f"✓ Created {result['relationships_created']} relationships from {result['documents_analyzed']} documents",
fg=typer.colors.GREEN
)
else:
typer.secho(f"✗ Failed: {result.get('error')}", fg=typer.colors.RED)
@app.command()
def cluster(
project_id: str = typer.Option(..., "--project"),
num_clusters: int = typer.Option(20, "--clusters", help="Number of clusters"),
) -> None:
"""Build k-means clusters of code embeddings"""
collection_name = f"code_{project_id}"
typer.secho(f"Building {num_clusters} clusters for {collection_name}...", fg=typer.colors.BLUE)
result = build_code_clusters(collection_name, num_clusters)
if result:
typer.secho(f"✓ Built {result['num_clusters']} clusters from {result['total_documents']} documents", fg=typer.colors.GREEN)
# Show cluster summaries
for cluster_id, info in result['clusters'].items():
typer.secho(f"\nCluster {cluster_id} ({info['size']} documents):", fg=typer.colors.CYAN)
if 'representative' in info:
rep = info['representative']
typer.secho(f" Representative: {rep['file_path']}:{rep['start_line']}", fg=typer.colors.BLUE)
else:
typer.secho("✗ Failed to build clusters", fg=typer.colors.RED)
@app.command()
def info() -> None:
"""Show system info"""
gpu = torch.cuda.is_available()
typer.secho(f"GPU: {'Yes' if gpu else 'No'}")
if gpu:
typer.secho(f"GPU Name: {torch.cuda.get_device_name(0)}")
typer.secho(f"Embedding model: {EMBED_MODEL}")
typer.secho(f"Summary model: {SUMMARY_MODEL}")
typer.secho(f"FAISS available: {'Yes' if FAISS_AVAILABLE else 'No'}")
typer.secho(f"Native Ollama available: {'Yes' if OLLAMA_NATIVE_AVAILABLE else 'No'}")
typer.secho(f"Use native Ollama: {'Yes' if USE_NATIVE_OLLAMA else 'No'}")
if USE_NATIVE_OLLAMA and OLLAMA_NATIVE_AVAILABLE:
typer.secho(f"Native Ollama model: {OLLAMA_NATIVE_MODEL}")
if __name__ == "__main__":
app()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment