Created
December 1, 2020 12:56
-
-
Save funnydman/376480af6500cc43f706df94e6e6ab33 to your computer and use it in GitHub Desktop.
combined expressions Django helper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://stackoverflow.com/questions/58877390/how-to-collect-results-into-array-from-annotation | |
""" | |
Some different ways: | |
""" | |
class CombinedExpressions(SQLiteNumericMixin, Expression): | |
def __init__(self, *expressions, connector, output_field=None): | |
super().__init__(output_field=output_field) | |
self.connector = connector | |
self.expressions = expressions | |
def get_source_expressions(self): | |
return self.expressions | |
def as_sql(self, compiler, connection): | |
output_fields = [] | |
not_supported_fields = {'DateField', 'DateTimeField', 'TimeField', 'DurationField'} | |
for expr in self.expressions: | |
output_field = getattr(expr, 'output_field') | |
if output_field: | |
output_fields.append(output_field.get_internal_type()) | |
if set(output_fields) & not_supported_fields: | |
raise NotImplementedError(f'Not support for such output fields: {", ".join(not_supported_fields)}') | |
expressions = [] | |
expression_params = [] | |
# compile every expression | |
for expression in self.expressions: | |
sql, params = compiler.compile(expression) | |
expressions.append(sql) | |
expression_params.extend(params) | |
# order of precedence | |
expression_wrapper = '(%s)' | |
sql = connection.ops.combine_expression(self.connector, expressions) | |
return expression_wrapper % sql, expression_params | |
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): | |
c = self.copy() | |
c.is_summary = summarize | |
# resolving stuff | |
expressions = [] | |
for exp in c.expressions: | |
exp = exp.resolve_expression(query, allow_joins, reuse, summarize, for_save) | |
expressions.append(exp) | |
c.expressions = expressions | |
return c | |
class FruitManager(Manager): | |
def get_queryset(self): | |
query = super().get_queryset() | |
query = query.annotate( | |
result=( | |
CombinedExpressions( | |
ArrayAgg(Case( | |
When( | |
type='tropicals', | |
then=Value('This fruit is tropical...'), | |
), | |
output_field=CharField() | |
)), | |
ArrayAgg(Case( | |
When( | |
country_of_import='Africa', | |
then=Value('This fruit is citrus...'), | |
), | |
output_field=CharField(), | |
default=Value('here we go')), | |
), | |
connector='||' | |
) | |
) | |
) | |
return query |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment