Created
November 10, 2025 19:23
-
-
Save haileyok/b3afde0ed7a6a3cd8874819097d122c1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import operator | |
| from typing import Any, Callable, Dict, List, get_origin | |
| from google.cloud import bigquery | |
| from pydruid.query import Query | |
| from osprey.engine.ast import grammar | |
| from osprey.engine.ast_validator.validation_context import ValidatedSources | |
| from osprey.engine.ast_validator.validators.validate_call_kwargs import ValidateCallKwargs | |
| from osprey.engine.ast_validator.validators.validate_static_types import ValidateStaticTypes | |
| from osprey.engine.udf.base import QueryUdfBase, UDFBase | |
| from osprey.engine.utils.osprey_unary_executor import OspreyUnaryExecutor | |
| class BQTransformException(Exception): | |
| """Some error happened while trying to transform the Osprey AST into a BigQuery SQL query""" | |
| def __init__(self, node: grammar.ASTNode, error: str): | |
| super().__init__(f'{error}: {node.__class__.__name__}') | |
| self.node = node | |
| class BQTranslator: | |
| """Given an osprey_ast node tree, transform it into a BigQuery SQL WHERE clause""" | |
| def __init__(self, validated_sources: ValidatedSources): | |
| self._validated_sources = validated_sources | |
| try: | |
| self._udf_node_mapping = validated_sources.get_validator_result(ValidateCallKwargs) | |
| except KeyError: | |
| self._udf_node_mapping = {} | |
| try: | |
| static_types_result = validated_sources.get_validator_result(ValidateStaticTypes) | |
| self._name_types = static_types_result.name_type_and_span_cache | |
| except KeyError: | |
| self._name_types = {} | |
| assign_node = validated_sources.sources.get_entry_point().ast_root.statements[0] | |
| assert isinstance(assign_node, grammar.Assign) | |
| self._root = assign_node.value | |
| self._param_counter = 0 | |
| self._params: List[bigquery.ScalarQueryParameter] = [] | |
| def transform(self) -> Dict[str, Any]: | |
| sql = self._transform(self._root) | |
| return {'sql': sql, 'params': self._params} | |
| def _get_next_param_name(self) -> str: | |
| """Helper to return the next valid param name for the query""" | |
| name = f'param_{self._param_counter}' | |
| self._param_counter += 1 | |
| return name | |
| def _add_param(self, value: Any) -> str: | |
| """Take an input value and return the right ScalarQueryParameter based on the type of the value""" | |
| param_name = self._get_next_param_name() | |
| if value is None: | |
| bq_type = 'STRING' | |
| # check bool before int since bool is a subclass of int in python. we love this. | |
| elif isinstance(value, bool): | |
| bq_type = 'BOOL' | |
| elif isinstance(value, int): | |
| bq_type = 'INT64' | |
| elif isinstance(value, float): | |
| bq_type = 'FLOAT64' | |
| elif isinstance(value, str): | |
| bq_type = 'STRING' | |
| else: | |
| bq_type = 'STRING' | |
| self._params.append(bigquery.ScalarQueryParameter(param_name, bq_type, value)) | |
| return f'@{param_name}' | |
| def _is_array_column(self, column_name: str) -> bool: | |
| """Check if a column is a List/Array type based on static type information. Needed for some BQ queries.""" | |
| if column_name not in self._name_types: | |
| return False | |
| col_type = self._name_types[column_name].type | |
| origin = get_origin(col_type) | |
| return origin is list | |
| def _transform(self, node: grammar.ASTNode) -> str: | |
| method = 'transform_' + node.__class__.__name__ | |
| transformer = getattr(self, method, None) | |
| if not transformer: | |
| raise BQTransformException(node, 'Unknown AST Expression') | |
| ret = transformer(node) | |
| assert isinstance(ret, str) | |
| return ret | |
| def transform_BooleanOperation(self, node: grammar.BooleanOperation) -> str: | |
| """Translate boolean operations into SQL where""" | |
| assert isinstance(node.operand, grammar.And) or isinstance(node.operand, grammar.Or) | |
| operator = 'AND' if isinstance(node.operand, grammar.And) else 'OR' | |
| conds = [self._transform(v) for v in node.values] | |
| return f'({f" {operator} ".join(conds)})' | |
| def transform_BinaryComparison(self, node: grammar.BinaryComparison) -> str: | |
| """Translate various binary comparions into SQL where""" | |
| if isinstance(node.left, grammar.Name) and isinstance(node.right, grammar.Name): | |
| left_col = node.left.identifier | |
| right_col = node.right.identifier | |
| # Username == OldUsername | |
| if isinstance(node.comparator, grammar.Equals): | |
| return f'{left_col} = {right_col}' | |
| # Username != OldUsername | |
| elif isinstance(node.comparator, grammar.NotEquals): | |
| return f'{left_col} != {right_col}' | |
| else: | |
| # TODO: decide if we actually need this? should be supported tbh. keeping w/ druid for now tho | |
| raise BQTransformException( | |
| node.comparator, 'When comparing two features, only the `==` and `!=` operators are supported' | |
| ) | |
| dim = get_comparison_dimension(node) | |
| value = get_comparison_value(node) | |
| if value is None: | |
| # Friend == None | |
| if isinstance(node.comparator, grammar.Equals): | |
| return f'{dim} IS NULL' | |
| # Friend != None | |
| elif isinstance(node.comparator, grammar.NotEquals): | |
| return f'{dim} IS NOT NULL' | |
| # Username == 'Kitten' | |
| if isinstance(node.comparator, grammar.Equals): | |
| param = self._add_param(value) | |
| return f'{dim} = {param}' | |
| # USername != 'Bunny' | |
| elif isinstance(node.comparator, grammar.NotEquals): | |
| param = self._add_param(value) | |
| return f'{dim} != {param}' | |
| elif isinstance(node.comparator, grammar.In): | |
| return self._transform_in_comparison(node, dim, value) | |
| elif isinstance(node.comparator, grammar.NotIn): | |
| in_clause = self._transform_in_comparison(node, dim, value) | |
| return f'NOT ({in_clause})' | |
| # NumEars < 2 | |
| elif isinstance(node.comparator, grammar.LessThan): | |
| param = self._add_param(value) | |
| return f'{dim} < {param}' | |
| # NumEyes <= 2 | |
| elif isinstance(node.comparator, grammar.LessThanEquals): | |
| param = self._add_param(value) | |
| return f'{dim} <= {param}' | |
| # NumPaws > 4 | |
| elif isinstance(node.comparator, grammar.GreaterThan): | |
| param = self._add_param(value) | |
| return f'{dim} > {param}' | |
| # NumTails >= 1 | |
| elif isinstance(node.comparator, grammar.GreaterThanEquals): | |
| param = self._add_param(value) | |
| return f'{dim} >= {param}' | |
| else: | |
| raise BQTransformException(node.comparator, 'Unknown Binary Comparator') | |
| def _transform_in_comparison(self, node: grammar.BinaryComparison, dim: str, value: Any) -> str: | |
| """Transform queries like 'Mariners' in PostText to a SQL where""" | |
| if isinstance(value, str): | |
| if self._is_array_column(dim): | |
| param = self._add_param(value) | |
| return f'{param} IN UNNEST({dim})' | |
| else: | |
| param = self._add_param(f'%{value}%') | |
| return f'LOWER({dim}) LIKE LOWER({param})' | |
| elif isinstance(value, list): | |
| # empty guy is just false | |
| if not value: | |
| return 'FALSE' | |
| param_name = self._get_next_param_name() | |
| # check bool before int since bool is a subclass of int in python. we still love this. | |
| if isinstance(value[0], bool): | |
| bq_type = 'BOOL' | |
| elif isinstance(value[0], int): | |
| bq_type = 'INT64' | |
| elif isinstance(value[0], float): | |
| bq_type = 'FLOAT64' | |
| else: | |
| bq_type = 'STRING' | |
| self._params.append(bigquery.ArrayQueryParameter(param_name, bq_type, value)) | |
| return f'{dim} IN UNNEST(@{param_name})' | |
| else: | |
| raise BQTransformException(node, 'Invalid "in" comparison value type, must be string or list') | |
| def transform_UnaryOperation(self, node: grammar.UnaryOperation) -> str: | |
| """Trnsform unary operations into a SQL where""" | |
| if isinstance(node.operator, grammar.Not): | |
| operand_sql = self._transform(node.operand) | |
| return f'NOT ({operand_sql})' | |
| else: | |
| raise BQTransformException(node, 'Unknown Unary Operator') | |
| # TODO: there is probably a better way to do this? | |
| def transform_Call(self, node: grammar.Call) -> str: | |
| """Transform various function calls into SQL where. Requires UDFs implement to_bigquery_sql()""" | |
| udf, _ = self._udf_node_mapping[id(node)] | |
| if not isinstance(udf, QueryUdfBase): | |
| raise BQTransformException(node, f'UDF {udf.__class__.__name__} is not UdfQueryBase') | |
| # it seems like QueryUdfBase has a to_druid_query method, which would let us actually run the query. this means | |
| # unfortuantely that we'll need to implement sql transformers for a bunch of things | |
| if not hasattr(udf, 'to_bigquery_sql'): | |
| raise BQTransformException(node, f'UDF {udf.__class__.__name__} does not implement to_bigquery_sql()') | |
| # ignoring this type for now because, well it doesn't exist | |
| return udf.to_bigquery_sql() # type: ignore | |
| # copy-pasted from at the time druid translator | |
| _BINARY_OPERATORS: Dict[type, Callable[[Any, Any], Any]] = { | |
| grammar.Add: operator.add, | |
| grammar.Subtract: operator.sub, | |
| grammar.Multiply: operator.mul, | |
| grammar.Divide: operator.truediv, | |
| grammar.FloorDivide: operator.floordiv, | |
| grammar.Modulo: operator.mod, | |
| grammar.Pow: operator.pow, | |
| grammar.LeftShift: operator.lshift, | |
| grammar.RightShift: operator.rshift, | |
| grammar.BitwiseOr: operator.or_, | |
| grammar.BitwiseAnd: operator.and_, | |
| grammar.BitwiseXor: operator.xor, | |
| } | |
| def evaluate_binary_operation(node: grammar.BinaryOperation) -> Any: | |
| """Evaluates a BinaryOperation node with constant values (e.g., 3600 * 24)""" | |
| left_value = get_ast_node_value(node.left) | |
| right_value = get_ast_node_value(node.right) | |
| operator_func = _BINARY_OPERATORS.get(node.operator.__class__) | |
| if operator_func is None: | |
| raise BQTransformException(node, f'Unsupported binary operator: {node.operator.__class__.__name__}') | |
| return operator_func(left_value, right_value) | |
| def get_comparison_dimension(node: grammar.BinaryComparison) -> str: | |
| """Extracts the dimension name for a binary comparison""" | |
| if isinstance(node.left, grammar.Name): | |
| return node.left.identifier | |
| elif isinstance(node.right, grammar.Name): | |
| return node.right.identifier | |
| else: | |
| raise BQTransformException(node, 'Binary Comparator must contain at least one column') | |
| def get_comparison_value(node: grammar.BinaryComparison) -> Any: | |
| """Extracts the value for a binary comparison""" | |
| if isinstance(node.left, (grammar.Literal, grammar.UnaryOperation, grammar.BinaryOperation)): | |
| return get_ast_node_value(node.left) | |
| elif isinstance(node.right, (grammar.Literal, grammar.UnaryOperation, grammar.BinaryOperation)): | |
| return get_ast_node_value(node.right) | |
| def get_ast_node_value(node: grammar.ASTNode) -> Any: | |
| """Gets the relevant value from any given expression type (Name or Literal) | |
| Unary operations can be evaluated into literals here (for negative Numbers) | |
| Binary operations can be evaluated for constant arithmetic (e.g., 3600 * 24) | |
| """ | |
| if isinstance(node, grammar.UnaryOperation): | |
| return OspreyUnaryExecutor(node).get_execution_value() | |
| elif isinstance(node, grammar.BinaryOperation): | |
| return evaluate_binary_operation(node) | |
| elif isinstance(node, grammar.List): | |
| return [get_ast_node_value(i) for i in node.items] | |
| elif isinstance(node, grammar.None_): | |
| return None | |
| elif isinstance(node, grammar.String) or isinstance(node, grammar.Number) or isinstance(node, grammar.Boolean): | |
| return node.value | |
| else: | |
| raise BQTransformException(node, 'Node has no known value attribute') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment