Skip to content

Instantly share code, notes, and snippets.

@shawnho1018
Last active November 6, 2024 14:57
Show Gist options
  • Save shawnho1018/d53b0922a6425dfe00587b18a459aa51 to your computer and use it in GitHub Desktop.
Save shawnho1018/d53b0922a6425dfe00587b18a459aa51 to your computer and use it in GitHub Desktop.
This function uses fastapi and gemini to create a graphgpt query service
from neo4j import GraphDatabase
from typing import Dict, Any, Optional, List
import logging, sys, json
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting, Tool
from fastapi import FastAPI, Request, HTTPException
app = FastAPI()
root = logging.getLogger()
root.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
root.addHandler(handler)
class neo4jConnection:
def __init__(self, uri, user, pwd):
self.__uri = uri
self.__auth = (user, pwd)
def run_query(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
with (
GraphDatabase.driver(self.__uri, auth=self.__auth) as driver,
driver.session() as session,
):
result= session.run(query, parameters)
results = result.data()
count = len(results)
logging.info(f"The following query {query} was sent with parameters {parameters}. Return {count} results.")
return results
def close_session(URI, AUTH, database):
driver = GraphDatabase.driver(URI, auth=AUTH)
session = driver.session(database=database)
# session/driver usage
session.close()
driver.close()
conn = neo4jConnection(uri="neo4j+ssc://neo4j-gke-loadbalancer-ip:7687", user="NEO4J_USERNAME", pwd="NEO4J_PWD")
def generate(query: str):
vertexai.init(project="PROJECT_ID", location="GEMINI_API_LOCATION")
textsi_1 = """
You are an expert of Neo4j cypher schema.
You are given a Neo4j graph database with \'Employee\' and \'Skill\' nodes.
Employees have a \'HAS_EXPERIENCE\' relationship with Skills.
Employee nodes have an \'employeeId\' property.
Please provide the cypher schema to to help translate the following natural
language into cypher in the json format.
---
Question: 列出所有員工的ID、姓名與其有經驗的技能
{"cypher": "MATCH (e:Employee)-[r:HAS_EXPERIENCE]->(s:Skill) RETURN e.employeeId, e.name, s.name"}
"""
generation_config = {
"max_output_tokens": 8192,
"temperature": 1,
"top_p": 0.95,
"response_mime_type": "application/json",
}
safety_settings = [
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=SafetySetting.HarmBlockThreshold.OFF
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.OFF
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=SafetySetting.HarmBlockThreshold.OFF
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=SafetySetting.HarmBlockThreshold.OFF
),
]
model = GenerativeModel(
"gemini-1.5-flash-002",
system_instruction=[textsi_1]
)
responses = model.generate_content(
[query],
generation_config=generation_config,
safety_settings=safety_settings,
stream=False,
)
#print(f"responses:{responses}")
llm_answer=""
for c in responses.candidates:
for p in c.content.parts:
llm_answer+=p.text
answer = json.loads(llm_answer)
print(f"answer: {answer['cypher']}")
return answer['cypher']
@app.post("/query")
async def execute_query(request: Request):
prompt = await request.json()
query = prompt.get("query")
cypher = generate(query)
print(f"cypher: {cypher}")
results = conn.run_query(cypher)
return {"results": results}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment