Created
February 18, 2025 19:14
skaffold-issues-server/backend/main.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import glob | |
import os | |
import textwrap | |
import re | |
import numpy as np | |
import pandas as pd | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from vertexai.preview.language_models import TextEmbeddingModel, TextGenerationModel, CodeGenerationModel | |
from tenacity import retry, stop_after_attempt, wait_random_exponential | |
app = Flask(__name__) | |
CORS(app) | |
logging.basicConfig(level=logging.INFO) | |
# Globals | |
chunk_size = 500000 | |
extracted_data = [] | |
cache = {} # A dictionary to cache the predictions | |
cache_files = {} # A dictionary to cache the predictions | |
# Functions | |
@retry(wait=wait_random_exponential(min=10, max=120), stop=stop_after_attempt(5)) | |
def embedding_model_with_backoff(text=[]): | |
logging.info("Fetching embeddings") | |
embeddings = embedding_model.get_embeddings(text) | |
logging.info("Fetched embeddings") | |
return [each.values for each in embeddings][0] | |
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) | |
def text_generation_text_bison_with_backoff(**kwargs): | |
logging.info("Generating text") | |
result = text_bison_model.predict(**kwargs).text | |
logging.info("Generated text") | |
return result | |
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) | |
def text_generation_code_bison_with_backoff(**kwargs): | |
logging.info("Generating text") | |
result = code_bison_model.predict(**kwargs).text | |
logging.info("Generated text") | |
return result | |
def get_context_from_question(question: str, vector_store: pd.DataFrame, sort_index_value: int = 2) -> tuple: | |
logging.info("Getting context from question") | |
query_vector = np.array(embedding_model_with_backoff([question])) | |
vector_store["dot_product"] = vector_store["embedding"].apply(lambda row: np.dot(row, query_vector)) | |
top_matched = vector_store.sort_values(by="dot_product", ascending=False)[:sort_index_value].index | |
top_matched_df = vector_store.loc[top_matched, ["file_name", "chunks"]] | |
context = "\n".join(top_matched_df["chunks"].values) | |
return context, top_matched_df | |
# Setup | |
# List of common image file extensions | |
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.ico'] | |
for path in glob.glob("issues_files/**", recursive=True): | |
# Check if the path is a file and not an image | |
if os.path.isfile(path) and os.path.splitext(path)[1].lower() not in image_extensions: | |
file_name, file_type = os.path.splitext(path) | |
logging.info(f"Processing {file_name}") | |
with open(path, 'r', encoding='utf-8') as file: | |
content = file.read() | |
document_chunks = textwrap.wrap(content, width=chunk_size) | |
for chunk_number, chunk_content in enumerate(document_chunks, start=1): | |
extracted_data.append({ | |
"file_name": path, | |
# "file_name": file_name, | |
"file_type": file_type, | |
"chunk_number": chunk_number, | |
"content": chunk_content, | |
}) | |
logging.info("Finished processing text files") | |
text_data = pd.DataFrame.from_dict(extracted_data).sort_values(by=["file_name"]).reset_index(drop=True) | |
text_data_sample = text_data.copy() | |
text_data_sample["content"] = text_data_sample["content"].apply(lambda x: re.sub("[^A-Za-z0-9]+", " ", x)) | |
text_data_sample["chunks"] = text_data_sample["content"].apply(lambda row: textwrap.wrap(row, width=chunk_size)) | |
text_data_sample = text_data_sample.explode("chunks").sort_values(by=["file_name"]).reset_index(drop=True) | |
logging.info("Loading models") | |
text_bison_model = TextGenerationModel.from_pretrained("text-bison-32k") | |
code_bison_model = CodeGenerationModel.from_pretrained("code-bison-32k") | |
embedding_model = TextEmbeddingModel.from_pretrained("textembedding-gecko") | |
logging.info("Applying embeddings") | |
text_data_sample["embedding"] = text_data_sample["chunks"].apply(lambda x: embedding_model_with_backoff([x])) | |
text_data_sample["embedding"] = text_data_sample.embedding.apply(np.array) | |
# Web server routes | |
@app.route('/', methods=['GET']) | |
def get_prediction(): | |
question = request.args.get('question') | |
model = request.args.get('model') | |
ctx = request.args.get('context') | |
if not question: | |
logging.info("error question parameter not provided") | |
return jsonify({"error": "Please provide a question."}), 400 | |
# Check if the prediction for the question exists in the cache | |
if question+"_"+model+"_"+ctx in cache: | |
logging.info("Returning cached prediction") | |
# TODO(aaron-prindle) make it so caching returns files | |
return jsonify({"prediction": cache[question+"_"+model+"_"+ctx], "files": cache_files[question+"_"+model+"_"+ctx]}) | |
if ctx == "none": | |
prompt = f""" | |
Answer the question in great detail. | |
Question: \n{question} | |
Answer: | |
# """ | |
# prompt = f""" | |
# Answer the question as precise as possible. | |
# Question: \n{question} | |
# Answer: | |
# """ | |
output_files = [] | |
else: | |
context, top_matched_df = get_context_from_question( | |
question, | |
vector_store=text_data_sample, | |
sort_index_value=2, | |
) | |
# prompt = f""" | |
# Answer the question as precise as possible using the provided context. | |
# Context: \n{context} | |
# Question: \n{question} | |
# Answer: | |
# """ | |
prompt = f""" | |
Answer the question in great detail using the provided context. | |
Context: \n{context} | |
Question: \n{question} | |
Answer: | |
""" | |
output_files = top_matched_df["file_name"].values.tolist() | |
if model == "text-bison-32k": | |
prediction = text_generation_text_bison_with_backoff(prompt=prompt) | |
else: | |
prediction = text_generation_code_bison_with_backoff(prefix=prompt) | |
# file_name = "\n".join(top_matched_df["file_name"].values) | |
# logging.info("Got context from question: %s", context) | |
# logging.info("Got file_name from question: %s", file_name) | |
# Cache the prediction | |
cache[question+"_"+model+"_"+ctx] = prediction | |
cache_files[question+"_"+model+"_"+ctx] = output_files | |
logging.info("returning prediction") | |
return jsonify({"prediction": prediction, "files": output_files}) | |
# Run server | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=5005) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment