Last active
May 10, 2024 07:14
-
-
Save pmeier/38ee90be6c30ecdf9bbec086a0dabafe to your computer and use it in GitHub Desktop.
Sample implementation for MetadataFilter in https://github.com/Quansight/ragna/issues/256#issuecomment-1933808962
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
from metadata_filter import MetadataFilter, MetadataFilterOperator | |
# https://docs.trychroma.com/usage-guide#using-where-filters | |
OPERATOR_MAP = { | |
MetadataFilterOperator.AND: "$and", | |
MetadataFilterOperator.OR: "$or", | |
MetadataFilterOperator.EQ: "$eq", | |
MetadataFilterOperator.NE: "$ne", | |
MetadataFilterOperator.LT: "$lt", | |
MetadataFilterOperator.LE: "$lte", | |
MetadataFilterOperator.GT: "$gt", | |
MetadataFilterOperator.GE: "$gte", | |
MetadataFilterOperator.IN: "$in", | |
MetadataFilterOperator.NOT_IN: "$nin", | |
} | |
def translate(filter: MetadataFilter) -> dict[str, Any]: | |
if filter.operator is MetadataFilterOperator.RAW: | |
return filter.value | |
elif filter.operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}: | |
return {OPERATOR_MAP[filter.operator]: [translate(child) for child in filter.value]} | |
else: | |
return {filter.key: {OPERATOR_MAP[filter.operator]: filter.value}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pprint("generic filter", generic_filter) | |
pprint("generic filter jsonified", generic_filter.to_json()) | |
pprint( | |
"generic filter json roundtrip", MetadataFilter.from_json(generic_filter.to_json()) | |
) | |
pprint("generic filter translated to Chroma dialect", translate_chroma(generic_filter)) | |
pprint( | |
"generic filter translated to LanceDB dialect", translate_lance_db(generic_filter) | |
) | |
chroma_specific_filter = MetadataFilter.or_( | |
[generic_filter, MetadataFilter.raw({"spam": {"$exists": True}})] | |
) | |
pprint("Chroma specific filter", chroma_specific_filter) | |
pprint( | |
"Chroma specific filter translated to its dialect", | |
translate_chroma(chroma_specific_filter), | |
) | |
lancedb_specific_filter = MetadataFilter.or_( | |
[generic_filter, MetadataFilter.raw("regexp_match('spam', '*am$')")] | |
) | |
pprint("LanceDB specific filter", lancedb_specific_filter) | |
pprint( | |
"LanceDB specific filter translated to its dialect", | |
translate_lance_db(lancedb_specific_filter), | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from metadata_filter import MetadataFilter, MetadataFilterOperator | |
# https://lancedb.github.io/lancedb/sql/ | |
OPERATOR_MAP = { | |
MetadataFilterOperator.AND: "AND", | |
MetadataFilterOperator.OR: "OR", | |
MetadataFilterOperator.EQ: "=", | |
MetadataFilterOperator.LT: "<", | |
MetadataFilterOperator.LE: "<=", | |
MetadataFilterOperator.GT: ">", | |
MetadataFilterOperator.GE: ">=", | |
MetadataFilterOperator.IN: "IN", | |
} | |
def translate(filter: MetadataFilter) -> str: | |
if filter.operator is MetadataFilterOperator.RAW: | |
return filter.value | |
elif filter.operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}: | |
return f" {OPERATOR_MAP[filter.operator]} ".join( | |
f"({translate(child)})" for child in filter.value | |
) | |
elif filter.operator is MetadataFilterOperator.NE: | |
return f"NOT ({translate(MetadataFilter.eq(filter.key, filter.value))})" | |
elif filter.operator is MetadataFilterOperator.NOT_IN: | |
return f"NOT ({translate(MetadataFilter.in_(filter.key, filter.value))})" | |
else: | |
value = ( | |
tuple(filter.value) | |
if filter.operator is MetadataFilterOperator.IN | |
else filter.value | |
) | |
return f"{filter.key} {OPERATOR_MAP[filter.operator]} {value!r}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import enum | |
import textwrap | |
from typing import Any, Literal, Sequence, cast | |
import json | |
class MetadataFilterOperator(enum.Enum): | |
RAW = enum.auto() | |
AND = enum.auto() | |
OR = enum.auto() | |
EQ = enum.auto() | |
NE = enum.auto() | |
LT = enum.auto() | |
LE = enum.auto() | |
GT = enum.auto() | |
GE = enum.auto() | |
IN = enum.auto() | |
NOT_IN = enum.auto() | |
class MetadataFilter: | |
# These are just to be consistent. The actual values have no effect. | |
_RAW_KEY = "filter" | |
_CHILDREN_KEY = "children" | |
def __init__(self, operator: MetadataFilterOperator, key: str, value: Any) -> None: | |
self.operator = operator | |
self.key = key | |
self.value = value | |
def __repr__(self) -> str: | |
if self.operator is MetadataFilterOperator.RAW: | |
return f"{self.operator.name}({self.value!r})" | |
elif self.operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}: | |
return "\n".join( | |
[ | |
f"{self.operator.name}(", | |
*[ | |
f"{textwrap.indent(repr(child), prefix=' ' * 2)}," | |
for child in self.value | |
], | |
")", | |
] | |
) | |
else: | |
return f"{self.operator.name}({self.key!r}, {self.value!r})" | |
def _to_json(self) -> dict[str, Any]: | |
if self.operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}: | |
value = [child._to_json() for child in self.value] | |
else: | |
value = self.value | |
return {self.operator.name: {self.key: value}} | |
def to_json(self) -> str: | |
return json.dumps(self._to_json()) | |
@classmethod | |
def _from_json(cls, json_obj: dict[str, Any]) -> MetadataFilter: | |
operator, key_value = next(iter(json_obj.items())) | |
operator = MetadataFilterOperator.__members__[operator] | |
key_value = cast(dict[str, Any], key_value) | |
key, value = next(iter(key_value.items())) | |
if operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}: | |
value = [cls._from_json(child) for child in value] | |
return cls(operator, key, value) | |
@classmethod | |
def from_json(cls, json_str: str) -> MetadataFilter: | |
return cls._from_json(json.loads(json_str)) | |
@classmethod | |
def raw(cls, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.RAW, cls._RAW_KEY, value) | |
@staticmethod | |
def _flatten( | |
operator: Literal[MetadataFilterOperator.OR, MetadataFilterOperator.AND], | |
children: Sequence[MetadataFilter], | |
) -> list[MetadataFilter]: | |
flat_children = [] | |
for child in children: | |
if child.operator == operator: | |
flat_children.extend(child.value) | |
else: | |
flat_children.append(child) | |
return flat_children | |
@classmethod | |
def and_(cls, children: Sequence[MetadataFilter]) -> MetadataFilter: | |
return cls( | |
MetadataFilterOperator.AND, | |
cls._CHILDREN_KEY, | |
cls._flatten(MetadataFilterOperator.AND, children), | |
) | |
@classmethod | |
def or_(cls, children: list[MetadataFilter]) -> MetadataFilter: | |
return cls( | |
MetadataFilterOperator.OR, | |
cls._CHILDREN_KEY, | |
cls._flatten(MetadataFilterOperator.OR, children), | |
) | |
@classmethod | |
def eq(cls, key: str, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.EQ, key, value) | |
@classmethod | |
def ne(cls, key: str, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.NE, key, value) | |
@classmethod | |
def lt(cls, key: str, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.LT, key, value) | |
@classmethod | |
def le(cls, key: str, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.LE, key, value) | |
@classmethod | |
def gt(cls, key: str, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.GT, key, value) | |
@classmethod | |
def ge(cls, key: str, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.GE, key, value) | |
@classmethod | |
def in_(cls, key: str, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.IN, key, value) | |
@classmethod | |
def not_in(cls, key: str, value: Any) -> MetadataFilter: | |
return cls(MetadataFilterOperator.NOT_IN, key, value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Running
demo.py
prints