Skip to content

Instantly share code, notes, and snippets.

@pmeier
Last active May 10, 2024 07:14
Show Gist options
  • Save pmeier/38ee90be6c30ecdf9bbec086a0dabafe to your computer and use it in GitHub Desktop.
Save pmeier/38ee90be6c30ecdf9bbec086a0dabafe to your computer and use it in GitHub Desktop.
from typing import Any
from metadata_filter import MetadataFilter, MetadataFilterOperator
# https://docs.trychroma.com/usage-guide#using-where-filters
OPERATOR_MAP = {
MetadataFilterOperator.AND: "$and",
MetadataFilterOperator.OR: "$or",
MetadataFilterOperator.EQ: "$eq",
MetadataFilterOperator.NE: "$ne",
MetadataFilterOperator.LT: "$lt",
MetadataFilterOperator.LE: "$lte",
MetadataFilterOperator.GT: "$gt",
MetadataFilterOperator.GE: "$gte",
MetadataFilterOperator.IN: "$in",
MetadataFilterOperator.NOT_IN: "$nin",
}
def translate(filter: MetadataFilter) -> dict[str, Any]:
if filter.operator is MetadataFilterOperator.RAW:
return filter.value
elif filter.operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}:
return {OPERATOR_MAP[filter.operator]: [translate(child) for child in filter.value]}
else:
return {filter.key: {OPERATOR_MAP[filter.operator]: filter.value}}
pprint("generic filter", generic_filter)
pprint("generic filter jsonified", generic_filter.to_json())
pprint(
"generic filter json roundtrip", MetadataFilter.from_json(generic_filter.to_json())
)
pprint("generic filter translated to Chroma dialect", translate_chroma(generic_filter))
pprint(
"generic filter translated to LanceDB dialect", translate_lance_db(generic_filter)
)
chroma_specific_filter = MetadataFilter.or_(
[generic_filter, MetadataFilter.raw({"spam": {"$exists": True}})]
)
pprint("Chroma specific filter", chroma_specific_filter)
pprint(
"Chroma specific filter translated to its dialect",
translate_chroma(chroma_specific_filter),
)
lancedb_specific_filter = MetadataFilter.or_(
[generic_filter, MetadataFilter.raw("regexp_match('spam', '*am$')")]
)
pprint("LanceDB specific filter", lancedb_specific_filter)
pprint(
"LanceDB specific filter translated to its dialect",
translate_lance_db(lancedb_specific_filter),
)
from metadata_filter import MetadataFilter, MetadataFilterOperator
# https://lancedb.github.io/lancedb/sql/
OPERATOR_MAP = {
MetadataFilterOperator.AND: "AND",
MetadataFilterOperator.OR: "OR",
MetadataFilterOperator.EQ: "=",
MetadataFilterOperator.LT: "<",
MetadataFilterOperator.LE: "<=",
MetadataFilterOperator.GT: ">",
MetadataFilterOperator.GE: ">=",
MetadataFilterOperator.IN: "IN",
}
def translate(filter: MetadataFilter) -> str:
if filter.operator is MetadataFilterOperator.RAW:
return filter.value
elif filter.operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}:
return f" {OPERATOR_MAP[filter.operator]} ".join(
f"({translate(child)})" for child in filter.value
)
elif filter.operator is MetadataFilterOperator.NE:
return f"NOT ({translate(MetadataFilter.eq(filter.key, filter.value))})"
elif filter.operator is MetadataFilterOperator.NOT_IN:
return f"NOT ({translate(MetadataFilter.in_(filter.key, filter.value))})"
else:
value = (
tuple(filter.value)
if filter.operator is MetadataFilterOperator.IN
else filter.value
)
return f"{filter.key} {OPERATOR_MAP[filter.operator]} {value!r}"
from __future__ import annotations
import enum
import textwrap
from typing import Any, Literal, Sequence, cast
import json
class MetadataFilterOperator(enum.Enum):
RAW = enum.auto()
AND = enum.auto()
OR = enum.auto()
EQ = enum.auto()
NE = enum.auto()
LT = enum.auto()
LE = enum.auto()
GT = enum.auto()
GE = enum.auto()
IN = enum.auto()
NOT_IN = enum.auto()
class MetadataFilter:
# These are just to be consistent. The actual values have no effect.
_RAW_KEY = "filter"
_CHILDREN_KEY = "children"
def __init__(self, operator: MetadataFilterOperator, key: str, value: Any) -> None:
self.operator = operator
self.key = key
self.value = value
def __repr__(self) -> str:
if self.operator is MetadataFilterOperator.RAW:
return f"{self.operator.name}({self.value!r})"
elif self.operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}:
return "\n".join(
[
f"{self.operator.name}(",
*[
f"{textwrap.indent(repr(child), prefix=' ' * 2)},"
for child in self.value
],
")",
]
)
else:
return f"{self.operator.name}({self.key!r}, {self.value!r})"
def _to_json(self) -> dict[str, Any]:
if self.operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}:
value = [child._to_json() for child in self.value]
else:
value = self.value
return {self.operator.name: {self.key: value}}
def to_json(self) -> str:
return json.dumps(self._to_json())
@classmethod
def _from_json(cls, json_obj: dict[str, Any]) -> MetadataFilter:
operator, key_value = next(iter(json_obj.items()))
operator = MetadataFilterOperator.__members__[operator]
key_value = cast(dict[str, Any], key_value)
key, value = next(iter(key_value.items()))
if operator in {MetadataFilterOperator.AND, MetadataFilterOperator.OR}:
value = [cls._from_json(child) for child in value]
return cls(operator, key, value)
@classmethod
def from_json(cls, json_str: str) -> MetadataFilter:
return cls._from_json(json.loads(json_str))
@classmethod
def raw(cls, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.RAW, cls._RAW_KEY, value)
@staticmethod
def _flatten(
operator: Literal[MetadataFilterOperator.OR, MetadataFilterOperator.AND],
children: Sequence[MetadataFilter],
) -> list[MetadataFilter]:
flat_children = []
for child in children:
if child.operator == operator:
flat_children.extend(child.value)
else:
flat_children.append(child)
return flat_children
@classmethod
def and_(cls, children: Sequence[MetadataFilter]) -> MetadataFilter:
return cls(
MetadataFilterOperator.AND,
cls._CHILDREN_KEY,
cls._flatten(MetadataFilterOperator.AND, children),
)
@classmethod
def or_(cls, children: list[MetadataFilter]) -> MetadataFilter:
return cls(
MetadataFilterOperator.OR,
cls._CHILDREN_KEY,
cls._flatten(MetadataFilterOperator.OR, children),
)
@classmethod
def eq(cls, key: str, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.EQ, key, value)
@classmethod
def ne(cls, key: str, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.NE, key, value)
@classmethod
def lt(cls, key: str, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.LT, key, value)
@classmethod
def le(cls, key: str, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.LE, key, value)
@classmethod
def gt(cls, key: str, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.GT, key, value)
@classmethod
def ge(cls, key: str, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.GE, key, value)
@classmethod
def in_(cls, key: str, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.IN, key, value)
@classmethod
def not_in(cls, key: str, value: Any) -> MetadataFilter:
return cls(MetadataFilterOperator.NOT_IN, key, value)
@pmeier
Copy link
Author

pmeier commented Feb 12, 2024

Running demo.py prints

################################################################################
# generic filter
################################################################################

OR(
  EQ('tag', 'a'),
  AND(
    EQ('doc', 'b'),
    EQ('doc', 'c'),
  ),
  NOT_IN('id', ['id1', 'id2']),
  GE('count', 6),
)

################################################################################
# generic filter jsonified
################################################################################

{"OR": {"children": [{"EQ": {"tag": "a"}}, {"AND": {"children": [{"EQ": {"doc": "b"}}, {"EQ": {"doc": "c"}}]}}, {"NOT_IN": {"id": ["id1", "id2"]}}, {"GE": {"count": 6}}]}}

################################################################################
# generic filter json roundtrip
################################################################################

OR(
  EQ('tag', 'a'),
  AND(
    EQ('doc', 'b'),
    EQ('doc', 'c'),
  ),
  NOT_IN('id', ['id1', 'id2']),
  GE('count', 6),
)

################################################################################
# generic filter translated to Chroma dialect
################################################################################

{'$or': [{'tag': {'$eq': 'a'}}, {'$and': [{'doc': {'$eq': 'b'}}, {'doc': {'$eq': 'c'}}]}, {'id': {'$nin': ['id1', 'id2']}}, {'count': {'$gte': 6}}]}

################################################################################
# generic filter translated to LanceDB dialect
################################################################################

(tag = 'a') OR ((doc = 'b') AND (doc = 'c')) OR (NOT (id IN ('id1', 'id2'))) OR (count >= 6)

################################################################################
# Chroma specific filter
################################################################################

OR(
  EQ('tag', 'a'),
  AND(
    EQ('doc', 'b'),
    EQ('doc', 'c'),
  ),
  NOT_IN('id', ['id1', 'id2']),
  GE('count', 6),
  RAW({'spam': {'$exists': True}}),
)

################################################################################
# Chroma specific filter translated to its dialect
################################################################################

{'$or': [{'tag': {'$eq': 'a'}}, {'$and': [{'doc': {'$eq': 'b'}}, {'doc': {'$eq': 'c'}}]}, {'id': {'$nin': ['id1', 'id2']}}, {'count': {'$gte': 6}}, {'spam': {'$exists': True}}]}

################################################################################
# LanceDB specific filter
################################################################################

OR(
  EQ('tag', 'a'),
  AND(
    EQ('doc', 'b'),
    EQ('doc', 'c'),
  ),
  NOT_IN('id', ['id1', 'id2']),
  GE('count', 6),
  RAW("regexp_match('spam', '*am$')"),
)

################################################################################
# LanceDB specific filter translated to its dialect
################################################################################

(tag = 'a') OR ((doc = 'b') AND (doc = 'c')) OR (NOT (id IN ('id1', 'id2'))) OR (count >= 6) OR (regexp_match('spam', '*am$'))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment