Last active
November 2, 2019 10:11
-
-
Save hassanselim0/8123bf889c4df4879993ae11bab335f6 to your computer and use it in GitHub Desktop.
Windowing Functions for Django 1.11
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Back-porting DB Windowing Functions from Django 2 😀 | |
# Note, this is only partially-tested on PostgreSQL | |
# New DB opts have been replaced with their string values | |
from django.db import connection | |
from django.db.models import Func, Value, FloatField, IntegerField | |
from django.db.models.expressions import BaseExpression, Expression | |
class Window(Expression): | |
template = '%(expression)s OVER (%(window)s)' | |
# Although the main expression may either be an aggregate or an | |
# expression with an aggregate function, the GROUP BY that will | |
# be introduced in the query as a result is not desired. | |
contains_aggregate = False | |
contains_over_clause = True | |
filterable = False | |
def __init__( | |
self, expression, partition_by=None, order_by=None, frame=None, output_field=None): | |
self.partition_by = partition_by | |
self.order_by = order_by | |
self.frame = frame | |
if not getattr(expression, 'window_compatible', False): | |
raise ValueError( | |
"Expression '%s' isn't compatible with OVER clauses." % | |
expression.__class__.__name__ | |
) | |
if self.partition_by is not None: | |
if not isinstance(self.partition_by, (tuple, list)): | |
self.partition_by = (self.partition_by,) | |
self.partition_by = ExpressionList(*self.partition_by) | |
if self.order_by is not None: | |
if isinstance(self.order_by, (list, tuple)): | |
self.order_by = ExpressionList(*self.order_by) | |
elif not isinstance(self.order_by, BaseExpression): | |
raise ValueError( | |
'order_by must be either an Expression or a sequence of ' | |
'expressions.' | |
) | |
super().__init__(output_field=output_field) | |
self.source_expression = self._parse_expressions(expression)[0] | |
def _resolve_output_field(self): | |
self._output_field = self.source_expression.output_field | |
return self._output_field | |
def get_source_expressions(self): | |
if self.frame: | |
return [self.source_expression, self.partition_by, self.order_by, self.frame] | |
else: | |
return [self.source_expression, self.partition_by, self.order_by] | |
def set_source_expressions(self, exprs): | |
if len(exprs) == 4: | |
self.source_expression, self.partition_by, self.order_by, self.frame = exprs | |
else: | |
self.source_expression, self.partition_by, self.order_by = exprs | |
def as_sql(self, compiler, connection, function=None, template=None): | |
connection.ops.check_expression_support(self) | |
expr_sql, params = compiler.compile(self.source_expression) | |
window_sql, window_params = [], [] | |
if self.partition_by is not None: | |
sql_expr, sql_params = self.partition_by.as_sql( | |
compiler=compiler, connection=connection, | |
template='PARTITION BY %(expressions)s', | |
) | |
window_sql.extend(sql_expr) | |
window_params.extend(sql_params) | |
if self.order_by is not None: | |
window_sql.append(' ORDER BY ') | |
order_sql, order_params = compiler.compile(self.order_by) | |
window_sql.extend(''.join(order_sql)) | |
window_params.extend(order_params) | |
if self.frame: | |
frame_sql, frame_params = compiler.compile(self.frame) | |
window_sql.extend(' ' + frame_sql) | |
window_params.extend(frame_params) | |
params.extend(window_params) | |
template = template or self.template | |
return template % { | |
'expression': expr_sql, | |
'window': ''.join(window_sql).strip() | |
}, params | |
def __str__(self): | |
return '{} OVER ({}{}{})'.format( | |
str(self.source_expression), | |
'PARTITION BY ' + str(self.partition_by) if self.partition_by else '', | |
'ORDER BY ' + str(self.order_by) if self.order_by else '', | |
str(self.frame or ''), | |
) | |
def __repr__(self): | |
return '<%s: %s>' % (self.__class__.__name__, self) | |
def get_group_by_cols(self): | |
return [] | |
class WindowFrame(Expression): | |
""" | |
Model the frame clause in window expressions. There are two types of frame | |
clauses which are subclasses, however, all processing and validation (by no | |
means intended to be complete) is done here. Thus, providing an end for a | |
frame is optional (the default is UNBOUNDED FOLLOWING, which is the last | |
row in the frame). | |
""" | |
template = '%(frame_type)s BETWEEN %(start)s AND %(end)s' | |
def __init__(self, start=None, end=None): | |
self.start = start | |
self.end = end | |
def set_source_expressions(self, exprs): | |
self.start, self.end = exprs | |
def get_source_expressions(self): | |
return [Value(self.start), Value(self.end)] | |
def as_sql(self, compiler, connection): | |
connection.ops.check_expression_support(self) | |
start, end = self.window_frame_start_end(connection, self.start.value, self.end.value) | |
return self.template % { | |
'frame_type': self.frame_type, | |
'start': start, | |
'end': end, | |
}, [] | |
def __repr__(self): | |
return '<%s: %s>' % (self.__class__.__name__, self) | |
def get_group_by_cols(self): | |
return [] | |
def __str__(self): | |
if self.start is not None and self.start < 0: | |
start = '%d %s' % (abs(self.start), 'PRECEDING') | |
elif self.start is not None and self.start == 0: | |
start = 'CURRENT ROW' | |
else: | |
start = 'UNBOUNDED PRECEDING' | |
if self.end is not None and self.end > 0: | |
end = '%d %s' % (self.end, 'FOLLOWING') | |
elif self.end is not None and self.end == 0: | |
end = 'CURRENT ROW' | |
else: | |
end = 'UNBOUNDED FOLLOWING' | |
return self.template % { | |
'frame_type': self.frame_type, | |
'start': start, | |
'end': end, | |
} | |
def window_frame_start_end(self, connection, start, end): | |
raise NotImplementedError('Subclasses must implement window_frame_start_end().') | |
class ExpressionList(Func): | |
""" | |
An expression containing multiple expressions. Can be used to provide a | |
list of expressions as an argument to another expression, like an | |
ordering clause. | |
""" | |
template = '%(expressions)s' | |
def __init__(self, *expressions, **extra): | |
if not expressions: | |
raise ValueError('%s requires at least one expression.' % self.__class__.__name__) | |
super().__init__(*expressions, **extra) | |
def __str__(self): | |
return self.arg_joiner.join(str(arg) for arg in self.source_expressions) | |
class RowRange(WindowFrame): | |
frame_type = 'ROWS' | |
def window_frame_start_end(self, connection, start, end): | |
return window_frame_rows_start_end(connection.ops, start, end) | |
class ValueRange(WindowFrame): | |
frame_type = 'RANGE' | |
def window_frame_start_end(self, connection, start, end): | |
return window_frame_range_start_end(connection.ops, start, end) | |
def window_frame_rows_start_end(self, start=None, end=None): | |
""" | |
Return SQL for start and end points in an OVER clause window frame. | |
""" | |
return window_frame_start(self, start), window_frame_end(self, end) | |
def window_frame_range_start_end(self, start=None, end=None): | |
return window_frame_rows_start_end(self, start, end) | |
def window_frame_start(self, start): | |
if isinstance(start, int): | |
if start < 0: | |
return '%d %s' % (abs(start), 'PRECEDING') | |
elif start == 0: | |
return 'CURRENT ROW' | |
elif start is None: | |
return 'UNBOUNDED PRECEDING' | |
raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start) | |
def window_frame_end(self, end): | |
if isinstance(end, int): | |
if end == 0: | |
return 'CURRENT ROW' | |
elif end > 0: | |
return '%d %s' % (end, 'FOLLOWING') | |
elif end is None: | |
return 'UNBOUNDED FOLLOWING' | |
raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end) | |
class CumeDist(Func): | |
function = 'CUME_DIST' | |
name = 'CumeDist' | |
output_field = FloatField() | |
window_compatible = True | |
class DenseRank(Func): | |
function = 'DENSE_RANK' | |
name = 'DenseRank' | |
output_field = IntegerField() | |
window_compatible = True | |
class FirstValue(Func): | |
arity = 1 | |
function = 'FIRST_VALUE' | |
name = 'FirstValue' | |
window_compatible = True | |
class LagLeadFunction(Func): | |
window_compatible = True | |
def __init__(self, expression, offset=1, default=None, **extra): | |
if expression is None: | |
raise ValueError( | |
'%s requires a non-null source expression.' % | |
self.__class__.__name__ | |
) | |
if offset is None or offset <= 0: | |
raise ValueError( | |
'%s requires a positive integer for the offset.' % | |
self.__class__.__name__ | |
) | |
args = (expression, offset) | |
if default is not None: | |
args += (default,) | |
super().__init__(*args, **extra) | |
def _resolve_output_field(self): | |
sources = self.get_source_expressions() | |
return sources[0].output_field | |
class Lag(LagLeadFunction): | |
function = 'LAG' | |
name = 'Lag' | |
class LastValue(Func): | |
arity = 1 | |
function = 'LAST_VALUE' | |
name = 'LastValue' | |
window_compatible = True | |
class Lead(LagLeadFunction): | |
function = 'LEAD' | |
name = 'Lead' | |
class NthValue(Func): | |
function = 'NTH_VALUE' | |
name = 'NthValue' | |
window_compatible = True | |
def __init__(self, expression, nth=1, **extra): | |
if expression is None: | |
raise ValueError( | |
'%s requires a non-null source expression.' % self.__class__.__name__) | |
if nth is None or nth <= 0: | |
raise ValueError( | |
'%s requires a positive integer as for nth.' % self.__class__.__name__) | |
super().__init__(expression, nth, **extra) | |
def _resolve_output_field(self): | |
sources = self.get_source_expressions() | |
return sources[0].output_field | |
class Ntile(Func): | |
function = 'NTILE' | |
name = 'Ntile' | |
output_field = IntegerField() | |
window_compatible = True | |
def __init__(self, num_buckets=1, **extra): | |
if num_buckets <= 0: | |
raise ValueError('num_buckets must be greater than 0.') | |
super().__init__(num_buckets, **extra) | |
class PercentRank(Func): | |
function = 'PERCENT_RANK' | |
name = 'PercentRank' | |
output_field = FloatField() | |
window_compatible = True | |
class Rank(Func): | |
function = 'RANK' | |
name = 'Rank' | |
output_field = IntegerField() | |
window_compatible = True | |
class RowNumber(Func): | |
function = 'ROW_NUMBER' | |
name = 'RowNumber' | |
output_field = IntegerField() | |
window_compatible = True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment