This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
RAG_IN_CONTEXT_PROMPT = """ | |
Given a QUERY below, your task is to come up with a maximum of 25 | |
STATISTICAL QUESTIONS that help in answering QUERY. | |
Here are the only forms of STATISTICAL QUESTIONS you can generate: | |
1. "What is $METRIC in $PLACE?" | |
2. "What is $METRIC in $PLACE $PLACE_TYPE?" | |
3. "How has $METRIC changed over time in $PLACE $PLACE_TYPE?" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@staticmethod | |
def pretty_print(q2resp: dict[str, dg.base.DataCommonsCall]): | |
markdown_output = "# Data Commons Response\n" | |
for k, v in q2resp.items(): | |
markdown_output += f"**{k}**\n\n" | |
markdown_output += f"{v.answer()}\n\n" | |
return markdown_output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataCommonsClient: | |
def __init__(self): | |
self.data_fetcher = dg.DataCommons(api_key=DC_API_KEY) | |
def call_dc(self, questions: list[str]) -> dict[str, dg.base.DataCommonsCall]: | |
try: | |
q2resp = self.data_fetcher.calln(questions, self.data_fetcher.point) | |
except Exception as e: | |
logging.warning(e) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataGemma: | |
def __init__(self, model_id: str = "bartowski/datagemma-rag-27b-it-GGUF", model_file: str = "datagemma-rag-27b-it-Q2_K.gguf"): | |
self.generation_kwargs = { | |
"max_tokens": 4096, # Max number of new tokens to generate | |
} | |
self.model_path = hf_hub_download(model_id, model_file) | |
self.llm = Llama( | |
self.model_path | |
) | |
self.name = "DataGemma" |