Created
March 30, 2023 07:16
-
-
Save hlwhl/0499d665c7e9db48d7b0a8056f0deccb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from flask import Flask, request | |
from flask_cors import CORS | |
from llama_index import Document, GPTSimpleVectorIndex, LLMPredictor, QuestionAnswerPrompt, ServiceContext, PromptHelper | |
from langchain.chat_models import ChatOpenAI | |
app = Flask(__name__) | |
CORS(app) | |
@app.route('/', methods=["POST"]) | |
def home(): | |
return 'Hello, World!' | |
llm_predictor = LLMPredictor(llm=ChatOpenAI( | |
openai_api_key="", | |
temperature=0, | |
)) | |
# define prompt helper | |
# set maximum input size | |
max_input_size = 4096 | |
# set number of output tokens | |
num_output = 512 | |
# set maximum chunk overlap | |
max_chunk_overlap = 20 | |
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap) | |
service_context = ServiceContext.from_defaults( | |
llm_predictor=llm_predictor, prompt_helper=prompt_helper) | |
def gen_index(document): | |
doc = Document(document) | |
# adding space every 100 characters | |
doc.text = ' '.join([doc.text[i:i+100] | |
for i in range(0, len(doc.text), 100)]) | |
index = GPTSimpleVectorIndex.from_documents( | |
[doc], service_context=service_context) | |
return index | |
def getQAPrompt(): | |
QUESTION_ANSWER_PROMPT_TMPL = ( | |
"Context information is below. \n" | |
"---------------------\n" | |
"{context_str}" | |
"\n---------------------\n" | |
"{query_str}\n") | |
QUESTION_ANSWER_PROMPT = QuestionAnswerPrompt(QUESTION_ANSWER_PROMPT_TMPL) | |
return QUESTION_ANSWER_PROMPT | |
# 定义一个接口名为query, post请求, body为json格式, 传入参数为context, prompt | |
@app.route("/query", methods=["POST"]) | |
def query(): | |
context = request.json.get('context') | |
prompt = request.json.get('prompt') | |
index = gen_index(context) | |
response = index.query(prompt + " 请使用中文回答。", | |
text_qa_template=getQAPrompt(), response_mode="tree_summarize") | |
return str(response), 200 | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.DEBUG) | |
app.run(host='0.0.0.0', port=5601) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment