Skip to content

Instantly share code, notes, and snippets.

@ShawonAshraf
Created July 19, 2024 21:01
Show Gist options
  • Save ShawonAshraf/a9f3b632cb16667213c0fa0ac341e9d1 to your computer and use it in GitHub Desktop.
Save ShawonAshraf/a9f3b632cb16667213c0fa0ac341e9d1 to your computer and use it in GitHub Desktop.
from llama_index.vector_stores.qdrant import QdrantVectorStore
def get_forward_nodes(
node_with_score: NodeWithScore, num_nodes: int, vector_store: QdrantVectorStore
) -> Dict[str, NodeWithScore]:
"""Get forward nodes."""
node = node_with_score.node
nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score}
cur_count = 0
# get forward nodes in an iterative manner
while cur_count < num_nodes:
if NodeRelationship.NEXT not in node.relationships:
break
next_node_info = node.next_node
if next_node_info is None:
break
next_node_id = next_node_info.node_id
next_node = vector_store.get_nodes(next_node_id)[0]
nodes[next_node.node_id] = NodeWithScore(node=next_node)
node = next_node
cur_count += 1
return nodes
def get_backward_nodes(
node_with_score: NodeWithScore, num_nodes: int, vector_store: QdrantVectorStore
) -> Dict[str, NodeWithScore]:
"""Get backward nodes."""
node = node_with_score.node
# get backward nodes in an iterative manner
nodes: Dict[str, NodeWithScore] = {node.node_id: node_with_score}
cur_count = 0
while cur_count < num_nodes:
prev_node_info = node.prev_node
if prev_node_info is None:
break
prev_node_id = prev_node_info.node_id
prev_node = vector_store.get_nodes(prev_node_id)[0]
if prev_node is None:
break
nodes[prev_node.node_id] = NodeWithScore(node=prev_node)
node = prev_node
cur_count += 1
return nodes
class VectorPrevNextNodePostprocessor(BaseNodePostprocessor):
vector_store: QdrantVectorStore
mode: str = Field(default="next")
@validator("mode")
def _validate_mode(cls, v: str) -> str:
"""Validate mode."""
if v not in ["next", "previous", "both"]:
raise ValueError(f"Invalid mode: {v}")
return v
@classmethod
def class_name(cls) -> str:
return "VectorPrevNextNodePostprocessor"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
all_nodes: Dict[str, NodeWithScore] = {}
for node in nodes:
all_nodes[node.node.node_id] = node
if self.mode == "next":
all_nodes.update(get_forward_nodes(node, self.num_nodes, self.vector_store))
elif self.mode == "previous":
all_nodes.update(
get_backward_nodes(node, self.num_nodes, self.vector_store)
)
elif self.mode == "both":
all_nodes.update(get_forward_nodes(node, self.num_nodes, self.vector_store))
all_nodes.update(
get_backward_nodes(node, self.num_nodes, self.vector_store)
)
else:
raise ValueError(f"Invalid mode: {self.mode}")
all_nodes_values: List[NodeWithScore] = list(all_nodes.values())
sorted_nodes: List[NodeWithScore] = []
for node in all_nodes_values:
# variable to check if cand node is inserted
node_inserted = False
for i, cand in enumerate(sorted_nodes):
node_id = node.node.node_id
# prepend to current candidate
prev_node_info = cand.node.prev_node
next_node_info = cand.node.next_node
if prev_node_info is not None and node_id == prev_node_info.node_id:
node_inserted = True
sorted_nodes.insert(i, node)
break
# append to current candidate
elif next_node_info is not None and node_id == next_node_info.node_id:
node_inserted = True
sorted_nodes.insert(i + 1, node)
break
if not node_inserted:
sorted_nodes.append(node)
return sorted_nodes
@ShawonAshraf
Copy link
Author

Btw, can you please also post your solution in the llama-index issue?

@AutoScrape123TX
Copy link

Yes sure thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment