Created
February 5, 2025 18:22
-
-
Save lukestanley/414312b8196f97269b8c0990af6c9d55 to your computer and use it in GitHub Desktop.
Fast Exa search, Mermaid diagrams, LLM chat: DeepSeek R1 70B distilled via Groq using FastAPI (mostly single file web app)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run with: uvicorn mermaid_chat:app --reload --port <free port> | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from pydantic import BaseModel | |
import json, re, requests | |
from datetime import datetime | |
# Import keys from your keys module | |
from keys import GROQ_API_KEY, EXA_API_KEY, NVAPI_KEY | |
app = FastAPI() | |
# --- Config --- | |
GROQ_MODEL = "deepseek-r1-distill-llama-70b" | |
NV_MODEL = "deepseek-ai/deepseek-r1" | |
GROQ_URL = "https://api.groq.com/openai/v1/chat/completions" | |
NVIDIA_URL = "https://integrate.api.nvidia.com/v1/chat/completions" | |
NATIVE_TOOL_USE = False | |
PROVIDER = "groq" | |
if PROVIDER == "groq": | |
URL = GROQ_URL | |
API_KEY = GROQ_API_KEY | |
MODEL = GROQ_MODEL | |
NATIVE_TOOL_USE = True | |
elif PROVIDER in ("nv", "nvidia"): | |
URL = NVIDIA_URL | |
API_KEY = NVAPI_KEY | |
MODEL = NV_MODEL | |
# --- System prompt --- | |
SYSTEM_PROMPT = """You are a helpful assistant that can provide information and answer questions. | |
You can search for information using either method: | |
1. Using special tags: | |
<tool_call>{"name":"exa_search","arguments":{"query":"your search query"}}</tool_call> | |
2. Or the built-in tools system | |
Example using tags: | |
"Let me search for that information: | |
<tool_call>{"name":"exa_search","arguments":{"query":"current UK Prime Minister 2024"}}</tool_call>" | |
Be factual and precise, and search for information that might change over time.""" | |
# --- Utility functions --- | |
def extract_tool_calls(content: str): | |
if not content: | |
return [] | |
tool_calls = [] | |
pattern = r'<tool_call>(.*?)</tool_call>' | |
matches = re.findall(pattern, content) | |
for match in matches: | |
try: | |
tool_call = json.loads(match) | |
tool_calls.append(tool_call) | |
except json.JSONDecodeError: | |
print(f"Failed to parse tool call: {match}") | |
return tool_calls | |
def get_timestamp(): | |
return datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
def exa_search(q: str): | |
print("Searching for:", q) | |
exa_api_url = 'https://api.exa.ai/search' | |
payload = { | |
"query": q, | |
"useAutoprompt": False, | |
"contents": { | |
"text": { "maxCharacters": 1000 } | |
} | |
} | |
headers = { | |
'x-api-key': EXA_API_KEY, | |
"Content-Type": "application/json" | |
} | |
response = requests.post(exa_api_url, json=payload, headers=headers) | |
response.raise_for_status() | |
data = response.json() | |
if not data.get('results'): | |
return "No relevant information found." | |
property_labels = { | |
'title': 'Title', | |
'url': 'URL', | |
'author': 'Author', | |
'publishedDate': 'Published', | |
'text': 'Text' | |
} | |
formatted_results = [] | |
for i, result in enumerate(data['results'], 1): | |
result_str = f"Result {i}:\n" | |
for prop, label in property_labels.items(): | |
value = result.get(prop, '').strip() | |
if value: | |
result_str += f"{label}: {value}\n" | |
formatted_results.append(result_str) | |
return "\n\n".join(formatted_results) | |
def llm_query(prompt, messages): | |
headers = { | |
"Accept": "application/json", | |
"Authorization": f"Bearer {API_KEY}", | |
"Content-Type": "application/json" | |
} | |
if prompt: | |
messages.append({"role": "user", "content": prompt}) | |
data = { | |
"model": MODEL, | |
"messages": messages, | |
"temperature": 0.3, | |
"top_p": 1 | |
} | |
if NATIVE_TOOL_USE: | |
data_tools = { | |
"tools": [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "exa_search", | |
"description": "Search the web using Exa search", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "Search query" | |
} | |
}, | |
"required": ["query"], | |
"additionalProperties": False | |
}, | |
"strict": True | |
} | |
} | |
], | |
"tool_choice": "auto" | |
} | |
data.update(data_tools) | |
print("LLM Request:", data) | |
response = requests.post(URL, json=data, headers=headers) | |
response.raise_for_status() | |
response_data = response.json() | |
print("LLM Response:", response_data) | |
message = response_data['choices'][0]['message'] | |
# Prefer 'content', fallback to 'reasoning' | |
content = message.get('content', '').strip() or message.get('reasoning', '').strip() | |
# Check for tool_calls in the message or extract them from content. | |
tool_calls = message.get('tool_calls', []) | |
if not tool_calls: | |
tool_calls = extract_tool_calls(content) | |
if tool_calls: | |
for tool_call in tool_calls: | |
tool_name = tool_call.get('name') or tool_call.get('function', {}).get('name') | |
if tool_name == 'exa_search': | |
# If the tool call is nested under a function key, extract it. | |
arguments = tool_call.get('arguments') | |
if not arguments and 'function' in tool_call: | |
arguments = tool_call['function'].get('arguments') | |
if isinstance(arguments, str): | |
try: | |
arguments = json.loads(arguments) | |
except Exception: | |
arguments = {} | |
search_query = arguments.get('query', '') | |
search_result = exa_search(search_query) | |
if content: | |
messages.append({"role": "assistant", "content": content}) | |
messages.append({ | |
"role": "user", | |
"content": f"Search results: {search_result}\nPlease provide a final answer based on these results." | |
}) | |
return llm_query(None, messages) | |
return content | |
# --- Pydantic model --- | |
class ChatRequest(BaseModel): | |
message: str | |
history: list = [] # Client-side conversation history | |
# --- Endpoints --- | |
@app.get("/", response_class=HTMLResponse) | |
async def get_index(): | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="utf-8"> | |
<title>LLM Chat with Markdown & Mermaid</title> | |
</head> | |
<body> | |
<div id="app"> | |
<div v-for="msg in conversation" :key="msg.content" class="message" :class="msg.role"> | |
<div v-html="renderMarkdown(msg.content)" class="message-content"></div> | |
</div> | |
<div class="input-container"> | |
<input v-model="newMessage" placeholder="Type your message" @keyup.enter="sendMessage" /> | |
<button @click="sendMessage">Send</button> | |
</div> | |
</div> | |
<script src="https://cdn.jsdelivr.net/npm/vue@3/dist/vue.global.prod.js"></script> | |
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script> | |
<script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script> | |
<script> | |
const { createApp, nextTick } = Vue; | |
mermaid.initialize({ startOnLoad: false }); | |
// Retry state management | |
const retryState = { | |
count: 0, | |
startTime: 0, | |
maxRetries: 3, | |
windowMs: 10000, | |
reset() { | |
this.count = 0; | |
this.startTime = Date.now(); | |
}, | |
canRetry() { | |
const now = Date.now(); | |
if (now - this.startTime > this.windowMs) { | |
this.reset(); | |
} | |
return this.count < this.maxRetries; | |
}, | |
increment() { | |
this.count++; | |
} | |
}; | |
// Helper to check if error is a Mermaid error | |
function isMermaidError(error) { | |
return error && ( | |
error.hash === "UnknownDiagramError" || | |
error.message?.includes("Parse error") || | |
error.str?.includes("Parse error") || | |
error.message?.includes("No diagram type detected") || | |
error.str?.includes("No diagram type detected") | |
); | |
} | |
// Mermaid rendering helper | |
async function renderMermaidDiagram() { | |
const diagrams = document.querySelectorAll(".message-content code.language-mermaid"); | |
if (diagrams.length === 0) return; | |
try { | |
await mermaid.init(undefined, diagrams); | |
} catch (error) { | |
console.error("Mermaid render error:", error); | |
if (isMermaidError(error)) { | |
throw error; | |
} | |
} | |
} | |
const appInstance = createApp({ | |
data() { | |
return { | |
conversation: [], | |
newMessage: "" | |
}; | |
}, | |
methods: { | |
renderMarkdown(content) { | |
return marked.parse(content); | |
}, | |
async tryRenderMermaid() { | |
await nextTick(); | |
try { | |
await renderMermaidDiagram(); | |
} catch (error) { | |
if (isMermaidError(error) && retryState.canRetry()) { | |
retryState.increment(); | |
const errorMsg = error.str || error.message || "Unknown Mermaid syntax error"; | |
await fetch("/chat", { | |
method: "POST", | |
headers: { "Content-Type": "application/json" }, | |
body: JSON.stringify({ | |
message: `The Mermaid diagram is invalid. Error: "${errorMsg}". Carefully review the error using <think></think> tags until you are confident, and then provide a diagram with working syntax. Common errors to checl for are brackets in brackets, wrong diagram type, but other errors are possible too.`, | |
history: this.conversation | |
}) | |
}) | |
.then(response => response.json()) | |
.then(data => { | |
this.conversation.push({ role: "assistant", content: data.response }); | |
return new Promise(resolve => setTimeout(resolve, 100)) | |
.then(() => this.tryRenderMermaid()); | |
}); | |
} | |
} | |
}, | |
async sendMessage() { | |
if (!this.newMessage.trim()) return; | |
const userMsg = { role: "user", content: this.newMessage }; | |
this.conversation.push(userMsg); | |
try { | |
const response = await fetch("/chat", { | |
method: "POST", | |
headers: { "Content-Type": "application/json" }, | |
body: JSON.stringify({ | |
message: this.newMessage, | |
history: this.conversation | |
}) | |
}); | |
const data = await response.json(); | |
this.conversation.push({ role: "assistant", content: data.response }); | |
this.newMessage = ""; | |
await this.tryRenderMermaid(); | |
} catch (error) { | |
console.error("Error in message chain:", error); | |
} | |
} | |
}, | |
watch: { | |
conversation: { | |
deep: true, | |
async handler() { | |
await this.tryRenderMermaid(); | |
} | |
} | |
}, | |
mounted() { | |
this.tryRenderMermaid(); | |
} | |
}).mount("#app"); | |
// Global error handler for Mermaid | |
window.addEventListener('error', async (event) => { | |
const error = event.error; | |
if (isMermaidError(error)) { | |
event.preventDefault(); | |
if (retryState.canRetry()) { | |
retryState.increment(); | |
const errorMsg = error.str || error.message || "Unknown Mermaid syntax error"; | |
try { | |
const response = await fetch("/chat", { | |
method: "POST", | |
headers: { "Content-Type": "application/json" }, | |
body: JSON.stringify({ | |
message: `The Mermaid diagram is invalid. Error: "${errorMsg}". Please fix the diagram syntax and provide a valid Mermaid diagram with a proper diagram type directive (like 'graph TD', 'flowchart LR', etc).`, | |
history: appInstance.conversation | |
}) | |
}); | |
const data = await response.json(); | |
appInstance.conversation.push({ role: "assistant", content: data.response }); | |
await new Promise(resolve => setTimeout(resolve, 100)); | |
await appInstance.tryRenderMermaid(); | |
} catch (err) { | |
console.error("Error in global handler:", err); | |
} | |
} | |
} | |
}); | |
</script> | |
<style> | |
body { font-family: Arial, sans-serif; background: #f4f4f4; padding: 20px; } | |
.message { padding: 10px; margin: 8px 0; border-radius: 8px; max-width: 600px; } | |
.user { background: #e0f7fa; text-align: left; } | |
.assistant { background: #e8f5e9; text-align: left; } | |
.message-content { word-wrap: break-word; } | |
.input-container { margin-top: 10px; } | |
input { padding: 8px; width: 400px; } | |
button { padding: 8px 12px; margin-left: 5px; } | |
</style> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content) | |
@app.post("/chat") | |
async def chat_endpoint(chat_req: ChatRequest): | |
messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
if chat_req.history: | |
messages.extend(chat_req.history) | |
messages.append({"role": "user", "content": chat_req.message}) | |
try: | |
response = llm_query(None, messages) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
return JSONResponse(content={"response": response}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment