Created
June 25, 2024 12:03
-
-
Save paulwababu/d7210340158c924e40b3112485c3acc8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from langchain_openai import OpenAI | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import DirectoryLoader | |
from langchain.vectorstores import FAISS | |
import gradio as gr | |
from langchain.embeddings import OpenAIEmbeddings | |
# Define the path to the code folder | |
CODE_DIR = os.path.join(os.getcwd(), "code") | |
# Function to load code files from the specified directory | |
def load_code_files(directory, file_extensions=[".py", ".js", ".java"]): | |
documents = [] | |
try: | |
loader = DirectoryLoader(path=directory) | |
documents = loader.load() | |
except ImportError: | |
print("Using fallback UnstructuredFileLoader for each file") | |
for root, _, files in os.walk(directory): | |
for file in files: | |
if any(file.endswith(ext) for ext in file_extensions): | |
file_path = os.path.join(root, file) | |
loader = UnstructuredFileLoader(file_path) | |
documents.extend(loader.load()) | |
except Exception as e: | |
print(f"Error loading files: {e}") | |
filtered_documents = [ | |
doc for doc in documents if os.path.splitext(doc.metadata['source'])[1] in file_extensions | |
] | |
return filtered_documents | |
# Initialize the OpenAI GPT-4 model | |
openai_api_key = os.getenv("OPENAI_API_KEY") # Use environment variable for the API key | |
llm = OpenAI(model="gpt-3.5-turbo-instruct", openai_api_key=openai_api_key) | |
# Create a retrieval-based chain | |
def create_retrieval_chain(directory): | |
documents = load_code_files(directory) | |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
vector_store = FAISS.from_documents(documents, embeddings) | |
retriever = vector_store.as_retriever() | |
chain = RetrievalQA.from_chain_type( | |
llm, | |
retriever=retriever, | |
chain_type="map_reduce" | |
) | |
return chain | |
# Function to chat with the codebase | |
def chat_with_codebase(query): | |
chain = create_retrieval_chain(CODE_DIR) | |
response = chain.run({"query": query}) | |
return response | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=chat_with_codebase, | |
inputs="text", | |
outputs="text", | |
title="Chat with Your Codebase", | |
description="Ask questions about the code in the 'code' folder.", | |
examples=["Explain the function calculate_sum", "What does the main function do?"] | |
) | |
# Launch the Gradio app with share=True to enable sharing via an external URL | |
if __name__ == "__main__": | |
iface.launch(share=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment