Created
June 24, 2023 21:33
-
-
Save skirdey/872402c20cd4ec34250be1ce7ce09ad5 to your computer and use it in GitHub Desktop.
haystack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AnswersToDocs(BaseComponent): | |
""" | |
This Node is used to convert retrieved documents into predicted answers format. | |
It is useful for situations where you are calling a Retriever only pipeline via REST API. | |
This ensures that your output is in a compatible format. | |
:param progress_bar: Whether to show a progress bar | |
""" | |
outgoing_edges = 1 | |
def __init__(self, progress_bar: bool = True): | |
super().__init__() | |
self.progress_bar = progress_bar | |
def run(self, query: str, answers: List[Answer]): # type: ignore | |
documents: List[Document] = [] | |
for ans in answers: | |
cur_ans = self._convert_answer_to_doc(ans) | |
documents.append(cur_ans) | |
output = {"query": query, "documents": documents} | |
return output, "output_1" | |
def run_batch(self, queries: List[str], answers: Union[List[Answer], List[List[Answer]]]): # type: ignore | |
output: Dict = {"queries": queries, "documents": []} | |
# Docs case 1: single list of Documents | |
if len(answers) > 0 and isinstance(answers[0], Answer): | |
for answer in tqdm(answers, disable=not self.progress_bar, desc="Converting to docs"): | |
if not isinstance(answer, Answer): | |
raise HaystackError(f"answer was of type {type(answer)}, but expected a Answer.") | |
docs = [self._convert_answer_to_doc(answer)] | |
output["documents"].append(docs) | |
# Docs case 2: list of lists of Documents | |
elif len(answers) > 0 and isinstance(answers[0], list): | |
for answs in tqdm(answers, disable=not self.progress_bar, desc="Converting to documents"): | |
if not isinstance(answs, list): | |
raise HaystackError(f"docs was of type {type(answs)}, but expected a list of Documents.") | |
documents = [] | |
for answ in answs: | |
cur_doc = self._convert_answer_to_doc(answ) | |
documents.append(cur_doc) | |
output["documents"].append(documents) | |
return output, "output_1" | |
@staticmethod | |
def _convert_answer_to_doc(answ: Answer) -> Document: | |
if isinstance(answ, Document): | |
return answ | |
return Document( | |
content=answ.answer, | |
score=answ.score, | |
meta=answ.meta | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment