Last active
February 9, 2025 19:41
-
-
Save RohanAwhad/ffafb6d78ad1d6ffe5d7b72031c1ebde to your computer and use it in GitHub Desktop.
One-shot Perplexity Pro
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/Users/rohan/miniconda3/bin/python | |
import asyncio | |
import dataclasses | |
import json | |
import openai | |
import os | |
import requests | |
import sys | |
import time | |
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig | |
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator | |
from crawl4ai.content_filter_strategy import PruningContentFilter | |
from loguru import logger | |
from typing import List, Dict, Any, Tuple | |
logger.remove() | |
if os.environ.get("DEBUG") == "1": | |
logger.add(sink=os.path.join(os.environ['HOME'], "utils", "researcher.logs.debug"), level="DEBUG") | |
else: | |
logger.add(sink=os.path.join(os.environ['HOME'], "utils", "researcher.logs.info"), level="INFO") | |
# cost per 1 M tokens | |
@dataclasses.dataclass | |
class CostPolicy: | |
model_name: str | |
input_cost: float | |
cached_input_cost: float | |
output_cost: float | |
api_key_env_var: str | |
base_url: str | None = None | |
gpt4o = CostPolicy('gpt-4o', 2.50/1e6, 1.25/1e6, 10.00/1e6, 'OPENAI_API_KEY') | |
o3mini = CostPolicy('o3-mini', 1.1/1e6, 0.55/1e6, 4.4/1e6, 'OPENAI_API_KEY') | |
CURR_MODEL = o3mini | |
class TokensAndCosts: | |
def __init__(self): | |
self.total_uncached_input_tokens = 0 | |
self.total_cached_input_tokens = 0 | |
self.total_completion_tokens = 0 | |
self.total_cost = 0.0 | |
self.all_good = True | |
def update(self, usage: Any, model: CostPolicy) -> None: | |
try: | |
uncached_tokens = usage.prompt_tokens - usage.prompt_tokens_details.cached_tokens | |
self.total_uncached_input_tokens += uncached_tokens | |
self.total_cached_input_tokens += usage.prompt_tokens_details.cached_tokens | |
self.total_completion_tokens += usage.completion_tokens | |
uncached_input_cost = uncached_tokens * model.input_cost | |
cached_input_cost = usage.prompt_tokens_details.cached_tokens * model.cached_input_cost | |
output_cost = usage.completion_tokens * model.output_cost | |
self.total_cost += uncached_input_cost + cached_input_cost + output_cost | |
logger.debug(f"Usage - Prompt tokens: {usage.prompt_tokens}, Cached: { | |
usage.prompt_tokens_details.cached_tokens}, Completion: {usage.completion_tokens}") | |
logger.debug( | |
f"Updated totals - Uncached: {self.total_uncached_input_tokens}, Cached: {self.total_cached_input_tokens}") | |
logger.debug(f"Total cost so far: ${self.total_cost:.6f}") | |
except Exception: | |
logger.exception('Failed to update usage.') | |
self.all_good = False | |
def __str__(self) -> str: | |
if self.all_good: | |
return ( | |
f"Uncached input tokens: {self.total_uncached_input_tokens}\n" | |
f"Cached input tokens: {self.total_cached_input_tokens}\n" | |
f"Completion tokens: {self.total_completion_tokens}\n" | |
f"Total cost: ${self.total_cost:.6f}" | |
) | |
return "Check logs, something messed up!!" | |
# === | |
# Scrapper | |
# === | |
async def fetch_markdown(url: str) -> str | None: | |
try: | |
assert url is not None, 'url is None' | |
assert len(url) > 0, 'len(url) is 0' | |
# The PruningContentFilter removes low-density text blocks (fluff). | |
md_generator = DefaultMarkdownGenerator( | |
options={ | |
"ignore_links": True, | |
"escape_html": True, | |
"body_width": 80 | |
}, | |
content_filter=PruningContentFilter( | |
threshold=0.5, | |
min_word_threshold=50 | |
) | |
) | |
browser_config = BrowserConfig(verbose=False) | |
crawler_config = CrawlerRunConfig( | |
markdown_generator=md_generator, | |
word_count_threshold=10, | |
excluded_tags=["nav", "header", "footer"], | |
exclude_external_links=True, | |
verbose=False | |
) | |
async with AsyncWebCrawler(config=browser_config) as crawler: | |
result = await crawler.arun(url, config=crawler_config) | |
logger.debug(f'Markdown:\n{result.markdown}') | |
return result.markdown | |
except Exception as e: | |
logger.exception(f'Failed to fetch markdown for url: {url}') | |
return None | |
# === | |
# Brave Search tool | |
# === | |
@dataclasses.dataclass | |
class SearchResult: | |
""" | |
Dataclass to represent the search results from Brave Search API. | |
:param title: The title of the search result. | |
:param url: The URL of the search result. | |
:param description: A brief description of the search result. | |
:param extra_snippets: Additional snippets related to the search result. | |
:param markdown: A pruned and filtered markdown of the webpage (may not contain all the details). | |
""" | |
title: str | |
url: str | |
description: str | |
extra_snippets: list | |
markdown: str | |
def __str__(self) -> str: | |
""" | |
Returns a string representation of the search result. | |
:return: A string representation of the search result. | |
""" | |
return ( | |
f"Title: {self.title}\n" | |
f"URL: {self.url}\n" | |
f"Description: {self.description}\n" | |
f"Extra Snippets: {', '.join(self.extra_snippets)}\n" | |
f"Markdown: {self.markdown}" | |
) | |
def search_brave(query: str, count: int = 5) -> list[SearchResult]: | |
""" | |
Searches the web using Brave Search API and returns structured search results. | |
:param query: The search query string. | |
:param count: The number of search results to return. | |
:return: A list of SearchResult objects containing the search results. | |
""" | |
if not query: | |
return [] | |
url: str = "https://api.search.brave.com/res/v1/web/search" | |
headers: dict = { | |
"Accept": "application/json", | |
"X-Subscription-Token": os.environ.get('BRAVE_SEARCH_AI_API_KEY', '') | |
} | |
if not headers['X-Subscription-Token']: | |
logger.error("Error: Missing Brave Search API key.") | |
return [] | |
params: dict = { | |
"q": query, | |
"count": count | |
} | |
retries: int = 0 | |
max_retries: int = 3 | |
backoff_factor: int = 2 | |
while retries < max_retries: | |
try: | |
response = requests.get(url, headers=headers, params=params) | |
response.raise_for_status() | |
results_json: dict = response.json() | |
logger.debug('Got results') | |
break | |
except requests.exceptions.RequestException as e: | |
logger.exception(f"HTTP Request failed: {e}, retrying...") | |
retries += 1 | |
if retries < max_retries: | |
time.sleep(backoff_factor ** retries) | |
else: | |
return [] | |
async def fetch_all_markdown(urls): | |
tasks = [fetch_markdown(url) for url in urls] | |
return await asyncio.gather(*tasks) | |
urls = [item.get('url', '') for item in results_json.get('web', {}).get('results', [])] | |
markdowns = asyncio.run(fetch_all_markdown(urls)) | |
results: List[SearchResult] = [] | |
for item, md in zip(results_json.get('web', {}).get('results', []), markdowns): | |
if md is None: | |
md = 'Failed to get markdown.' | |
result = SearchResult( | |
title=item.get('title', ''), | |
url=item.get('url', ''), | |
description=item.get('description', ''), | |
extra_snippets=item.get('extra_snippets', []), | |
markdown=md | |
) | |
results.append(result) | |
return results | |
brave_search_tools = [{ | |
"type": "function", | |
"function": { | |
"name": "search_brave", | |
"description": "Search the web using Brave Search API and returns structured search results.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": {"type": "string", "description": "the search query string."}, | |
}, | |
"required": ["query"] | |
} | |
} | |
}] | |
# === | |
# Main System | |
# === | |
first_llm_system_prompt = ''' | |
You are a language model. Your task is to answer the complex queries of the user. You can use brave search to search internet and get not just links and title and small description, but also a deep dive into the original content of certain pages. | |
You can perform multiple search requests with breakdown'd queries, and do multi-turn requests before answering the user's queries. | |
'''.strip() | |
second_llm_system_prompt = ''' | |
You are a language model and your task is to look at the search results received from searching the internet for a user given query. You have to decide, whether the results need to be augmented with more information from actual webpage. If you do decide that you need to add more information for a certain result, you can call the tool. | |
Your final response back to the user should be a long compilation of the entire search result input and the content gathered from surfing webpages. Make sure that you include all the required information and additional related content that is mentioned. | |
'''.strip() | |
answer_checker_system_prompt = ''' | |
You are a language model and your task is to evaluate the answer that is generated by another llm for the given user query. Check if the answer does answer the entire question or list of questions that the user is asking. | |
If you think answer is sufficient, then call return_answer_to_user tool, else if you think the answer needs some more information, then call the regenerate_answer tool with your suggestion as to how to modify the answer, and what else needs to be included. | |
You can only call one function at a time. | |
For answer sufficiency, look if the answer is appropriately answered. Look for sections which could benefit from further internet research. Sometimes the llm might just answer based on the one google search, you can ask the llm to research down another avenue to fulfill question. | |
'''.strip() | |
answer_checker_tools = [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "return_answer_to_user", | |
"description": "This is function will return the current answer back to the user", | |
"parameters": {} | |
} | |
}, | |
{ | |
"type": "function", | |
"function": { | |
"name": "regenerate_answer", | |
"description": "This function will send the suggestion back to the llm for modifications to the answer based on the user queries", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"suggestion": {"type": "string", "description": "the suggestion for the llm"}, | |
}, | |
"required": ["suggestion"] | |
} | |
} | |
} | |
] | |
def ask_llm(messages: List[Dict[str, str]], tools: List[Dict[str, Any]] | None = None, tool_choice: str = 'auto') -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
""" | |
Ask the LLM a question and get a response. | |
:param messages: List of message dictionaries | |
:param functions: List of function definitions (optional) | |
:return: LLM's response as a dictionary | |
""" | |
import os | |
LLM_API_KEY = os.environ[CURR_MODEL.api_key_env_var] | |
client = openai.OpenAI(api_key=LLM_API_KEY, base_url=CURR_MODEL.base_url) | |
response = client.chat.completions.create( | |
model=CURR_MODEL.model_name, | |
messages=messages, | |
tools=tools, | |
tool_choice=tool_choice, | |
) | |
return response.choices[0].message, response.usage | |
def process_user_query(query: str) -> str: | |
""" | |
Process the user's query and return a response. | |
:param query: User's question or query | |
:return: Final response to the user | |
""" | |
# Step 1: Ask LLM to formulate a search query | |
history = [{'role': 'system', 'content': first_llm_system_prompt}] | |
initial_message = {"role": "user", "content": query} | |
history.append(initial_message) | |
usage_track = TokensAndCosts() | |
while True: | |
assistant_message, usage = ask_llm(history, brave_search_tools) | |
history.append(assistant_message) | |
usage_track.update(usage, CURR_MODEL) | |
if assistant_message.tool_calls: | |
if assistant_message.content: | |
logger.debug(assistant_message.content) | |
for function_call in assistant_message.tool_calls: | |
args = json.loads(function_call.function.arguments) | |
tool_call_response = {"role": "tool", "tool_call_id": function_call.id, "content": ""} | |
if function_call.function.name == 'search_brave': | |
if 'query' in args: | |
search_result = search_brave(args['query']) | |
tool_call_response['content'] = json.dumps([dataclasses.asdict(x) for x in search_result], indent=2) | |
else: | |
tool_call_response['content'] = 'No query found in arguments' | |
history.append(tool_call_response) | |
else: | |
logger.debug(f'LLM Generated Answer:\n{assistant_message.content}') | |
input_to_answer_checker = f'User query: {query}\n\nLLM Answer:\n{assistant_message.content}' | |
answer_checker_history = [ | |
{'role': 'system', 'content': answer_checker_system_prompt}, | |
{'role': 'user', 'content': input_to_answer_checker}, | |
] | |
answer_checker_response, usage = ask_llm(answer_checker_history, answer_checker_tools, tool_choice='required') | |
usage_track.update(usage, CURR_MODEL) | |
function_call = answer_checker_response.tool_calls[0] | |
if function_call.function.name == 'return_answer_to_user': | |
if usage_track.all_good: | |
cost_details: Dict[str, float] = { | |
"total_uncached_input_tokens_cost": usage_track.total_uncached_input_tokens * CURR_MODEL.input_cost, | |
"total_cached_input_tokens_cost": usage_track.total_cached_input_tokens * CURR_MODEL.cached_input_cost, | |
"total_completion_tokens_cost": usage_track.total_completion_tokens * CURR_MODEL.output_cost, | |
"total_cost": usage_track.total_cost | |
} | |
else: | |
cost_details: Dict[str, float] = { | |
"total_uncached_input_tokens_cost": 0, | |
"total_cached_input_tokens_cost": 0, | |
"total_completion_tokens_cost": 0, | |
"total_cost": 0 | |
} | |
logger.debug(f'Usage:\n{usage_track}') | |
return assistant_message.content, cost_details | |
elif function_call.function.name == 'regenerate_answer': | |
suggestion = json.loads(function_call.function.arguments).get('suggestion', '') | |
logger.debug('LLM Suggestion:', suggestion) | |
history.append({'role': 'user', 'content': suggestion}) | |
if __name__ == "__main__": | |
user_question: str = ' '.join(sys.argv[1:]) if len(sys.argv) > 1 else input("What's your question? ") | |
response, cost_details = process_user_query(user_question) | |
print(f"Response: {response}") | |
print(f"Total cost: ${cost_details['total_cost']:.4f}") | |
with open('response.md', 'w') as f: | |
f.write(response) | |
with open('cost_details.json', 'w') as f: | |
json.dump(cost_details, f, indent=2) | |
print('Response and Cost details also saved in current dir as response.md and cost_details.json, respectively.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment