Created
June 12, 2020 22:50
-
-
Save edelooff/b71436f137a390545b42d1f18a7ad4f8 to your computer and use it in GitHub Desktop.
Evaluating SQLAlchemy expressions in Python (for ORM objects and such)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import operator | |
from collections import deque | |
from enum import Enum, auto | |
from typing import Any, Dict, Iterator, NamedTuple, Optional, Tuple | |
from sqlalchemy import Column | |
from sqlalchemy.sql import operators | |
from sqlalchemy.sql.elements import ( | |
AsBoolean, | |
BinaryExpression, | |
BindParameter, | |
BooleanClauseList, | |
ColumnElement, | |
Grouping, | |
Null, | |
UnaryExpression, | |
) | |
OPERATOR_MAP = { | |
operators.istrue: None, | |
operators.isfalse: operator.not_, | |
} | |
class Expression: | |
def __init__(self, expression: ColumnElement, force_bool: bool = False): | |
self.serialized = tuple(self._serialize(expression, force_bool=force_bool)) | |
def __eq__(self, other: Any) -> bool: | |
if not isinstance(other, type(self)): | |
return NotImplemented | |
return self.serialized == other.serialized | |
def evaluate(self, column_values: Dict[Column, Any]) -> Any: | |
stack = Stack() | |
for itype, arity, value in reversed(self.serialized): | |
if itype is SymbolType.literal: | |
stack.push(value) | |
elif itype is SymbolType.column: | |
stack.push(column_values[value]) | |
else: | |
stack.push(value(*stack.popn(arity))) | |
return stack.pop() | |
def _serialize(self, expr, force_bool=False) -> Iterator[Symbol]: | |
"""Serializes an SQLAlchemy expression to Python functions. | |
This takes an SQLAlchemy expression tree and converts it into an | |
equivalent set of Python Symbols. The generated format is that | |
of a Polish prefix notation. This allows the expression to be easily | |
evaluated with column value substitutions. | |
""" | |
# Simple and direct value types | |
if isinstance(expr, BindParameter): | |
yield Symbol(expr.value) | |
elif isinstance(expr, Grouping): | |
value = [element.value for element in expr.element] | |
yield Symbol(value) | |
elif isinstance(expr, Null): | |
yield Symbol(None) | |
# Columns and column-wrapping functions | |
elif isinstance(expr, Column): | |
if force_bool: | |
yield from self._serialize(expr.isnot(None)) | |
else: | |
yield Symbol(expr) | |
elif isinstance(expr, AsBoolean): | |
if (func := OPERATOR_MAP[expr.operator]) is not None: | |
yield Symbol(func, arity=1) | |
yield Symbol(expr.element) | |
elif isinstance(expr, UnaryExpression): | |
target = expr.element | |
target_is_column = isinstance(target, Column) | |
if force_bool and expr.operator == operator.inv and target_is_column: | |
yield from self._serialize(target.is_(None)) | |
else: | |
yield Symbol(expr.operator, arity=1) | |
yield from self._serialize(target, force_bool=force_bool) | |
# Multi-clause expressions | |
elif isinstance(expr, BooleanClauseList): | |
yield Symbol(expr.operator, arity=len(expr.clauses)) | |
for clause in expr.clauses: | |
yield from self._serialize(clause, force_bool=force_bool) | |
elif isinstance(expr, BinaryExpression): | |
yield Symbol(expr.operator, arity=2) | |
yield from self._serialize(expr.left) | |
yield from self._serialize(expr.right) | |
else: | |
raise TypeError( | |
f"Unsupported expression {expr} of type {type(expr)}.__name__" | |
) | |
class Stack: | |
def __init__(self): | |
self._stack = deque() | |
def push(self, frame: Any) -> None: | |
self._stack.append(frame) | |
def pop(self) -> Any: | |
return self._stack.pop() | |
def popn(self, argcount) -> Iterator[Any]: | |
return (self._stack.pop() for _ in range(argcount)) | |
class Symbol: | |
__slots__ = "value", "type", "arity" | |
def __init__(self, value: Any, arity: int = None): | |
self.value = value | |
self.type = self._determine_type(value) | |
self.arity = arity | |
def _determine_type(self, value: Any) -> SymbolType: | |
if isinstance(value, Column): | |
return SymbolType.column | |
if callable(value): | |
return SymbolType.operator | |
return SymbolType.literal | |
def __eq__(self, other: Any) -> bool: | |
if not isinstance(other, type(self)): | |
return NotImplemented | |
return tuple(self) == tuple(other) | |
def __iter__(self): | |
yield from (self.type, self.arity, self.value) | |
class SymbolType(Enum): | |
column = auto() | |
literal = auto() | |
operator = auto() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from sqlalchemy import Table, MetaData, Column, Boolean, Integer, Text, null | |
from expression import Expression | |
class TestBooleanExpressions: | |
BOOL = Table( | |
"bool_expr", | |
MetaData(), | |
Column("a", Boolean), | |
Column("b", Boolean), | |
Column("c", Boolean), | |
) | |
@pytest.fixture | |
def cols(self): | |
return self.BOOL.columns | |
@pytest.mark.parametrize( | |
"inputs, expected", [({BOOL.c.a: False}, False), ({BOOL.c.a: True}, True)] | |
) | |
def test_direct_bool(self, cols, inputs, expected): | |
expression = Expression(cols.a) | |
assert expression.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", [({BOOL.c.a: False}, True), ({BOOL.c.a: True}, False)] | |
) | |
def test_negation(self, cols, inputs, expected): | |
expression = Expression(~cols.a) | |
assert expression.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({BOOL.c.a: False, BOOL.c.b: False}, False), | |
({BOOL.c.a: True, BOOL.c.b: False}, False), | |
({BOOL.c.a: False, BOOL.c.b: True}, False), | |
({BOOL.c.a: True, BOOL.c.b: True}, True), | |
], | |
) | |
def test_conjunction(self, cols, inputs, expected): | |
expression = Expression(cols.a & cols.b) | |
assert expression.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({BOOL.c.a: False, BOOL.c.b: False}, False), | |
({BOOL.c.a: True, BOOL.c.b: False}, True), | |
({BOOL.c.a: False, BOOL.c.b: True}, True), | |
({BOOL.c.a: True, BOOL.c.b: True}, True), | |
], | |
) | |
def test_disjunction(self, cols, inputs, expected): | |
expression = Expression(cols.a | cols.b) | |
assert expression.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({BOOL.c.a: False, BOOL.c.b: False, BOOL.c.c: False}, False), | |
({BOOL.c.a: True, BOOL.c.b: False, BOOL.c.c: False}, True), | |
({BOOL.c.a: False, BOOL.c.b: True, BOOL.c.c: False}, False), | |
({BOOL.c.a: True, BOOL.c.b: True, BOOL.c.c: False}, False), | |
({BOOL.c.a: False, BOOL.c.b: False, BOOL.c.c: True}, True), | |
({BOOL.c.a: True, BOOL.c.b: False, BOOL.c.c: True}, True), | |
({BOOL.c.a: False, BOOL.c.b: True, BOOL.c.c: True}, True), | |
({BOOL.c.a: True, BOOL.c.b: True, BOOL.c.c: True}, True), | |
], | |
) | |
def test_mixed_expression(self, cols, inputs, expected): | |
expression = Expression((cols.a & ~cols.b) | cols.c) | |
assert expression.evaluate(inputs) == expected | |
class TestBooleanCoercion: | |
COERCE = Table( | |
"expr_coerce", | |
MetaData(), | |
Column("bool", Boolean), | |
Column("number", Integer), | |
Column("text", Text), | |
) | |
@pytest.fixture | |
def cols(self): | |
return self.COERCE.columns | |
def test_sql_equivalences(self, cols): | |
left = Expression(cols.text != null()) | |
right = Expression(cols.text.isnot(None)) | |
assert left == right | |
def test_coerce_simple_expression(self, cols): | |
left = Expression(cols.text, force_bool=True) | |
right = Expression(cols.text.isnot(None)) | |
assert left == right | |
def test_coerce_negated_expression(self, cols): | |
left = Expression(~cols.text, force_bool=True) | |
right = Expression(cols.text.is_(None)) | |
assert left == right | |
def test_do_not_coerce_nonbool(self, cols): | |
left = Expression(~cols.text.in_(["foo", "bar"]), force_bool=True) | |
right = Expression(~cols.text.in_(["foo", "bar"])) | |
assert left == right | |
class TestMathExpressions: | |
MATH = Table( | |
"expr_math", | |
MetaData(), | |
Column("a", Integer), | |
Column("b", Integer), | |
Column("c", Integer), | |
) | |
@pytest.fixture | |
def cols(self): | |
return self.MATH.columns | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({MATH.c.a: 0, MATH.c.b: 0}, 0), | |
({MATH.c.a: 2, MATH.c.b: 0}, 2), | |
({MATH.c.a: 0, MATH.c.b: 3}, 3), | |
({MATH.c.a: 5, MATH.c.b: 5}, 10), | |
({MATH.c.a: -2, MATH.c.b: -2}, -4), | |
], | |
) | |
def test_addition(self, cols, inputs, expected): | |
addition = Expression(cols.a + cols.b) | |
assert addition.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({MATH.c.a: 0, MATH.c.b: 0}, 0), | |
({MATH.c.a: 2, MATH.c.b: 0}, 2), | |
({MATH.c.a: 0, MATH.c.b: 3}, -3), | |
({MATH.c.a: 10, MATH.c.b: 5}, 5), | |
], | |
) | |
def test_subtraction(self, cols, inputs, expected): | |
subtraction = Expression(cols.a - cols.b) | |
assert subtraction.evaluate(inputs) == expected | |
@pytest.mark.parametrize( | |
"inputs, expected", | |
[ | |
({MATH.c.a: 0, MATH.c.b: 0, MATH.c.c: 0}, 0), | |
({MATH.c.a: 2, MATH.c.b: 3, MATH.c.c: 0}, 6), | |
({MATH.c.a: 3, MATH.c.b: 4, MATH.c.c: 6}, 6), | |
({MATH.c.a: 3, MATH.c.b: 3, MATH.c.c: 10}, -1), | |
], | |
) | |
def test_mixed_match(self, cols, inputs, expected): | |
multmin = Expression(cols.a * cols.b - cols.c) | |
assert multmin.evaluate(inputs) == expected |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment