Created
July 24, 2023 20:08
-
-
Save wenqiglantz/b4c37ef0c8c446302621fc262bfdc2df to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from llama_index.prompts.base import Prompt | |
from llama_index.prompts.prompt_type import PromptType | |
TEXT_TO_SQL_TMPL = ( | |
"Given an input question, first create a syntactically correct {dialect} " | |
"query to run, then look at the results of the query and return the answer. " | |
"You can order the results by a relevant column to return the most " | |
"interesting examples in the database.\n" | |
"Never query for all the columns from a specific table, only ask for a " | |
"few relevant columns given the question.\n" | |
"Pay attention to use only the column names that you can see in the schema " | |
"description. " | |
"Be careful to not query for columns that do not exist. " | |
"Pay attention to which column is in which table. " | |
"Also, qualify column names with the table name when needed.\n" | |
"Use the following format:\n" | |
"Question: Question here\n" | |
"SQLQuery: SQL Query to run\n" | |
"SQLResult: Result of the SQLQuery\n" | |
"Answer: Final answer here\n" | |
"Only use the tables listed below.\n" | |
"{schema}\n" | |
"Question: {query_str} \n" | |
"Order by revenue value converted to number from highest to lowest. " | |
"Please use 'company_name' column from sec_cik_index, use sec_cik_index's SIC_CODE_CATEGORY " | |
"of 'Office of Life Sciences' to identify life science companies, use sec_report_attributes's " | |
"tag 'Revenues', statement 'Income Statement', metadata IS NULL, value is not null, " | |
"period_start_date '2022–01–01' and period_end_date '2022–12–31'. \n" | |
"SQLQuery: " | |
) | |
TEXT_TO_SQL_PROMPT = Prompt( | |
TEXT_TO_SQL_TMPL, | |
prompt_type=PromptType.TEXT_TO_SQL, | |
) | |
query_engine = NLSQLTableQueryEngine( | |
sql_database=sql_database, | |
tables=["sec_cik_index", "sec_report_attributes"], | |
service_context=service_context, | |
text_to_sql_prompt=TEXT_TO_SQL_PROMPT | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment