Created
December 2, 2019 10:49
-
-
Save twheys/5635a932ca6cfce0d114a86fb55f6c80 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Although the code in appears broken with inspection, it is not. Sly uses some hacky syntax. | |
https://sly.readthedocs.io/en/latest/sly.html#writing-a-parser | |
""" | |
from pypika import ( | |
Bracket, | |
Case, | |
Not, | |
Order, | |
Schema, | |
analytics as an, | |
functions as fn, | |
) | |
from pypika.enums import ( | |
DatePart, | |
SqlTypes, | |
) | |
from pypika.functions import ( | |
Cast, | |
Extract, | |
) | |
from pypika.terms import ( | |
Field, | |
NullValue, | |
Star, | |
ValueWrapper, | |
) | |
from sly import ( | |
Lexer, | |
Parser, | |
) | |
AGGREGATE_FUNCTION_NAMES = {'COUNT', 'SUM', 'SUM_FLOAT', 'MIN', 'MAX', 'AVG', 'STD', 'STDDEV', 'APPROXIMATE_PERCENTILE'} | |
class ExpressionSyntaxError(Exception): | |
pass | |
class PyPikaLexer(Lexer): | |
# Set of token names. This is always required | |
tokens = {NAME, QUOTED_NAME, DECIMAL, INTEGER, STRING, | |
PLUS, MINUS, TIMES, DIVIDE, MODULO, | |
EQ, LT, LE, GT, GE, NE, NE2, | |
TRUE, FALSE, NULL, NULLS, | |
IN, IS, AS, FROM, BY, LIKE, ILIKE, NOT, AND, OR, USING, PARAMETERS, PERCENTILE, | |
CASE, WHEN, THEN, ELSE, END, | |
DBL_PIPE, | |
DISTINCT, BETWEEN, | |
OVER, PARTITION, ORDER, ASC, DESC, IGNORE, | |
CAST, APPROXIMATE_PERCENTILE, EXTRACT, | |
INTEGER_TYPE, FLOAT_TYPE, NUMERIC_TYPE, SIGNED_TYPE, UNSIGNED_TYPE, BOOLEAN_TYPE, CHAR_TYPE, VARCHAR_TYPE, | |
BINARY_TYPE, VARBINARY_TYPE, LONG_TYPE, | |
YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, SECOND, MICROSECOND | |
} | |
special_tokens = { | |
'IN': IN, | |
'IS': IS, | |
'AS': AS, | |
'FROM': FROM, | |
'BY': BY, | |
'NULL': NULL, | |
'NULLS': NULLS, | |
'NOT': NOT, | |
'AND': AND, | |
'OR': OR, | |
'CASE': CASE, | |
'WHEN': WHEN, | |
'THEN': THEN, | |
'ELSE': ELSE, | |
'END': END, | |
'DISTINCT': DISTINCT, | |
'BETWEEN': BETWEEN, | |
'TRUE': TRUE, | |
'FALSE': FALSE, | |
'OVER': OVER, | |
'IGNORE': IGNORE, | |
'PARTITION': PARTITION, | |
'USING': USING, | |
'PARAMETERS': PARAMETERS, | |
'PERCENTILE': PERCENTILE, | |
'ORDER': ORDER, | |
'ASC': ASC, | |
'DESC': DESC, | |
# Special Functions | |
'CAST': CAST, | |
'APPROXIMATE_PERCENTILE': APPROXIMATE_PERCENTILE, | |
'EXTRACT': EXTRACT, | |
'LIKE': LIKE, | |
'ILIKE': ILIKE, | |
# TYPES | |
'INTEGER': INTEGER_TYPE, | |
'FLOAT': FLOAT_TYPE, | |
'NUMERIC': NUMERIC_TYPE, | |
'SIGNED': SIGNED_TYPE, | |
'UNSIGNED': UNSIGNED_TYPE, | |
'BOOLEAN': BOOLEAN_TYPE, | |
'CHAR': CHAR_TYPE, | |
'VARCHAR': VARCHAR_TYPE, | |
'BINARY': BINARY_TYPE, | |
'VARBINARY': VARBINARY_TYPE, | |
'LONG': LONG_TYPE, | |
# TIME UNITS | |
'YEAR': YEAR, | |
'QUARTER': QUARTER, | |
'MONTH': MONTH, | |
'WEEK': WEEK, | |
'DAY': DAY, | |
'HOUR': HOUR, | |
'MINUTE': MINUTE, | |
'SECOND': SECOND, | |
'MICROSECOND': MICROSECOND, | |
} | |
literals = {r'.', r',', r'(', r')', r'"', r"'"} | |
# String containing ignored characters | |
ignore = ' \t' | |
@_(r'\"([^"\n]|"")*\"') | |
def QUOTED_NAME(self, t): | |
t.value = t.value[1:-1] | |
return t | |
@_(r"\'([^'\n]|'')*\'") | |
def STRING(self, t): | |
t.value = t.value[1:-1].replace("''", "'") | |
return t | |
@_(r'(\d*\.\d+)([eE][-+]?[0-9]+)?', | |
r'(\d+\.\d*)([eE][-+]?[0-9]+)?', | |
r'(\d+)[eE][-+]?[0-9]+') | |
def DECIMAL(self, t): | |
t.value = float(t.value) | |
return t | |
@_(r'\d+') | |
def INTEGER(self, t): | |
t.value = int(t.value) | |
return t | |
# Regular expression rules for tokens | |
PLUS = r'\+' | |
MINUS = r'-' | |
TIMES = r'\*' | |
DIVIDE = r'/' | |
MODULO = r'%' | |
EQ = r'=' | |
NE = r'<>' | |
NE2 = r'!=' | |
LE = r'<=' | |
LT = r'<' | |
GE = r'>=' | |
GT = r'>' | |
DBL_PIPE = r'\|\|' | |
@_(r'[a-zA-Z][a-zA-Z0-9_@#]*') | |
def NAME(self, t): | |
upper_value = t.value.upper() | |
if upper_value in self.special_tokens: | |
t.type = self.special_tokens[upper_value] | |
return t | |
@_(r'\n+') | |
def newline(self, t): | |
self.lineno += t.value.count('\n') | |
def error(self, t): | |
raise ExpressionSyntaxError(f"Syntax Error: illegal value '{t.value}' on line {self.lineno}:{self.index}") | |
def build_case(case, when_then_list, else_=None): | |
for when, then in when_then_list: | |
case = case.when(when, then) | |
if else_ is not None: | |
case = case.else_(else_) | |
return case | |
def build_analytic(func, partitions=(), orders=()): | |
for partition in partitions: | |
func = func.over(partition) | |
for order, by in orders: | |
func = func.orderby(order, order=by) | |
return func | |
class PyPikaParser(Parser): | |
# Uncomment this in order to write debug logs | |
# debugfile = 'parser.out' | |
# Get the token list from the lexer (required) | |
tokens = PyPikaLexer.tokens | |
def __init__(self, tables): | |
super().__init__() | |
self.tables = tables | |
def _get_table_for_alias(self, alias): | |
if alias not in self.tables: | |
raise Exception(f'Invalid table name [{alias}]') | |
return self.tables[alias] | |
precedence = ( | |
('left', DBL_PIPE), | |
('left', PLUS, MINUS), | |
('left', TIMES, DIVIDE, MODULO), | |
('right', UMINUS), | |
('right', NOT), | |
) | |
@_('expression OR and_condition') | |
def expression(self, p): | |
return p.expression | p.and_condition | |
@_('and_condition') | |
def expression(self, p): | |
return p.and_condition | |
@_('and_condition AND condition') | |
def and_condition(self, p): | |
return p.and_condition & p.condition | |
@_('condition') | |
def and_condition(self, p): | |
return p.condition | |
@_('operand') | |
def condition(self, p): | |
return p.operand | |
@_('operand EQ operand') | |
def condition(self, p): | |
return p.operand0 == p.operand1 | |
@_('operand NE operand', | |
'operand NE2 operand') | |
def condition(self, p): | |
return p.operand0 != p.operand1 | |
@_('operand GT operand') | |
def condition(self, p): | |
return p.operand0 > p.operand1 | |
@_('operand GE operand') | |
def condition(self, p): | |
return p.operand0 >= p.operand1 | |
@_('operand LT operand') | |
def condition(self, p): | |
return p.operand0 < p.operand1 | |
@_('operand LE operand') | |
def condition(self, p): | |
return p.operand0 <= p.operand1 | |
@_('operand IN "(" operand_list ")"') | |
def condition(self, p): | |
return p.operand.isin(p.operand_list) | |
@_('operand NOT IN "(" operand_list ")"') | |
def condition(self, p): | |
return p.operand.notin(p.operand_list) | |
@_('operand_list "," operand') | |
def operand_list(self, p): | |
return p.operand_list + [p.operand] | |
@_('operand') | |
def operand_list(self, p): | |
return [p.operand] | |
@_('operand LIKE operand') | |
def condition(self, p): | |
return p.operand0.like(p.operand1) | |
@_('operand NOT LIKE operand') | |
def condition(self, p): | |
return p.operand0.not_like(p.operand1) | |
@_('operand ILIKE operand') | |
def condition(self, p): | |
return p.operand0.ilike(p.operand1) | |
@_('operand NOT ILIKE operand') | |
def condition(self, p): | |
return p.operand0.not_ilike(p.operand1) | |
@_('operand BETWEEN operand AND operand') | |
def condition(self, p): | |
return p.operand0.between(p.operand1, p.operand2) | |
@_('operand NOT BETWEEN operand AND operand') | |
def condition(self, p): | |
return p.operand0.not_between(p.operand1, p.operand2) | |
@_('operand IS NULL') | |
def condition(self, p): | |
return p.operand.isnull() | |
@_('operand IS NOT NULL') | |
def condition(self, p): | |
return p.operand.notnull() | |
@_('NOT expression') | |
def condition(self, p): | |
return Not(p.expression) | |
@_('"(" expression ")"') | |
def condition(self, p): | |
return Bracket(p.expression) | |
@_('factor DBL_PIPE factor') | |
def operand(self, p): | |
if isinstance(p.factor0, fn.Concat): | |
p.factor0.args += [p.factor1] | |
return p.factor0 | |
return fn.Concat(p.factor0, p.factor1) | |
@_('factor') | |
def operand(self, p): | |
return p.factor | |
@_('term TIMES term') | |
def factor(self, p): | |
return p.term0 * p.term1 | |
@_('term DIVIDE term') | |
def factor(self, p): | |
return p.term0 / p.term1 | |
@_('term MODULO term') | |
def factor(self, p): | |
return p.term0 % p.term1 | |
@_('term PLUS term') | |
def factor(self, p): | |
return p.term0 + p.term1 | |
@_('term MINUS term') | |
def factor(self, p): | |
return p.term0 - p.term1 | |
@_('MINUS term %prec UMINUS') | |
def factor(self, p): | |
return -p.term | |
@_('term') | |
def factor(self, p): | |
return p.term | |
@_( | |
'value', | |
'function', | |
'case', | |
'case_when', | |
'operand', | |
) | |
def term(self, p): | |
return p[0] | |
@_('"(" operand ")"') | |
def term(self, p): | |
return Bracket(p.operand) | |
@_('alias "." column_ref') | |
def term(self, p): | |
table = self._get_table_for_alias(p.alias) | |
return Field(p.column_ref, table=table) | |
@_('column_ref') | |
def term(self, p): | |
table_keys = list(self.tables.keys()) | |
if len(table_keys) != 1: | |
raise Exception('Ambiguous column name. When using more than one table, column names must be prefixed.') | |
table_key = table_keys[0] | |
table = self.tables[table_key] | |
return Field(p.column_ref, table=table) | |
@_( | |
'string', | |
'numeric', | |
'boolean', | |
'constant', | |
) | |
def value(self, p): | |
return p[0] | |
@_('null') | |
def value(self, p): | |
return p.null | |
@_('CASE term when_then_list ELSE expression END', ) | |
def case(self, p): | |
return build_case(Case(p.term), p.when_then_list, p.expression) | |
@_('CASE term when_then_list END') | |
def case(self, p): | |
return build_case(Case(p.term), p.when_then_list) | |
@_('CASE when_then_list ELSE expression END') | |
def case_when(self, p): | |
return build_case(Case(), p.when_then_list, p.expression) | |
@_('CASE when_then_list END') | |
def case_when(self, p): | |
return build_case(Case(), p.when_then_list) | |
@_('when_then_list when_then_stmt') | |
def when_then_list(self, p): | |
return p.when_then_list + [p.when_then_stmt] | |
@_('when_then_stmt') | |
def when_then_list(self, p): | |
return [p.when_then_stmt] | |
@_('WHEN expression THEN expression') | |
def when_then_stmt(self, p): | |
return p.expression0, p.expression1 | |
@_('NAME', 'QUOTED_NAME') | |
def alias(self, p): | |
return p[0] | |
@_('alias') | |
def column_ref(self, p): | |
return p.alias | |
@_('NULL') | |
def null(self, p): | |
return NullValue() | |
@_('STRING') | |
def string(self, p): | |
return ValueWrapper(p.STRING) | |
@_('DECIMAL', | |
'INTEGER') | |
def numeric(self, p): | |
return ValueWrapper(p[0]) | |
@_('TRUE') | |
def boolean(self, p): | |
return ValueWrapper(True) | |
@_('FALSE') | |
def boolean(self, p): | |
return ValueWrapper(False) | |
@_('time_unit') | |
def constant(self, p): | |
return p[0] | |
@_('data_type_with_arg "(" INTEGER ")"') | |
def data_type(self, p): | |
return p.data_type_with_arg(p.INTEGER) | |
@_('data_type_with_arg') | |
def data_type(self, p): | |
return p.data_type_with_arg | |
@_('LONG_TYPE VARCHAR_TYPE') | |
def data_type(self, p): | |
return SqlTypes.LONG_VARCHAR | |
@_('LONG_TYPE VARBINARY_TYPE') | |
def data_type(self, p): | |
return SqlTypes.LONG_VARBINARY | |
@_('INTEGER_TYPE', | |
'FLOAT_TYPE', | |
'NUMERIC_TYPE', | |
'SIGNED_TYPE', | |
'UNSIGNED_TYPE', | |
'BOOLEAN_TYPE') | |
def data_type(self, p): | |
return getattr(SqlTypes, p[0].upper()) | |
@_('CHAR_TYPE ', | |
'VARCHAR_TYPE', | |
'BINARY_TYPE', | |
'VARBINARY_TYPE') | |
def data_type_with_arg(self, p): | |
return getattr(SqlTypes, p[0].upper()) | |
@_('YEAR', | |
'QUARTER', | |
'MONTH', | |
'WEEK', | |
'DAY', | |
'HOUR', | |
'MINUTE', | |
'SECOND', | |
'MICROSECOND') | |
def time_unit(self, p): | |
time_unit_string = p[0] | |
return DatePart[str.lower(time_unit_string)] | |
# FUNCTIONS | |
@_( | |
'cast', | |
'extract', | |
'analytic', | |
'approximate_percentile', | |
) | |
def function(self, p): | |
return p[0] | |
@_('NAME "(" DISTINCT arguments_list ")"') | |
def function(self, p): | |
upper_name = p.NAME.upper() | |
return fn.DistinctOptionFunction(upper_name, | |
*p.arguments_list) \ | |
.distinct() | |
@_('NAME "(" ")"', | |
'NAME "(" arguments_list ")"') | |
def function(self, p): | |
upper_name = p.NAME.upper() | |
args = p.arguments_list \ | |
if 'arguments_list' in p._namemap \ | |
else [] | |
func = fn.AggregateFunction \ | |
if upper_name in AGGREGATE_FUNCTION_NAMES \ | |
else fn.Function | |
return func(upper_name, | |
*args) | |
@_('alias "." alias "(" ")"', | |
'alias "." alias "(" arguments_list ")"') | |
def function(self, p): | |
schema = p.alias0 | |
upper_name = p.alias1 | |
args = p.arguments_list \ | |
if 'arguments_list' in p._namemap \ | |
else [] | |
func = fn.AggregateFunction \ | |
if upper_name in AGGREGATE_FUNCTION_NAMES \ | |
else fn.Function | |
return func(upper_name, | |
schema=Schema(schema), | |
*args) | |
@_('TIMES') | |
def arguments_list(self, p): | |
return [Star()] | |
@_('alias "." TIMES') | |
def arguments_list(self, p): | |
table = self._get_table_for_alias(p.alias) | |
return [Star(table)] | |
@_('arguments_list "," expression') | |
def arguments_list(self, p): | |
return p.arguments_list + [p.expression] | |
@_('expression') | |
def arguments_list(self, p): | |
return [p.expression] | |
@_('CAST "(" expression AS data_type ")"') | |
def cast(self, p): | |
return Cast(p.expression, p.data_type) | |
@_('APPROXIMATE_PERCENTILE "(" term USING PARAMETERS PERCENTILE EQ DECIMAL ")"') | |
def approximate_percentile(self, p): | |
return fn.ApproximatePercentile(p.term, p.DECIMAL) | |
@_('EXTRACT "(" time_unit FROM expression ")"') | |
def extract(self, p): | |
return Extract(p.time_unit, p.expression) | |
# ANALYTIC FUNCTIONS | |
@_('analytic_function OVER "(" partition_by ")"') | |
def analytic(self, p): | |
return build_analytic(p.analytic_function, partitions=p.partition_by) | |
@_('analytic_function OVER "(" order_by ")"') | |
def analytic(self, p): | |
return build_analytic(p.analytic_function, orders=p.order_by) | |
@_('analytic_function OVER "(" partition_by order_by ")"') | |
def analytic(self, p): | |
return build_analytic(p.analytic_function, partitions=p.partition_by, orders=p.order_by) | |
@_('function_ignore_nulls') | |
def analytic_function(self, p): | |
return p.function_ignore_nulls | |
@_('function') | |
def analytic_function(self, p): | |
return an.AnalyticFunction(p.function.name, *p.function.args) | |
@_('NAME "(" arguments_list IGNORE NULLS ")"') | |
def function_ignore_nulls(self, p): | |
upper_name = p.NAME.upper() | |
return an.IgnoreNullsAnalyticFunction(upper_name, | |
*p.arguments_list) \ | |
.ignore_nulls() | |
@_('PARTITION BY arguments_list') | |
def partition_by(self, p): | |
return p.arguments_list | |
@_('ORDER BY arguments_list_orientation') | |
def order_by(self, p): | |
return p.arguments_list_orientation | |
@_('arguments_list_orientation "," expression orientation') | |
def arguments_list_orientation(self, p): | |
return p.arguments_list_orientation + [(p.expression, p.orientation)] | |
@_('arguments_list_orientation "," expression') | |
def arguments_list_orientation(self, p): | |
return p.arguments_list_orientation + [(p.expression, None)] | |
@_('expression orientation') | |
def arguments_list_orientation(self, p): | |
return [(p.expression, p.orientation)] | |
@_('expression') | |
def arguments_list_orientation(self, p): | |
return [(p.expression, None)] | |
@_('ASC') | |
def orientation(self, p): | |
return Order.asc | |
@_('DESC') | |
def orientation(self, p): | |
return Order.desc | |
def error(self, token): | |
if token: | |
lineno = getattr(token, 'lineno', 0) | |
index = getattr(token, 'index', 0) | |
raise ExpressionSyntaxError(f'Syntax error on line:column {lineno}:{index}, ' | |
f'unexpected value \'{token.value}\'') | |
raise ExpressionSyntaxError('Parse error in input. Unexpected end of expression.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment