Created
July 19, 2024 21:01
-
-
Save ShawonAshraf/a9f3b632cb16667213c0fa0ac341e9d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from llama_index.vector_stores.qdrant import QdrantVectorStore | |
def get_forward_nodes( | |
node_with_score: NodeWithScore, num_nodes: int, vector_store: QdrantVectorStore | |
) -> Dict[str, NodeWithScore]: | |
"""Get forward nodes.""" | |
node = node_with_score.node | |
nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score} | |
cur_count = 0 | |
# get forward nodes in an iterative manner | |
while cur_count < num_nodes: | |
if NodeRelationship.NEXT not in node.relationships: | |
break | |
next_node_info = node.next_node | |
if next_node_info is None: | |
break | |
next_node_id = next_node_info.node_id | |
next_node = vector_store.get_nodes(next_node_id)[0] | |
nodes[next_node.node_id] = NodeWithScore(node=next_node) | |
node = next_node | |
cur_count += 1 | |
return nodes | |
def get_backward_nodes( | |
node_with_score: NodeWithScore, num_nodes: int, vector_store: QdrantVectorStore | |
) -> Dict[str, NodeWithScore]: | |
"""Get backward nodes.""" | |
node = node_with_score.node | |
# get backward nodes in an iterative manner | |
nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score} | |
cur_count = 0 | |
while cur_count < num_nodes: | |
prev_node_info = node.prev_node | |
if prev_node_info is None: | |
break | |
prev_node_id = prev_node_info.node_id | |
prev_node = vector_store.get_nodes(prev_node_id)[0] | |
if prev_node is None: | |
break | |
nodes[prev_node.node_id] = NodeWithScore(node=prev_node) | |
node = prev_node | |
cur_count += 1 | |
return nodes | |
class VectorPrevNextNodePostprocessor(BaseNodePostprocessor): | |
vector_store: QdrantVectorStore | |
mode: str = Field(default="next") | |
@validator("mode") | |
def _validate_mode(cls, v: str) -> str: | |
"""Validate mode.""" | |
if v not in ["next", "previous", "both"]: | |
raise ValueError(f"Invalid mode: {v}") | |
return v | |
@classmethod | |
def class_name(cls) -> str: | |
return "VectorPrevNextNodePostprocessor" | |
def _postprocess_nodes( | |
self, | |
nodes: List[NodeWithScore], | |
query_bundle: Optional[QueryBundle] = None, | |
) -> List[NodeWithScore]: | |
"""Postprocess nodes.""" | |
all_nodes: Dict[str, NodeWithScore] = {} | |
for node in nodes: | |
all_nodes[node.node.node_id] = node | |
if self.mode == "next": | |
all_nodes.update(get_forward_nodes(node, self.num_nodes, self.vector_store)) | |
elif self.mode == "previous": | |
all_nodes.update( | |
get_backward_nodes(node, self.num_nodes, self.vector_store) | |
) | |
elif self.mode == "both": | |
all_nodes.update(get_forward_nodes(node, self.num_nodes, self.vector_store)) | |
all_nodes.update( | |
get_backward_nodes(node, self.num_nodes, self.vector_store) | |
) | |
else: | |
raise ValueError(f"Invalid mode: {self.mode}") | |
all_nodes_values: List[NodeWithScore] = list(all_nodes.values()) | |
sorted_nodes: List[NodeWithScore] = [] | |
for node in all_nodes_values: | |
# variable to check if cand node is inserted | |
node_inserted = False | |
for i, cand in enumerate(sorted_nodes): | |
node_id = node.node.node_id | |
# prepend to current candidate | |
prev_node_info = cand.node.prev_node | |
next_node_info = cand.node.next_node | |
if prev_node_info is not None and node_id == prev_node_info.node_id: | |
node_inserted = True | |
sorted_nodes.insert(i, node) | |
break | |
# append to current candidate | |
elif next_node_info is not None and node_id == next_node_info.node_id: | |
node_inserted = True | |
sorted_nodes.insert(i + 1, node) | |
break | |
if not node_inserted: | |
sorted_nodes.append(node) | |
return sorted_nodes | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes sure thanks.