Last active
November 6, 2024 14:57
-
-
Save shawnho1018/d53b0922a6425dfe00587b18a459aa51 to your computer and use it in GitHub Desktop.
This function uses fastapi and gemini to create a graphgpt query service
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from neo4j import GraphDatabase | |
from typing import Dict, Any, Optional, List | |
import logging, sys, json | |
import vertexai | |
from vertexai.generative_models import GenerativeModel, Part, SafetySetting, Tool | |
from fastapi import FastAPI, Request, HTTPException | |
app = FastAPI() | |
root = logging.getLogger() | |
root.setLevel(logging.INFO) | |
handler = logging.StreamHandler(sys.stdout) | |
handler.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
root.addHandler(handler) | |
class neo4jConnection: | |
def __init__(self, uri, user, pwd): | |
self.__uri = uri | |
self.__auth = (user, pwd) | |
def run_query(self, query: str, parameters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: | |
with ( | |
GraphDatabase.driver(self.__uri, auth=self.__auth) as driver, | |
driver.session() as session, | |
): | |
result= session.run(query, parameters) | |
results = result.data() | |
count = len(results) | |
logging.info(f"The following query {query} was sent with parameters {parameters}. Return {count} results.") | |
return results | |
def close_session(URI, AUTH, database): | |
driver = GraphDatabase.driver(URI, auth=AUTH) | |
session = driver.session(database=database) | |
# session/driver usage | |
session.close() | |
driver.close() | |
conn = neo4jConnection(uri="neo4j+ssc://neo4j-gke-loadbalancer-ip:7687", user="NEO4J_USERNAME", pwd="NEO4J_PWD") | |
def generate(query: str): | |
vertexai.init(project="PROJECT_ID", location="GEMINI_API_LOCATION") | |
textsi_1 = """ | |
You are an expert of Neo4j cypher schema. | |
You are given a Neo4j graph database with \'Employee\' and \'Skill\' nodes. | |
Employees have a \'HAS_EXPERIENCE\' relationship with Skills. | |
Employee nodes have an \'employeeId\' property. | |
Please provide the cypher schema to to help translate the following natural | |
language into cypher in the json format. | |
--- | |
Question: 列出所有員工的ID、姓名與其有經驗的技能 | |
{"cypher": "MATCH (e:Employee)-[r:HAS_EXPERIENCE]->(s:Skill) RETURN e.employeeId, e.name, s.name"} | |
""" | |
generation_config = { | |
"max_output_tokens": 8192, | |
"temperature": 1, | |
"top_p": 0.95, | |
"response_mime_type": "application/json", | |
} | |
safety_settings = [ | |
SafetySetting( | |
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
threshold=SafetySetting.HarmBlockThreshold.OFF | |
), | |
SafetySetting( | |
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
threshold=SafetySetting.HarmBlockThreshold.OFF | |
), | |
SafetySetting( | |
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, | |
threshold=SafetySetting.HarmBlockThreshold.OFF | |
), | |
SafetySetting( | |
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, | |
threshold=SafetySetting.HarmBlockThreshold.OFF | |
), | |
] | |
model = GenerativeModel( | |
"gemini-1.5-flash-002", | |
system_instruction=[textsi_1] | |
) | |
responses = model.generate_content( | |
[query], | |
generation_config=generation_config, | |
safety_settings=safety_settings, | |
stream=False, | |
) | |
#print(f"responses:{responses}") | |
llm_answer="" | |
for c in responses.candidates: | |
for p in c.content.parts: | |
llm_answer+=p.text | |
answer = json.loads(llm_answer) | |
print(f"answer: {answer['cypher']}") | |
return answer['cypher'] | |
@app.post("/query") | |
async def execute_query(request: Request): | |
prompt = await request.json() | |
query = prompt.get("query") | |
cypher = generate(query) | |
print(f"cypher: {cypher}") | |
results = conn.run_query(cypher) | |
return {"results": results} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment