Last active
May 29, 2019 22:24
-
-
Save knoguchi/5c56a4ddfc6ffc5f920e7ce760a4a711 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import operator | |
from sqlalchemy import func, column, select, text, between | |
from sqlalchemy.schema import * | |
from sqlalchemy.sql import sqltypes | |
from sqlalchemy.sql.functions import GenericFunction | |
class hyperunique(GenericFunction): | |
type = sqltypes.Integer | |
def __init__(self, arg, **kwargs): | |
GenericFunction.__init__(self, arg, **kwargs) | |
def sq(s): | |
return "'{}'".format(s.replace("'", "\\'")) | |
def agg_to_column(agg): | |
agg_func_map = dict( | |
min=func.min, | |
max=func.max, | |
longSum=func.sum, | |
longMin=func.min, | |
longMax=func.max, | |
doubleSum=func.sum, | |
doubleMin=func.min, | |
doubleMax=func.max, | |
count=func.count, | |
hyperunique=hyperunique, | |
# cardinality=cardinality, | |
# filtered=filtered, | |
# javascript=javascript, | |
) | |
col = column(agg['fieldName']) | |
agg_func = agg_func_map.get(agg['type']) | |
col = agg_func(col).label(agg['name']) | |
return col | |
def post_agg_to_column(pagg, aggs): | |
op_map = {"/": operator.truediv, "*": operator.mul, "+": operator.add, "-": operator.sub} | |
t = pagg['type'] | |
if t == 'arithmetic': | |
lhs = post_agg_to_column(pagg['fields'][0], aggs) | |
rhs = post_agg_to_column(pagg['fields'][1], aggs) | |
fn = op_map[pagg['fn']] | |
col = fn(lhs, rhs) | |
elif t == 'fieldAccess': | |
fieldName = pagg['fieldName'] | |
col = aggs[fieldName] | |
name = pagg.get('name') | |
if name: | |
col = col.label(name) | |
aggs[name] = col | |
return col | |
def tr_filter_one(filters): | |
op_map = {"and": operator.and_, "or": operator.or_} | |
t = filters.get('type') | |
if t == 'selector': | |
v = filters['value'] | |
# FIXME. There must be a better way to create a literal sting | |
if type(v) in (int, float): | |
c = text(str(v)) | |
else: | |
c = text(sq(v)) | |
return operator.eq(column(filters['dimension']), c) | |
elif t == 'columnComparison': | |
dims = [column(c) for c in filters['dimensions']] | |
return operator.eq(column(dims[0]), column(dims[1])) | |
elif t in op_map: | |
fields = [tr_filter_one(f) for f in filters.get('fields')] | |
return functools.reduce(op_map.get(t), fields) | |
elif t in ('regex', 'javascript', 'extraction'): | |
raise NotImplemented | |
raise ValueError | |
def tr_intervals(intervals): | |
""" | |
Support 1 <start>/<end> only e.g. | |
["2012-01-01T00:00:00.000/2012-01-03T00:00:00.000"] | |
""" | |
lhs, rhs = intervals[0].split('/') | |
return between(column('__time'), text(sq(lhs)), text(sq(rhs))) | |
def tr_filter(sql, intervals, filters): | |
if filters: | |
return sql.where(operator.and_(tr_intervals(intervals), tr_filter_one(filters))) | |
return sql.where(tr_intervals(intervals)) | |
def tr_having(sql, filters, aggs): | |
t = filters.get('type') | |
if t == 'filter': | |
# any druid query filter | |
pass | |
elif t in ('equalTo', 'greaterThan', 'lessThan'): | |
op_map = dict(equalTo=operator.eq, greaterThan=operator.gt, lessThan=operator.lt) | |
# numeric value | |
agg = aggs[filters["aggregation"]] | |
value = filters["value"] | |
return sql.having(op_map.get(t)(agg, text(str(value)))) | |
elif t == 'dimSelector': | |
# dimension value | |
pass | |
elif t in ('and', 'or', 'not'): | |
# logical expressions | |
pass | |
def tr_limit_spec(sql, limit_spec): | |
t = limit_spec['type'] | |
if t == 'default': | |
limit = limit_spec['limit'] | |
orders = [] | |
for c in limit_spec['columns'] or []: | |
if type(c) == str: | |
# column is specified by string, order is ASCENDING | |
orders.append(column(c).asc()) | |
elif type(c) == dict: | |
col = column(c['dimension']) | |
if c['direction'] == 'descending': | |
col = col.desc() | |
else: | |
col = col.asc() | |
orders.append(col) | |
sql = sql.order_by(*orders) | |
if limit: | |
sql = sql.limit(text(str(limit))) | |
return sql | |
def to_sql(query_dict, excluded_columns=None): | |
# columns are aggregations and postAggregations | |
aggs = {} | |
columns = [] | |
excluded_columns = excluded_columns or [] | |
for agg in query_dict.get('aggregations') or []: | |
column = agg_to_column(agg) | |
aggs[column.name] = column | |
if column.name not in excluded_columns: | |
columns.append(column) | |
for pagg in query_dict.get('postAggregations') or []: | |
column = post_agg_to_column(pagg, aggs) | |
aggs[column.name] = column | |
if column.name not in excluded_columns: | |
columns.append(column) | |
# building SELECT ... FROM ... | |
tbl = Table(query_dict['dataSource'], MetaData()) | |
sql = select(columns, from_obj=tbl) | |
# apply WHERE | |
intervals = query_dict['intervals'] | |
filters = query_dict.get('filter') | |
sql = tr_filter(sql, intervals, filters) | |
# granularity is a part of GROUP BY clause | |
buckets = ['second', 'minute', 'fifteen_minute', 'thirty_minute', 'hour', 'day', 'week', 'month', 'quarter', 'year'] | |
# translate granularity to GROUP BY FLOOR(__time TO <bucket size>) | |
granularity = query_dict['granularity'] | |
assert granularity in buckets | |
groups = [text('FLOOR(__time TO {})'.format(granularity.upper()))] | |
# translate dimensions to GROUP BY <dimensions> | |
if query_dict.get('queryType') == 'groupBy': | |
dims = [text(x) for x in query_dict.get('dimensions', [])] | |
if groups: | |
groups.extend(dims) | |
else: | |
groups = dims | |
# apply GROUP BY | |
if groups: | |
sql = sql.group_by(*groups) | |
having = query_dict.get('having') | |
if having: | |
sql = tr_having(sql, having, aggs) | |
# apply ORDER BY and LIMIT | |
limit_spec = query_dict.get('limitSpec') | |
if limit_spec: | |
sql = tr_limit_spec(sql, limit_spec) | |
return sql | |
if __name__ == '__main__': | |
q = { | |
"queryType": "groupBy", | |
"dataSource": "sample_datasource", | |
"granularity": "day", | |
"dimensions": ["country", "device"], | |
"limitSpec": {"type": "default", "limit": 5000, "columns": [ | |
"country", | |
{ | |
"dimension": "data_transfer", | |
"direction": "descending", | |
"dimensionOrder": "numeric" | |
} | |
]}, | |
"filter": { | |
"type": "and", | |
"fields": [ | |
{"type": "selector", "dimension": "carrier", "value": "AT&T"}, | |
{"type": "or", | |
"fields": [ | |
{"type": "selector", "dimension": "make", "value": "Apple"}, | |
{"type": "selector", "dimension": "make", "value": "Samsung"} | |
] | |
} | |
] | |
}, | |
"aggregations": [ | |
{"type": "longSum", "name": "total_usage", "fieldName": "user_count"}, | |
{"type": "doubleSum", "name": "data_transfer", "fieldName": "data_transfer"} | |
], | |
"postAggregations": [ | |
{"type": "arithmetic", | |
"name": "avg_usage", | |
"fn": "/", | |
"fields": [ | |
{"type": "fieldAccess", "fieldName": "data_transfer"}, | |
{"type": "fieldAccess", "fieldName": "total_usage"} | |
] | |
} | |
], | |
"intervals": ["2012-01-01T00:00:00.000/2012-01-03T00:00:00.000"], | |
"having": { | |
"type": "greaterThan", | |
"aggregation": "total_usage", | |
"value": 100 | |
} | |
} | |
print(to_sql(q)) | |
""" | |
SELECT sum(user_count) AS total_usage, sum(data_transfer) AS data_transfer, sum(data_transfer) / sum(user_count) AS avg_usage | |
FROM sample_datasource | |
WHERE __time BETWEEN '2012-01-01T00:00:00.000' AND '2012-01-03T00:00:00.000' AND carrier = 'AT&T' AND (make = 'Apple' OR make = 'Samsung') GROUP BY FLOOR(__time TO DAY), country, device | |
HAVING sum(user_count) > 100 ORDER BY country ASC, data_transfer DESC | |
LIMIT 5000 | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
WIP.
sql.compile(dialect=DruidDialect)
.