Last active
August 29, 2015 14:03
-
-
Save everilae/3d746e090a3084324316 to your computer and use it in GitHub Desktop.
An attempt at aggregate FILTER clauses for SQLAlchemy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from sqlalchemy import util, and_, case | |
from sqlalchemy.ext.compiler import compiles | |
from sqlalchemy.sql import ColumnElement | |
from sqlalchemy.sql.elements import ClauseList, _clone | |
from sqlalchemy.sql.functions import FunctionElement | |
class AggregateFilter(ColumnElement): | |
"""Represent a FILTER clause. | |
This is a special operator against aggregate functions, | |
which produces results relative to the result set | |
itself. It's supported only by certain database | |
backends. | |
""" | |
__visit_name__ = 'aggregatefilter' | |
criterion = None | |
def __init__(self, func, *criterion): | |
"""Produce an :class:`.AggregateFilter` object against a function. | |
Used against aggregate functions, | |
for database backends that support aggregate "FILTER" clause. | |
E.g.:: | |
from sqlalchemy import over | |
filter(func.count(1), MyClass.name == 'some name') | |
Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')". | |
:param func: a :class:`.FunctionElement` construct, typically | |
generated by :data:`~.expression.func`. | |
:param criterion: a column element or string, or a list | |
of such, that will be used as the FILTER clause | |
of the aggregate construct. | |
This function is also available from the :data:`~.expression.func` | |
construct itself via the :meth:`.FunctionElement.filter` method. | |
""" | |
self.func = func | |
if criterion: | |
self.criterion = ClauseList(*util.to_list(criterion)) | |
@util.memoized_property | |
def type(self): | |
return self.func.type | |
def get_children(self, **kwargs): | |
return [c for c in | |
(self.func, self.criterion) | |
if c is not None] | |
def _copy_internals(self, clone=_clone, **kw): | |
self.func = clone(self.func, **kw) | |
if self.criterion is not None: | |
self.criterion = clone(self.criterion, **kw) | |
@property | |
def _from_objects(self): | |
return list(itertools.chain( | |
*[c._from_objects for c in | |
(self.func, self.criterion) | |
if c is not None] | |
)) | |
# FIXME: this skips the normal function compilation possibly causing all kinds of | |
# weird or unexpected behaviour. Seems to work for simple count() and sum() cases. | |
@compiles(AggregateFilter) | |
def visit_aggregatefilter(aggfilter, compiler, **kwargs): | |
return "%s(%s)" % ( | |
".".join(list(aggfilter.func.packagenames) + [aggfilter.func.name]), | |
compiler.process( | |
case([ | |
( | |
and_(*aggfilter.criterion), | |
aggfilter.func.clause_expr | |
) | |
]) | |
) | |
) | |
# Uncomment to enable, if using postgresql >= 9.4 | |
#@compiles(AggregateFilter, "postgresql") | |
def pg_visit_aggregatefilter(aggfilter, compiler, **kwargs): | |
return "%s FILTER (WHERE %s)" % ( | |
compiler.process(aggfilter.func), | |
compiler.process(and_(*aggfilter.criterion)) | |
) | |
def filter_(self, *criterion): | |
"""Produce a FILTER clause against this function. | |
Used against aggregate functions, | |
for database backends that support aggregate "FILTER" clause. | |
""" | |
return AggregateFilter(self, *criterion) | |
# Monkeypatching, uncomment to enable | |
#FunctionElement.filter = filter_ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment