Skip to content

Instantly share code, notes, and snippets.

@RohanAwhad
Last active February 9, 2025 19:41
Show Gist options
  • Save RohanAwhad/ffafb6d78ad1d6ffe5d7b72031c1ebde to your computer and use it in GitHub Desktop.
Save RohanAwhad/ffafb6d78ad1d6ffe5d7b72031c1ebde to your computer and use it in GitHub Desktop.
One-shot Perplexity Pro
#!/Users/rohan/miniconda3/bin/python
import asyncio
import dataclasses
import json
import openai
import os
import requests
import sys
import time
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_filter_strategy import PruningContentFilter
from loguru import logger
from typing import List, Dict, Any, Tuple
logger.remove()
if os.environ.get("DEBUG") == "1":
logger.add(sink=os.path.join(os.environ['HOME'], "utils", "researcher.logs.debug"), level="DEBUG")
else:
logger.add(sink=os.path.join(os.environ['HOME'], "utils", "researcher.logs.info"), level="INFO")
# cost per 1 M tokens
@dataclasses.dataclass
class CostPolicy:
model_name: str
input_cost: float
cached_input_cost: float
output_cost: float
api_key_env_var: str
base_url: str | None = None
gpt4o = CostPolicy('gpt-4o', 2.50/1e6, 1.25/1e6, 10.00/1e6, 'OPENAI_API_KEY')
o3mini = CostPolicy('o3-mini', 1.1/1e6, 0.55/1e6, 4.4/1e6, 'OPENAI_API_KEY')
CURR_MODEL = o3mini
class TokensAndCosts:
def __init__(self):
self.total_uncached_input_tokens = 0
self.total_cached_input_tokens = 0
self.total_completion_tokens = 0
self.total_cost = 0.0
self.all_good = True
def update(self, usage: Any, model: CostPolicy) -> None:
try:
uncached_tokens = usage.prompt_tokens - usage.prompt_tokens_details.cached_tokens
self.total_uncached_input_tokens += uncached_tokens
self.total_cached_input_tokens += usage.prompt_tokens_details.cached_tokens
self.total_completion_tokens += usage.completion_tokens
uncached_input_cost = uncached_tokens * model.input_cost
cached_input_cost = usage.prompt_tokens_details.cached_tokens * model.cached_input_cost
output_cost = usage.completion_tokens * model.output_cost
self.total_cost += uncached_input_cost + cached_input_cost + output_cost
logger.debug(f"Usage - Prompt tokens: {usage.prompt_tokens}, Cached: {
usage.prompt_tokens_details.cached_tokens}, Completion: {usage.completion_tokens}")
logger.debug(
f"Updated totals - Uncached: {self.total_uncached_input_tokens}, Cached: {self.total_cached_input_tokens}")
logger.debug(f"Total cost so far: ${self.total_cost:.6f}")
except Exception:
logger.exception('Failed to update usage.')
self.all_good = False
def __str__(self) -> str:
if self.all_good:
return (
f"Uncached input tokens: {self.total_uncached_input_tokens}\n"
f"Cached input tokens: {self.total_cached_input_tokens}\n"
f"Completion tokens: {self.total_completion_tokens}\n"
f"Total cost: ${self.total_cost:.6f}"
)
return "Check logs, something messed up!!"
# ===
# Scrapper
# ===
async def fetch_markdown(url: str) -> str | None:
try:
assert url is not None, 'url is None'
assert len(url) > 0, 'len(url) is 0'
# The PruningContentFilter removes low-density text blocks (fluff).
md_generator = DefaultMarkdownGenerator(
options={
"ignore_links": True,
"escape_html": True,
"body_width": 80
},
content_filter=PruningContentFilter(
threshold=0.5,
min_word_threshold=50
)
)
browser_config = BrowserConfig(verbose=False)
crawler_config = CrawlerRunConfig(
markdown_generator=md_generator,
word_count_threshold=10,
excluded_tags=["nav", "header", "footer"],
exclude_external_links=True,
verbose=False
)
async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun(url, config=crawler_config)
logger.debug(f'Markdown:\n{result.markdown}')
return result.markdown
except Exception as e:
logger.exception(f'Failed to fetch markdown for url: {url}')
return None
# ===
# Brave Search tool
# ===
@dataclasses.dataclass
class SearchResult:
"""
Dataclass to represent the search results from Brave Search API.
:param title: The title of the search result.
:param url: The URL of the search result.
:param description: A brief description of the search result.
:param extra_snippets: Additional snippets related to the search result.
:param markdown: A pruned and filtered markdown of the webpage (may not contain all the details).
"""
title: str
url: str
description: str
extra_snippets: list
markdown: str
def __str__(self) -> str:
"""
Returns a string representation of the search result.
:return: A string representation of the search result.
"""
return (
f"Title: {self.title}\n"
f"URL: {self.url}\n"
f"Description: {self.description}\n"
f"Extra Snippets: {', '.join(self.extra_snippets)}\n"
f"Markdown: {self.markdown}"
)
def search_brave(query: str, count: int = 5) -> list[SearchResult]:
"""
Searches the web using Brave Search API and returns structured search results.
:param query: The search query string.
:param count: The number of search results to return.
:return: A list of SearchResult objects containing the search results.
"""
if not query:
return []
url: str = "https://api.search.brave.com/res/v1/web/search"
headers: dict = {
"Accept": "application/json",
"X-Subscription-Token": os.environ.get('BRAVE_SEARCH_AI_API_KEY', '')
}
if not headers['X-Subscription-Token']:
logger.error("Error: Missing Brave Search API key.")
return []
params: dict = {
"q": query,
"count": count
}
retries: int = 0
max_retries: int = 3
backoff_factor: int = 2
while retries < max_retries:
try:
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
results_json: dict = response.json()
logger.debug('Got results')
break
except requests.exceptions.RequestException as e:
logger.exception(f"HTTP Request failed: {e}, retrying...")
retries += 1
if retries < max_retries:
time.sleep(backoff_factor ** retries)
else:
return []
async def fetch_all_markdown(urls):
tasks = [fetch_markdown(url) for url in urls]
return await asyncio.gather(*tasks)
urls = [item.get('url', '') for item in results_json.get('web', {}).get('results', [])]
markdowns = asyncio.run(fetch_all_markdown(urls))
results: List[SearchResult] = []
for item, md in zip(results_json.get('web', {}).get('results', []), markdowns):
if md is None:
md = 'Failed to get markdown.'
result = SearchResult(
title=item.get('title', ''),
url=item.get('url', ''),
description=item.get('description', ''),
extra_snippets=item.get('extra_snippets', []),
markdown=md
)
results.append(result)
return results
brave_search_tools = [{
"type": "function",
"function": {
"name": "search_brave",
"description": "Search the web using Brave Search API and returns structured search results.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "the search query string."},
},
"required": ["query"]
}
}
}]
# ===
# Main System
# ===
first_llm_system_prompt = '''
You are a language model. Your task is to answer the complex queries of the user. You can use brave search to search internet and get not just links and title and small description, but also a deep dive into the original content of certain pages.
You can perform multiple search requests with breakdown'd queries, and do multi-turn requests before answering the user's queries.
'''.strip()
second_llm_system_prompt = '''
You are a language model and your task is to look at the search results received from searching the internet for a user given query. You have to decide, whether the results need to be augmented with more information from actual webpage. If you do decide that you need to add more information for a certain result, you can call the tool.
Your final response back to the user should be a long compilation of the entire search result input and the content gathered from surfing webpages. Make sure that you include all the required information and additional related content that is mentioned.
'''.strip()
answer_checker_system_prompt = '''
You are a language model and your task is to evaluate the answer that is generated by another llm for the given user query. Check if the answer does answer the entire question or list of questions that the user is asking.
If you think answer is sufficient, then call return_answer_to_user tool, else if you think the answer needs some more information, then call the regenerate_answer tool with your suggestion as to how to modify the answer, and what else needs to be included.
You can only call one function at a time.
For answer sufficiency, look if the answer is appropriately answered. Look for sections which could benefit from further internet research. Sometimes the llm might just answer based on the one google search, you can ask the llm to research down another avenue to fulfill question.
'''.strip()
answer_checker_tools = [
{
"type": "function",
"function": {
"name": "return_answer_to_user",
"description": "This is function will return the current answer back to the user",
"parameters": {}
}
},
{
"type": "function",
"function": {
"name": "regenerate_answer",
"description": "This function will send the suggestion back to the llm for modifications to the answer based on the user queries",
"parameters": {
"type": "object",
"properties": {
"suggestion": {"type": "string", "description": "the suggestion for the llm"},
},
"required": ["suggestion"]
}
}
}
]
def ask_llm(messages: List[Dict[str, str]], tools: List[Dict[str, Any]] | None = None, tool_choice: str = 'auto') -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Ask the LLM a question and get a response.
:param messages: List of message dictionaries
:param functions: List of function definitions (optional)
:return: LLM's response as a dictionary
"""
import os
LLM_API_KEY = os.environ[CURR_MODEL.api_key_env_var]
client = openai.OpenAI(api_key=LLM_API_KEY, base_url=CURR_MODEL.base_url)
response = client.chat.completions.create(
model=CURR_MODEL.model_name,
messages=messages,
tools=tools,
tool_choice=tool_choice,
)
return response.choices[0].message, response.usage
def process_user_query(query: str) -> str:
"""
Process the user's query and return a response.
:param query: User's question or query
:return: Final response to the user
"""
# Step 1: Ask LLM to formulate a search query
history = [{'role': 'system', 'content': first_llm_system_prompt}]
initial_message = {"role": "user", "content": query}
history.append(initial_message)
usage_track = TokensAndCosts()
while True:
assistant_message, usage = ask_llm(history, brave_search_tools)
history.append(assistant_message)
usage_track.update(usage, CURR_MODEL)
if assistant_message.tool_calls:
if assistant_message.content:
logger.debug(assistant_message.content)
for function_call in assistant_message.tool_calls:
args = json.loads(function_call.function.arguments)
tool_call_response = {"role": "tool", "tool_call_id": function_call.id, "content": ""}
if function_call.function.name == 'search_brave':
if 'query' in args:
search_result = search_brave(args['query'])
tool_call_response['content'] = json.dumps([dataclasses.asdict(x) for x in search_result], indent=2)
else:
tool_call_response['content'] = 'No query found in arguments'
history.append(tool_call_response)
else:
logger.debug(f'LLM Generated Answer:\n{assistant_message.content}')
input_to_answer_checker = f'User query: {query}\n\nLLM Answer:\n{assistant_message.content}'
answer_checker_history = [
{'role': 'system', 'content': answer_checker_system_prompt},
{'role': 'user', 'content': input_to_answer_checker},
]
answer_checker_response, usage = ask_llm(answer_checker_history, answer_checker_tools, tool_choice='required')
usage_track.update(usage, CURR_MODEL)
function_call = answer_checker_response.tool_calls[0]
if function_call.function.name == 'return_answer_to_user':
if usage_track.all_good:
cost_details: Dict[str, float] = {
"total_uncached_input_tokens_cost": usage_track.total_uncached_input_tokens * CURR_MODEL.input_cost,
"total_cached_input_tokens_cost": usage_track.total_cached_input_tokens * CURR_MODEL.cached_input_cost,
"total_completion_tokens_cost": usage_track.total_completion_tokens * CURR_MODEL.output_cost,
"total_cost": usage_track.total_cost
}
else:
cost_details: Dict[str, float] = {
"total_uncached_input_tokens_cost": 0,
"total_cached_input_tokens_cost": 0,
"total_completion_tokens_cost": 0,
"total_cost": 0
}
logger.debug(f'Usage:\n{usage_track}')
return assistant_message.content, cost_details
elif function_call.function.name == 'regenerate_answer':
suggestion = json.loads(function_call.function.arguments).get('suggestion', '')
logger.debug('LLM Suggestion:', suggestion)
history.append({'role': 'user', 'content': suggestion})
if __name__ == "__main__":
user_question: str = ' '.join(sys.argv[1:]) if len(sys.argv) > 1 else input("What's your question? ")
response, cost_details = process_user_query(user_question)
print(f"Response: {response}")
print(f"Total cost: ${cost_details['total_cost']:.4f}")
with open('response.md', 'w') as f:
f.write(response)
with open('cost_details.json', 'w') as f:
json.dump(cost_details, f, indent=2)
print('Response and Cost details also saved in current dir as response.md and cost_details.json, respectively.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment