-
-
Save AutoScrape123TX/174dc0f0f610019ab2cd53ce7fd4163f to your computer and use it in GitHub Desktop.
PrevNextNodePostprocessor tailored to QdrantVectorStore for TextNode relationships
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Optional | |
from llama_index.core.schema import NodeRelationship, NodeWithScore, QueryBundle | |
from llama_index.core.bridge.pydantic import Field, validator | |
from llama_index.core.postprocessor.types import BaseNodePostprocessor | |
from llama_index.vector_stores.qdrant import QdrantVectorStore | |
def qdrant_vector_store_get_forward_nodes( | |
node_with_score: NodeWithScore, num_nodes: int, vector_store: QdrantVectorStore | |
) -> Dict[str, NodeWithScore]: | |
"""Get forward nodes.""" | |
node = node_with_score.node | |
nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score} | |
cur_count = 0 | |
# get forward nodes in an iterative manner | |
while cur_count < num_nodes: | |
if NodeRelationship.NEXT not in node.relationships: | |
break | |
next_node_info = node.next_node | |
if next_node_info is None: | |
break | |
next_node_id = next_node_info.node_id | |
next_node = vector_store.get_nodes([next_node_id])[0] | |
nodes[next_node.node_id] = NodeWithScore(node=next_node) | |
node = next_node | |
cur_count += 1 | |
return nodes | |
def qdrant_vector_store_get_backward_nodes( | |
node_with_score: NodeWithScore, num_nodes: int, vector_store: QdrantVectorStore | |
) -> Dict[str, NodeWithScore]: | |
"""Get backward nodes.""" | |
node = node_with_score.node | |
# get backward nodes in an iterative manner | |
nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score} | |
cur_count = 0 | |
while cur_count < num_nodes: | |
prev_node_info = node.prev_node | |
if prev_node_info is None: | |
break | |
prev_node_id = prev_node_info.node_id | |
prev_node = vector_store.get_nodes([prev_node_id])[0] | |
if prev_node is None: | |
break | |
nodes[prev_node.node_id] = NodeWithScore(node=prev_node) | |
node = prev_node | |
cur_count += 1 | |
return nodes | |
class QdrantVectorPrevNextNodePostprocessor(BaseNodePostprocessor): | |
vector_store: QdrantVectorStore | |
mode: str = Field(default="next") | |
num_nodes: int = Field(default=1) | |
@validator("mode") | |
def _validate_mode(cls, v: str) -> str: | |
"""Validate mode.""" | |
if v not in ["next", "previous", "both"]: | |
raise ValueError(f"Invalid mode: {v}") | |
return v | |
@classmethod | |
def class_name(cls) -> str: | |
return "VectorPrevNextNodePostprocessor" | |
def _postprocess_nodes( | |
self, | |
nodes: List[NodeWithScore], | |
query_bundle: Optional[QueryBundle] = None, | |
) -> List[NodeWithScore]: | |
"""Postprocess nodes.""" | |
all_nodes: Dict[str, NodeWithScore] = {} | |
for node in nodes: | |
all_nodes[node.node.node_id] = node | |
if self.mode == "next": | |
all_nodes.update(qdrant_vector_store_get_forward_nodes(node, self.num_nodes, self.vector_store)) | |
elif self.mode == "previous": | |
all_nodes.update( | |
qdrant_vector_store_get_backward_nodes(node, self.num_nodes, self.vector_store) | |
) | |
elif self.mode == "both": | |
all_nodes.update(qdrant_vector_store_get_forward_nodes(node, self.num_nodes, self.vector_store)) | |
all_nodes.update( | |
qdrant_vector_store_get_backward_nodes(node, self.num_nodes, self.vector_store) | |
) | |
else: | |
raise ValueError(f"Invalid mode: {self.mode}") | |
all_nodes_values: List[NodeWithScore] = list(all_nodes.values()) | |
sorted_nodes: List[NodeWithScore] = [] | |
for node in all_nodes_values: | |
# variable to check if cand node is inserted | |
node_inserted = False | |
for i, cand in enumerate(sorted_nodes): | |
node_id = node.node.node_id | |
# prepend to current candidate | |
prev_node_info = cand.node.prev_node | |
next_node_info = cand.node.next_node | |
if prev_node_info is not None and node_id == prev_node_info.node_id: | |
node_inserted = True | |
sorted_nodes.insert(i, node) | |
break | |
# append to current candidate | |
elif next_node_info is not None and node_id == next_node_info.node_id: | |
node_inserted = True | |
sorted_nodes.insert(i + 1, node) | |
break | |
if not node_inserted: | |
sorted_nodes.append(node) | |
return sorted_nodes | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment