Skip to content

Instantly share code, notes, and snippets.

@hassanselim0
Last active November 2, 2019 10:11
Show Gist options
  • Save hassanselim0/8123bf889c4df4879993ae11bab335f6 to your computer and use it in GitHub Desktop.
Save hassanselim0/8123bf889c4df4879993ae11bab335f6 to your computer and use it in GitHub Desktop.
Windowing Functions for Django 1.11
# Back-porting DB Windowing Functions from Django 2 😀
# Note, this is only partially-tested on PostgreSQL
# New DB opts have been replaced with their string values
from django.db import connection
from django.db.models import Func, Value, FloatField, IntegerField
from django.db.models.expressions import BaseExpression, Expression
class Window(Expression):
template = '%(expression)s OVER (%(window)s)'
# Although the main expression may either be an aggregate or an
# expression with an aggregate function, the GROUP BY that will
# be introduced in the query as a result is not desired.
contains_aggregate = False
contains_over_clause = True
filterable = False
def __init__(
self, expression, partition_by=None, order_by=None, frame=None, output_field=None):
self.partition_by = partition_by
self.order_by = order_by
self.frame = frame
if not getattr(expression, 'window_compatible', False):
raise ValueError(
"Expression '%s' isn't compatible with OVER clauses." %
expression.__class__.__name__
)
if self.partition_by is not None:
if not isinstance(self.partition_by, (tuple, list)):
self.partition_by = (self.partition_by,)
self.partition_by = ExpressionList(*self.partition_by)
if self.order_by is not None:
if isinstance(self.order_by, (list, tuple)):
self.order_by = ExpressionList(*self.order_by)
elif not isinstance(self.order_by, BaseExpression):
raise ValueError(
'order_by must be either an Expression or a sequence of '
'expressions.'
)
super().__init__(output_field=output_field)
self.source_expression = self._parse_expressions(expression)[0]
def _resolve_output_field(self):
self._output_field = self.source_expression.output_field
return self._output_field
def get_source_expressions(self):
if self.frame:
return [self.source_expression, self.partition_by, self.order_by, self.frame]
else:
return [self.source_expression, self.partition_by, self.order_by]
def set_source_expressions(self, exprs):
if len(exprs) == 4:
self.source_expression, self.partition_by, self.order_by, self.frame = exprs
else:
self.source_expression, self.partition_by, self.order_by = exprs
def as_sql(self, compiler, connection, function=None, template=None):
connection.ops.check_expression_support(self)
expr_sql, params = compiler.compile(self.source_expression)
window_sql, window_params = [], []
if self.partition_by is not None:
sql_expr, sql_params = self.partition_by.as_sql(
compiler=compiler, connection=connection,
template='PARTITION BY %(expressions)s',
)
window_sql.extend(sql_expr)
window_params.extend(sql_params)
if self.order_by is not None:
window_sql.append(' ORDER BY ')
order_sql, order_params = compiler.compile(self.order_by)
window_sql.extend(''.join(order_sql))
window_params.extend(order_params)
if self.frame:
frame_sql, frame_params = compiler.compile(self.frame)
window_sql.extend(' ' + frame_sql)
window_params.extend(frame_params)
params.extend(window_params)
template = template or self.template
return template % {
'expression': expr_sql,
'window': ''.join(window_sql).strip()
}, params
def __str__(self):
return '{} OVER ({}{}{})'.format(
str(self.source_expression),
'PARTITION BY ' + str(self.partition_by) if self.partition_by else '',
'ORDER BY ' + str(self.order_by) if self.order_by else '',
str(self.frame or ''),
)
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self)
def get_group_by_cols(self):
return []
class WindowFrame(Expression):
"""
Model the frame clause in window expressions. There are two types of frame
clauses which are subclasses, however, all processing and validation (by no
means intended to be complete) is done here. Thus, providing an end for a
frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
row in the frame).
"""
template = '%(frame_type)s BETWEEN %(start)s AND %(end)s'
def __init__(self, start=None, end=None):
self.start = start
self.end = end
def set_source_expressions(self, exprs):
self.start, self.end = exprs
def get_source_expressions(self):
return [Value(self.start), Value(self.end)]
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
start, end = self.window_frame_start_end(connection, self.start.value, self.end.value)
return self.template % {
'frame_type': self.frame_type,
'start': start,
'end': end,
}, []
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self)
def get_group_by_cols(self):
return []
def __str__(self):
if self.start is not None and self.start < 0:
start = '%d %s' % (abs(self.start), 'PRECEDING')
elif self.start is not None and self.start == 0:
start = 'CURRENT ROW'
else:
start = 'UNBOUNDED PRECEDING'
if self.end is not None and self.end > 0:
end = '%d %s' % (self.end, 'FOLLOWING')
elif self.end is not None and self.end == 0:
end = 'CURRENT ROW'
else:
end = 'UNBOUNDED FOLLOWING'
return self.template % {
'frame_type': self.frame_type,
'start': start,
'end': end,
}
def window_frame_start_end(self, connection, start, end):
raise NotImplementedError('Subclasses must implement window_frame_start_end().')
class ExpressionList(Func):
"""
An expression containing multiple expressions. Can be used to provide a
list of expressions as an argument to another expression, like an
ordering clause.
"""
template = '%(expressions)s'
def __init__(self, *expressions, **extra):
if not expressions:
raise ValueError('%s requires at least one expression.' % self.__class__.__name__)
super().__init__(*expressions, **extra)
def __str__(self):
return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
class RowRange(WindowFrame):
frame_type = 'ROWS'
def window_frame_start_end(self, connection, start, end):
return window_frame_rows_start_end(connection.ops, start, end)
class ValueRange(WindowFrame):
frame_type = 'RANGE'
def window_frame_start_end(self, connection, start, end):
return window_frame_range_start_end(connection.ops, start, end)
def window_frame_rows_start_end(self, start=None, end=None):
"""
Return SQL for start and end points in an OVER clause window frame.
"""
return window_frame_start(self, start), window_frame_end(self, end)
def window_frame_range_start_end(self, start=None, end=None):
return window_frame_rows_start_end(self, start, end)
def window_frame_start(self, start):
if isinstance(start, int):
if start < 0:
return '%d %s' % (abs(start), 'PRECEDING')
elif start == 0:
return 'CURRENT ROW'
elif start is None:
return 'UNBOUNDED PRECEDING'
raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start)
def window_frame_end(self, end):
if isinstance(end, int):
if end == 0:
return 'CURRENT ROW'
elif end > 0:
return '%d %s' % (end, 'FOLLOWING')
elif end is None:
return 'UNBOUNDED FOLLOWING'
raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end)
class CumeDist(Func):
function = 'CUME_DIST'
name = 'CumeDist'
output_field = FloatField()
window_compatible = True
class DenseRank(Func):
function = 'DENSE_RANK'
name = 'DenseRank'
output_field = IntegerField()
window_compatible = True
class FirstValue(Func):
arity = 1
function = 'FIRST_VALUE'
name = 'FirstValue'
window_compatible = True
class LagLeadFunction(Func):
window_compatible = True
def __init__(self, expression, offset=1, default=None, **extra):
if expression is None:
raise ValueError(
'%s requires a non-null source expression.' %
self.__class__.__name__
)
if offset is None or offset <= 0:
raise ValueError(
'%s requires a positive integer for the offset.' %
self.__class__.__name__
)
args = (expression, offset)
if default is not None:
args += (default,)
super().__init__(*args, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Lag(LagLeadFunction):
function = 'LAG'
name = 'Lag'
class LastValue(Func):
arity = 1
function = 'LAST_VALUE'
name = 'LastValue'
window_compatible = True
class Lead(LagLeadFunction):
function = 'LEAD'
name = 'Lead'
class NthValue(Func):
function = 'NTH_VALUE'
name = 'NthValue'
window_compatible = True
def __init__(self, expression, nth=1, **extra):
if expression is None:
raise ValueError(
'%s requires a non-null source expression.' % self.__class__.__name__)
if nth is None or nth <= 0:
raise ValueError(
'%s requires a positive integer as for nth.' % self.__class__.__name__)
super().__init__(expression, nth, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Ntile(Func):
function = 'NTILE'
name = 'Ntile'
output_field = IntegerField()
window_compatible = True
def __init__(self, num_buckets=1, **extra):
if num_buckets <= 0:
raise ValueError('num_buckets must be greater than 0.')
super().__init__(num_buckets, **extra)
class PercentRank(Func):
function = 'PERCENT_RANK'
name = 'PercentRank'
output_field = FloatField()
window_compatible = True
class Rank(Func):
function = 'RANK'
name = 'Rank'
output_field = IntegerField()
window_compatible = True
class RowNumber(Func):
function = 'ROW_NUMBER'
name = 'RowNumber'
output_field = IntegerField()
window_compatible = True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment