Skip to content

Instantly share code, notes, and snippets.

@knoguchi
Last active May 29, 2019 22:24
Show Gist options
  • Save knoguchi/5c56a4ddfc6ffc5f920e7ce760a4a711 to your computer and use it in GitHub Desktop.
Save knoguchi/5c56a4ddfc6ffc5f920e7ce760a4a711 to your computer and use it in GitHub Desktop.
import functools
import operator
from sqlalchemy import func, column, select, text, between
from sqlalchemy.schema import *
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.functions import GenericFunction
class hyperunique(GenericFunction):
type = sqltypes.Integer
def __init__(self, arg, **kwargs):
GenericFunction.__init__(self, arg, **kwargs)
def sq(s):
return "'{}'".format(s.replace("'", "\\'"))
def agg_to_column(agg):
agg_func_map = dict(
min=func.min,
max=func.max,
longSum=func.sum,
longMin=func.min,
longMax=func.max,
doubleSum=func.sum,
doubleMin=func.min,
doubleMax=func.max,
count=func.count,
hyperunique=hyperunique,
# cardinality=cardinality,
# filtered=filtered,
# javascript=javascript,
)
col = column(agg['fieldName'])
agg_func = agg_func_map.get(agg['type'])
col = agg_func(col).label(agg['name'])
return col
def post_agg_to_column(pagg, aggs):
op_map = {"/": operator.truediv, "*": operator.mul, "+": operator.add, "-": operator.sub}
t = pagg['type']
if t == 'arithmetic':
lhs = post_agg_to_column(pagg['fields'][0], aggs)
rhs = post_agg_to_column(pagg['fields'][1], aggs)
fn = op_map[pagg['fn']]
col = fn(lhs, rhs)
elif t == 'fieldAccess':
fieldName = pagg['fieldName']
col = aggs[fieldName]
name = pagg.get('name')
if name:
col = col.label(name)
aggs[name] = col
return col
def tr_filter_one(filters):
op_map = {"and": operator.and_, "or": operator.or_}
t = filters.get('type')
if t == 'selector':
v = filters['value']
# FIXME. There must be a better way to create a literal sting
if type(v) in (int, float):
c = text(str(v))
else:
c = text(sq(v))
return operator.eq(column(filters['dimension']), c)
elif t == 'columnComparison':
dims = [column(c) for c in filters['dimensions']]
return operator.eq(column(dims[0]), column(dims[1]))
elif t in op_map:
fields = [tr_filter_one(f) for f in filters.get('fields')]
return functools.reduce(op_map.get(t), fields)
elif t in ('regex', 'javascript', 'extraction'):
raise NotImplemented
raise ValueError
def tr_intervals(intervals):
"""
Support 1 <start>/<end> only e.g.
["2012-01-01T00:00:00.000/2012-01-03T00:00:00.000"]
"""
lhs, rhs = intervals[0].split('/')
return between(column('__time'), text(sq(lhs)), text(sq(rhs)))
def tr_filter(sql, intervals, filters):
if filters:
return sql.where(operator.and_(tr_intervals(intervals), tr_filter_one(filters)))
return sql.where(tr_intervals(intervals))
def tr_having(sql, filters, aggs):
t = filters.get('type')
if t == 'filter':
# any druid query filter
pass
elif t in ('equalTo', 'greaterThan', 'lessThan'):
op_map = dict(equalTo=operator.eq, greaterThan=operator.gt, lessThan=operator.lt)
# numeric value
agg = aggs[filters["aggregation"]]
value = filters["value"]
return sql.having(op_map.get(t)(agg, text(str(value))))
elif t == 'dimSelector':
# dimension value
pass
elif t in ('and', 'or', 'not'):
# logical expressions
pass
def tr_limit_spec(sql, limit_spec):
t = limit_spec['type']
if t == 'default':
limit = limit_spec['limit']
orders = []
for c in limit_spec['columns'] or []:
if type(c) == str:
# column is specified by string, order is ASCENDING
orders.append(column(c).asc())
elif type(c) == dict:
col = column(c['dimension'])
if c['direction'] == 'descending':
col = col.desc()
else:
col = col.asc()
orders.append(col)
sql = sql.order_by(*orders)
if limit:
sql = sql.limit(text(str(limit)))
return sql
def to_sql(query_dict, excluded_columns=None):
# columns are aggregations and postAggregations
aggs = {}
columns = []
excluded_columns = excluded_columns or []
for agg in query_dict.get('aggregations') or []:
column = agg_to_column(agg)
aggs[column.name] = column
if column.name not in excluded_columns:
columns.append(column)
for pagg in query_dict.get('postAggregations') or []:
column = post_agg_to_column(pagg, aggs)
aggs[column.name] = column
if column.name not in excluded_columns:
columns.append(column)
# building SELECT ... FROM ...
tbl = Table(query_dict['dataSource'], MetaData())
sql = select(columns, from_obj=tbl)
# apply WHERE
intervals = query_dict['intervals']
filters = query_dict.get('filter')
sql = tr_filter(sql, intervals, filters)
# granularity is a part of GROUP BY clause
buckets = ['second', 'minute', 'fifteen_minute', 'thirty_minute', 'hour', 'day', 'week', 'month', 'quarter', 'year']
# translate granularity to GROUP BY FLOOR(__time TO <bucket size>)
granularity = query_dict['granularity']
assert granularity in buckets
groups = [text('FLOOR(__time TO {})'.format(granularity.upper()))]
# translate dimensions to GROUP BY <dimensions>
if query_dict.get('queryType') == 'groupBy':
dims = [text(x) for x in query_dict.get('dimensions', [])]
if groups:
groups.extend(dims)
else:
groups = dims
# apply GROUP BY
if groups:
sql = sql.group_by(*groups)
having = query_dict.get('having')
if having:
sql = tr_having(sql, having, aggs)
# apply ORDER BY and LIMIT
limit_spec = query_dict.get('limitSpec')
if limit_spec:
sql = tr_limit_spec(sql, limit_spec)
return sql
if __name__ == '__main__':
q = {
"queryType": "groupBy",
"dataSource": "sample_datasource",
"granularity": "day",
"dimensions": ["country", "device"],
"limitSpec": {"type": "default", "limit": 5000, "columns": [
"country",
{
"dimension": "data_transfer",
"direction": "descending",
"dimensionOrder": "numeric"
}
]},
"filter": {
"type": "and",
"fields": [
{"type": "selector", "dimension": "carrier", "value": "AT&T"},
{"type": "or",
"fields": [
{"type": "selector", "dimension": "make", "value": "Apple"},
{"type": "selector", "dimension": "make", "value": "Samsung"}
]
}
]
},
"aggregations": [
{"type": "longSum", "name": "total_usage", "fieldName": "user_count"},
{"type": "doubleSum", "name": "data_transfer", "fieldName": "data_transfer"}
],
"postAggregations": [
{"type": "arithmetic",
"name": "avg_usage",
"fn": "/",
"fields": [
{"type": "fieldAccess", "fieldName": "data_transfer"},
{"type": "fieldAccess", "fieldName": "total_usage"}
]
}
],
"intervals": ["2012-01-01T00:00:00.000/2012-01-03T00:00:00.000"],
"having": {
"type": "greaterThan",
"aggregation": "total_usage",
"value": 100
}
}
print(to_sql(q))
"""
SELECT sum(user_count) AS total_usage, sum(data_transfer) AS data_transfer, sum(data_transfer) / sum(user_count) AS avg_usage
FROM sample_datasource
WHERE __time BETWEEN '2012-01-01T00:00:00.000' AND '2012-01-03T00:00:00.000' AND carrier = 'AT&T' AND (make = 'Apple' OR make = 'Samsung') GROUP BY FLOOR(__time TO DAY), country, device
HAVING sum(user_count) > 100 ORDER BY country ASC, data_transfer DESC
LIMIT 5000
"""
@knoguchi
Copy link
Author

knoguchi commented May 29, 2019

WIP.

  • DruidDialect doesn't seem to work when I tried sql.compile(dialect=DruidDialect).
  • postAgg fieldName resolving might be wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment