-
-
Save rajivmehtaflex/4e3767e0a816160e3540c8fc5a2393cc to your computer and use it in GitHub Desktop.
Example using LlamaHub loaders to index Github repos into LlamaIndex and query GPTSimpleVectorIndex with GPT-4
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# main | |
llama-index | |
langchain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Modified llama-hub example for github_repo""" | |
import argparse | |
import logging | |
import os | |
import pickle | |
from langchain.chat_models import ChatOpenAI | |
from llama_index import ( | |
GPTSimpleVectorIndex, | |
LLMPredictor, | |
ServiceContext, | |
download_loader, | |
) | |
# from llama_index.logger.base import LlamaLogger | |
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode | |
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter | |
from llama_index.node_parser.simple import SimpleNodeParser | |
from llama_index.prompts.chat_prompts import CHAT_REFINE_PROMPT | |
assert ( | |
os.getenv("OPENAI_API_KEY") is not None | |
), "Please set the OPENAI_API_KEY environment variable." | |
assert ( | |
os.getenv("GITHUB_TOKEN") is not None | |
), "Please set the GITHUB_TOKEN environment variable." | |
# This is a way to test loaders on different forks/branches. | |
# LLAMA_HUB_CONTENTS_URL = "https://raw.githubusercontent.com/claysauruswrecks/llama-hub/bugfix/github-repo-splitter" # noqa: E501 | |
# LOADER_HUB_PATH = "/loader_hub" | |
# LOADER_HUB_URL = LLAMA_HUB_CONTENTS_URL + LOADER_HUB_PATH | |
download_loader( | |
"GithubRepositoryReader", | |
# loader_hub_url=LOADER_HUB_URL, | |
# refresh_cache=True, | |
) | |
from llama_index.readers.llamahub_modules.github_repo import ( # noqa: E402 | |
GithubClient, | |
GithubRepositoryReader, | |
) | |
# TODO: Modify github loader to support exclude list of filenames and unblock .ipynb # noqa: E501 | |
REPOS = { | |
# NOTE: Use this to find long line filetypes to avoid: `find . -type f -exec sh -c 'awk "BEGIN { max = 0 } { if (length > max) max = length } END { printf \"%s:%d\n\", FILENAME, max }" "{}"' \; | sort -t: -k2 -nr` # noqa: E501 | |
"jerryjliu/llama_index@1b739e1fcd525f73af4a7131dd52c7750e9ca247": dict( | |
filter_directories=( | |
["docs", "examples", "gpt_index", "tests"], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
filter_file_extensions=( | |
[ | |
".bat", | |
".md", | |
# ".ipynb", | |
".py", | |
".rst", | |
".sh", | |
], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
), | |
"emptycrown/llama-hub@8312da4ee8fcaf2cbbf5315a2ab8f170d102d081": dict( | |
filter_directories=( | |
["loader_hub", "tests"], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
filter_file_extensions=( | |
[".py", ".md", ".txt"], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
), | |
"hwchase17/langchain@d85f57ef9cbbbd5e512e064fb81c531b28c6591c": dict( | |
filter_directories=( | |
["docs", "langchain", "tests"], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
filter_file_extensions=( | |
[ | |
".bat", | |
".md", | |
# ".ipynb", | |
".py", | |
".rst", | |
".sh", | |
], | |
GithubRepositoryReader.FilterType.INCLUDE, | |
), | |
), | |
} | |
# MODEL_NAME = "gpt-3.5-turbo" | |
MODEL_NAME = "gpt-4" | |
CHUNK_SIZE_LIMIT = 512 | |
CHUNK_OVERLAP = 200 # default | |
MAX_TOKENS = None # Set to None to use model's maximum | |
EMBED_MODEL = OpenAIEmbedding(mode=OpenAIEmbeddingMode.SIMILARITY_MODE) | |
LLM_PREDICTOR = LLMPredictor( | |
llm=ChatOpenAI( | |
temperature=0.0, model_name=MODEL_NAME, max_tokens=MAX_TOKENS | |
) | |
) | |
PICKLE_DOCS_DIR = os.path.join( | |
os.path.join(os.path.join(os.path.dirname(__file__), "./"), "data"), | |
"pickled_docs", | |
) | |
# Create the directory if it does not exist | |
if not os.path.exists(PICKLE_DOCS_DIR): | |
os.makedirs(PICKLE_DOCS_DIR) | |
def load_pickle(filename): | |
"""Load the pickled embeddings""" | |
with open(os.path.join(PICKLE_DOCS_DIR, filename), "rb") as f: | |
logging.debug(f"Loading pickled embeddings from {filename}") | |
return pickle.load(f) | |
def save_pickle(obj, filename): | |
"""Save the pickled embeddings""" | |
with open(os.path.join(PICKLE_DOCS_DIR, filename), "wb") as f: | |
logging.debug(f"Saving pickled embeddings to {filename}") | |
pickle.dump(obj, f) | |
def main(args): | |
"""Run the trap.""" | |
g_docs = {} | |
for repo in REPOS.keys(): | |
logging.debug(f"Processing {repo}") | |
repo_owner, repo_name_at_sha = repo.split("/") | |
repo_name, commit_sha = repo_name_at_sha.split("@") | |
docs_filename = f"{repo_owner}-{repo_name}-{commit_sha}-docs.pkl" | |
docs_filepath = os.path.join(PICKLE_DOCS_DIR, docs_filename) | |
if os.path.exists(docs_filepath): | |
logging.debug(f"Path exists: {docs_filepath}") | |
g_docs[repo] = load_pickle(docs_filename) | |
if not g_docs.get(repo): | |
github_client = GithubClient(os.getenv("GITHUB_TOKEN")) | |
loader = GithubRepositoryReader( | |
github_client, | |
owner=repo_owner, | |
repo=repo_name, | |
filter_directories=REPOS[repo]["filter_directories"], | |
filter_file_extensions=REPOS[repo]["filter_file_extensions"], | |
verbose=args.debug, | |
concurrent_requests=10, | |
) | |
embedded_docs = loader.load_data(commit_sha=commit_sha) | |
g_docs[repo] = embedded_docs | |
save_pickle(embedded_docs, docs_filename) | |
# NOTE: set a chunk size limit to < 1024 tokens | |
service_context = ServiceContext.from_defaults( | |
llm_predictor=LLM_PREDICTOR, | |
embed_model=EMBED_MODEL, | |
node_parser=SimpleNodeParser( | |
text_splitter=TokenTextSplitter( | |
separator=" ", | |
chunk_size=CHUNK_SIZE_LIMIT, | |
chunk_overlap=CHUNK_OVERLAP, | |
backup_separators=[ | |
"\n", | |
"\n\n", | |
"\r\n", | |
"\r", | |
"\t", | |
"\\", | |
"\f", | |
"//", | |
"+", | |
"=", | |
",", | |
".", | |
"a", | |
"e", # TODO: Figure out why lol | |
], | |
) | |
), | |
# llama_logger=LlamaLogger(), # TODO: ? | |
) | |
# Collapse all the docs into a single list | |
logging.debug("Collapsing all the docs into a single list") | |
docs = [] | |
for repo in g_docs.keys(): | |
docs.extend(g_docs[repo]) | |
index = GPTSimpleVectorIndex.from_documents( | |
documents=docs, service_context=service_context | |
) | |
# Ask for CLI input in a loop | |
while True: | |
print("QUERY:") | |
query = input() | |
answer = index.query(query, refine_template=CHAT_REFINE_PROMPT) | |
print(f"ANSWER: {answer}") | |
if args.pdb: | |
import pdb | |
pdb.set_trace() | |
# Parse CLI arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--debug", | |
action="store_true", | |
default=False, | |
help="Enable debug logging.", | |
) | |
parser.add_argument( | |
"--pdb", | |
action="store_true", | |
help="Invoke PDB after each query.", | |
) | |
args = parser.parse_args() | |
if __name__ == "__main__": | |
if args.debug: | |
logging.basicConfig(level=logging.DEBUG) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment