Skip to content

Instantly share code, notes, and snippets.

@aaron-prindle
Created February 18, 2025 19:14
skaffold-issues-server/backend/main.py
import logging
import glob
import os
import textwrap
import re
import numpy as np
import pandas as pd
from flask import Flask, request, jsonify
from flask_cors import CORS
from vertexai.preview.language_models import TextEmbeddingModel, TextGenerationModel, CodeGenerationModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
app = Flask(__name__)
CORS(app)
logging.basicConfig(level=logging.INFO)
# Globals
chunk_size = 500000
extracted_data = []
cache = {} # A dictionary to cache the predictions
cache_files = {} # A dictionary to cache the predictions
# Functions
@retry(wait=wait_random_exponential(min=10, max=120), stop=stop_after_attempt(5))
def embedding_model_with_backoff(text=[]):
logging.info("Fetching embeddings")
embeddings = embedding_model.get_embeddings(text)
logging.info("Fetched embeddings")
return [each.values for each in embeddings][0]
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def text_generation_text_bison_with_backoff(**kwargs):
logging.info("Generating text")
result = text_bison_model.predict(**kwargs).text
logging.info("Generated text")
return result
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def text_generation_code_bison_with_backoff(**kwargs):
logging.info("Generating text")
result = code_bison_model.predict(**kwargs).text
logging.info("Generated text")
return result
def get_context_from_question(question: str, vector_store: pd.DataFrame, sort_index_value: int = 2) -> tuple:
logging.info("Getting context from question")
query_vector = np.array(embedding_model_with_backoff([question]))
vector_store["dot_product"] = vector_store["embedding"].apply(lambda row: np.dot(row, query_vector))
top_matched = vector_store.sort_values(by="dot_product", ascending=False)[:sort_index_value].index
top_matched_df = vector_store.loc[top_matched, ["file_name", "chunks"]]
context = "\n".join(top_matched_df["chunks"].values)
return context, top_matched_df
# Setup
# List of common image file extensions
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.ico']
for path in glob.glob("issues_files/**", recursive=True):
# Check if the path is a file and not an image
if os.path.isfile(path) and os.path.splitext(path)[1].lower() not in image_extensions:
file_name, file_type = os.path.splitext(path)
logging.info(f"Processing {file_name}")
with open(path, 'r', encoding='utf-8') as file:
content = file.read()
document_chunks = textwrap.wrap(content, width=chunk_size)
for chunk_number, chunk_content in enumerate(document_chunks, start=1):
extracted_data.append({
"file_name": path,
# "file_name": file_name,
"file_type": file_type,
"chunk_number": chunk_number,
"content": chunk_content,
})
logging.info("Finished processing text files")
text_data = pd.DataFrame.from_dict(extracted_data).sort_values(by=["file_name"]).reset_index(drop=True)
text_data_sample = text_data.copy()
text_data_sample["content"] = text_data_sample["content"].apply(lambda x: re.sub("[^A-Za-z0-9]+", " ", x))
text_data_sample["chunks"] = text_data_sample["content"].apply(lambda row: textwrap.wrap(row, width=chunk_size))
text_data_sample = text_data_sample.explode("chunks").sort_values(by=["file_name"]).reset_index(drop=True)
logging.info("Loading models")
text_bison_model = TextGenerationModel.from_pretrained("text-bison-32k")
code_bison_model = CodeGenerationModel.from_pretrained("code-bison-32k")
embedding_model = TextEmbeddingModel.from_pretrained("textembedding-gecko")
logging.info("Applying embeddings")
text_data_sample["embedding"] = text_data_sample["chunks"].apply(lambda x: embedding_model_with_backoff([x]))
text_data_sample["embedding"] = text_data_sample.embedding.apply(np.array)
# Web server routes
@app.route('/', methods=['GET'])
def get_prediction():
question = request.args.get('question')
model = request.args.get('model')
ctx = request.args.get('context')
if not question:
logging.info("error question parameter not provided")
return jsonify({"error": "Please provide a question."}), 400
# Check if the prediction for the question exists in the cache
if question+"_"+model+"_"+ctx in cache:
logging.info("Returning cached prediction")
# TODO(aaron-prindle) make it so caching returns files
return jsonify({"prediction": cache[question+"_"+model+"_"+ctx], "files": cache_files[question+"_"+model+"_"+ctx]})
if ctx == "none":
prompt = f"""
Answer the question in great detail.
Question: \n{question}
Answer:
# """
# prompt = f"""
# Answer the question as precise as possible.
# Question: \n{question}
# Answer:
# """
output_files = []
else:
context, top_matched_df = get_context_from_question(
question,
vector_store=text_data_sample,
sort_index_value=2,
)
# prompt = f"""
# Answer the question as precise as possible using the provided context.
# Context: \n{context}
# Question: \n{question}
# Answer:
# """
prompt = f"""
Answer the question in great detail using the provided context.
Context: \n{context}
Question: \n{question}
Answer:
"""
output_files = top_matched_df["file_name"].values.tolist()
if model == "text-bison-32k":
prediction = text_generation_text_bison_with_backoff(prompt=prompt)
else:
prediction = text_generation_code_bison_with_backoff(prefix=prompt)
# file_name = "\n".join(top_matched_df["file_name"].values)
# logging.info("Got context from question: %s", context)
# logging.info("Got file_name from question: %s", file_name)
# Cache the prediction
cache[question+"_"+model+"_"+ctx] = prediction
cache_files[question+"_"+model+"_"+ctx] = output_files
logging.info("returning prediction")
return jsonify({"prediction": prediction, "files": output_files})
# Run server
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5005)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment