Created
October 10, 2024 00:08
-
-
Save ColeMurray/da493e2f57611295af1b02de925d877f to your computer and use it in GitHub Desktop.
A basic text to SQL implementation using LLMs. OpenAI, SQL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import openai | |
from sqlalchemy import create_engine, inspect | |
from sqlalchemy.exc import SQLAlchemyError | |
from sqlalchemy import text | |
import re | |
import json | |
class LLMSQLQueryEngine: | |
def __init__(self, db_path): | |
self.engine = create_engine(f'sqlite:///{db_path}') | |
self.inspector = inspect(self.engine) | |
self.tables = self.inspector.get_table_names() | |
def get_schema(self): | |
schema = [] | |
for table in self.tables: | |
columns = self.inspector.get_columns(table) | |
schema.append(f"Table: {table}") | |
for column in columns: | |
schema.append(f" - {column['name']}: {column['type']}") | |
return "\n".join(schema) | |
def generate_sql(self, natural_language_query): | |
schema = self.get_schema() | |
prompt = f"""Given the following database schema: | |
{schema} | |
Generate a SQL query to answer the following question: | |
{natural_language_query} | |
Requirements: | |
The query should be compatible with SQLite syntax. | |
Include all relevant tables in the query. | |
ONLY USE THE AVAILABLE COLUMNS AND TABLES. | |
Be careful to not query for columns that do not exist. | |
Pay attention to which column is in which table. | |
Also, qualify column names with the table name when needed. | |
Prefer concise queries over verbose ones. | |
Return as a JSON OBJECT, like so: | |
{{"reasoning": "we need to fetch the users", sql": "SELECT * FROM ..."}} | |
""" | |
print(f"Sending prompt to OpenAI: {prompt}") | |
response = openai.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": "You are a SQL expert. Generate SQL queries based on natural language questions and the given schema."}, | |
{"role": "user", "content": prompt} | |
], | |
temperature=0.1, | |
response_format={"type": "json_object"} | |
) | |
response = response.choices[0].message.content.strip() | |
print(response) | |
return json.loads(response)['sql'] | |
def execute_sql(self, sql_query): | |
try: | |
with self.engine.connect() as connection: | |
result = connection.execute(sql_query) | |
return result.fetchall() | |
except SQLAlchemyError as e: | |
return f"Error executing SQL query: {str(e)}" | |
def query(self, natural_language_query): | |
sql_query = self.generate_sql(natural_language_query) | |
result = self.execute_sql(text(sql_query)) | |
return { | |
"natural_language_query": natural_language_query, | |
"sql_query": sql_query, | |
"result": result | |
} | |
# Usage example | |
if __name__ == "__main__": | |
db_path = 'forum_data.db' | |
query_engine = LLMSQLQueryEngine(db_path) | |
nl_query = "Identify the top 10 contributors and show their user profiles" | |
result = query_engine.query(nl_query) | |
print("Natural Language Query:", result["natural_language_query"]) | |
print("\nGenerated SQL Query:") | |
print(result["sql_query"]) | |
print("\nQuery Result:") | |
print(result["result"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment