Last active
October 20, 2025 03:17
-
-
Save haileyok/fd8e08cdb2a493098595fc6e52b163be to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import operator | |
| from typing import Any, Callable, Dict, List | |
| from google.cloud import bigquery | |
| from osprey.engine.ast import grammar | |
| from osprey.engine.ast_validator.validation_context import ValidatedSources | |
| from osprey.engine.ast_validator.validators.validate_call_kwargs import ValidateCallKwargs | |
| from osprey.engine.udf.base import QueryUdfBase | |
| from osprey.engine.utils.osprey_unary_executor import OspreyUnaryExecutor | |
| class BQTransformException(Exception): | |
| """Some error happened while trying to transform the Osprey AST into a BigQuery SQL query""" | |
| def __init__(self, node: grammar.ASTNode, error: str): | |
| super().__init__(f'{error}: {node.__class__.__name__}') | |
| self.node = node | |
| class BQTranslator: | |
| """Given an osprey_ast node tree, transform it into a BigQuery SQL WHERE clause""" | |
| def __init__(self, validated_sources: ValidatedSources): | |
| try: | |
| self._udf_node_mapping = validated_sources.get_validator_result(ValidateCallKwargs) | |
| except KeyError: | |
| self._udf_node_mapping = {} | |
| assign_node = validated_sources.sources.get_entry_point().ast_root.statements[0] | |
| assert isinstance(assign_node, grammar.Assign) | |
| self._root = assign_node.value | |
| self._param_counter = 0 | |
| self._params: List[bigquery.ScalarQueryParameter] = [] | |
| def transform(self) -> Dict[str, Any]: | |
| sql = self._transform(self._root) | |
| return {'sql': sql, 'params': self._params} | |
| def _get_next_param_name(self) -> str: | |
| """Helper to return the next valid param name for the query""" | |
| name = f'param_{self._param_counter}' | |
| self._param_counter += 1 | |
| return name | |
| def _add_param(self, value: Any) -> str: | |
| """Take an input value and return the right ScalarQueryParameter based on the type of the value""" | |
| param_name = self._get_next_param_name() | |
| if value is None: | |
| bq_type = 'STRING' | |
| elif isinstance(value, int): | |
| bq_type = 'INT64' | |
| elif isinstance(value, float): | |
| bq_type = 'FLOAT64' | |
| elif isinstance(value, str): | |
| bq_type = 'STRING' | |
| elif isinstance(value, bool): | |
| bq_type = 'BOOL' | |
| else: | |
| bq_type = 'STRING' | |
| self._params.append(bigquery.ScalarQueryParameter(param_name, bq_type, value)) | |
| return f'@{param_name}' | |
| def _transform(self, node: grammar.ASTNode) -> str: | |
| method = 'transform_' + node.__class__.__name__ | |
| transformer = getattr(self, method, None) | |
| if not transformer: | |
| raise BQTransformException(node, 'Unknown AST Expression') | |
| ret = transformer(node) | |
| assert isinstance(ret, str) | |
| return ret | |
| def transform_BooleanOperation(self, node: grammar.BooleanOperation) -> str: | |
| """Translate boolean operations into SQL where""" | |
| assert isinstance(node.operand, grammar.And) or isinstance(node.operand, grammar.Or) | |
| operator = 'AND' if isinstance(node.operand, grammar.And) else 'OR' | |
| conds = [self._transform(v) for v in node.values] | |
| return f'({f" {operator} ".join(conds)})' | |
| def transform_BinaryComparison(self, node: grammar.BinaryComparison) -> str: | |
| """Translate various binary comparions into SQL where""" | |
| if isinstance(node.left, grammar.Name) and isinstance(node.right, grammar.Name): | |
| left_col = node.left.identifier | |
| right_col = node.right.identifier | |
| # Username == OldUsername | |
| if isinstance(node.comparator, grammar.Equals): | |
| return f'{left_col} = {right_col}' | |
| # Username != OldUsername | |
| elif isinstance(node.comparator, grammar.NotEquals): | |
| return f'{left_col} != {right_col}' | |
| else: | |
| # TODO: decide if we actually need this? should be supported tbh. keeping w/ druid for now tho | |
| raise BQTransformException( | |
| node.comparator, 'When comparing two features, only the `==` and `!=` operators are supported' | |
| ) | |
| dim = get_comparison_dimension(node) | |
| value = get_comparison_value(node) | |
| if value is None: | |
| # Friend == None | |
| if isinstance(node.comparator, grammar.Equals): | |
| return f'{dim} IS NULL' | |
| # Friend != None | |
| elif isinstance(node.comparator, grammar.NotEquals): | |
| return f'{dim} IS NOT NULL' | |
| # Username == 'Kitten' | |
| if isinstance(node.comparator, grammar.Equals): | |
| param = self._add_param(value) | |
| return f'{dim} = {param}' | |
| # USername != 'Bunny' | |
| elif isinstance(node.comparator, grammar.NotEquals): | |
| param = self._add_param(value) | |
| return f'{dim} != {param}' | |
| elif isinstance(node.comparator, grammar.In): | |
| return self._transform_in_comparison(node, dim, value) | |
| elif isinstance(node.comparator, grammar.NotIn): | |
| in_clause = self._transform_in_comparison(node, dim, value) | |
| return f'NOT ({in_clause})' | |
| # NumEars < 2 | |
| elif isinstance(node.comparator, grammar.LessThan): | |
| param = self._add_param(value) | |
| return f'{dim} < {param}' | |
| # NumEyes <= 2 | |
| elif isinstance(node.comparator, grammar.LessThanEquals): | |
| param = self._add_param(value) | |
| return f'{dim} <= {param}' | |
| # NumPaws > 4 | |
| elif isinstance(node.comparator, grammar.GreaterThan): | |
| param = self._add_param(value) | |
| return f'{dim} > {param}' | |
| # NumTails >= 1 | |
| elif isinstance(node.comparator, grammar.GreaterThanEquals): | |
| param = self._add_param(value) | |
| return f'{dim} >= {param}' | |
| else: | |
| raise BQTransformException(node.comparator, 'Unknown Binary Comparator') | |
| def _transform_in_comparison(self, node: grammar.BinaryComparison, dim: str, value: Any) -> str: | |
| """Transform queries like 'Mariners' in PostText to a SQL where""" | |
| if isinstance(value, str): | |
| # idk how i like this, but the base druid translator does a case insensitive search so we'll just do that here | |
| # to keep it the same. maybe we just let the user specify the case sensititivity tbh | |
| param = self._add_param(f'%{value}%') | |
| return f'LOWER({dim}) LIKE LOWER({param})' | |
| elif isinstance(value, list): | |
| # empty guy is just false | |
| if not value: | |
| return 'FALSE' | |
| param_name = self._get_next_param_name() | |
| if isinstance(value[0], (int)): | |
| bq_type = 'INT64' | |
| elif isinstance(value[0], float): | |
| bq_type = 'FLOAT64' | |
| elif isinstance(value[0], bool): | |
| bq_type = 'BOOL' | |
| else: | |
| bq_type = 'STRING' | |
| self._params.append(bigquery.ScalarQueryParameter(param_name, f'ARRAY<{bq_type}>', value)) # type: ignore | |
| return f'{dim} IN UNNEST(@{param_name})' | |
| else: | |
| raise BQTransformException(node, 'Invalid "in" comparison value type, must be string or list') | |
| def transform_UnaryOperation(self, node: grammar.UnaryOperation) -> str: | |
| """Trnsform unary operations into a SQL where""" | |
| if isinstance(node.operator, grammar.Not): | |
| operand_sql = self._transform(node.operand) | |
| return f'NOT ({operand_sql})' | |
| else: | |
| raise BQTransformException(node, 'Unknown Unary Operator') | |
| # TODO: this needs to get implemented, i can't really test it since none of them are supported lol | |
| def transform_Call(self, node: grammar.Call) -> str: | |
| """Transform various function calls into SQL where. Requires UDFs implement to_bigquery_sql()""" | |
| udf, _ = self._udf_node_mapping[id(node)] | |
| if not isinstance(udf, QueryUdfBase): | |
| raise BQTransformException(node, 'Unknown function call type') | |
| # it seems like QueryUdfBase has a to_druid_query method, which would let us actually run the query. this means | |
| # unfortuantely that we'll need to implement sql transformers for a bunch of things | |
| if not hasattr(udf, 'to_bigquery_sql'): | |
| raise BQTransformException(node, f'UDF {udf.__class__.__name__} does not implement to_bigquery_sql()') | |
| # ignoring this type for now because, well it doesn't exist | |
| return udf.to_bigquery_sql() # type: ignore | |
| # copy-pasted from at the time druid translator | |
| def get_comparison_dimension(node: grammar.BinaryComparison) -> str: | |
| """Extracts the dimension name for a binary comparison""" | |
| if isinstance(node.left, grammar.Name): | |
| return node.left.identifier | |
| elif isinstance(node.right, grammar.Name): | |
| return node.right.identifier | |
| else: | |
| raise BQTransformException(node, 'Binary Comparator must contain at least one column') | |
| def get_comparison_value(node: grammar.BinaryComparison) -> Any: | |
| """Extracts the value for a binary comparison""" | |
| if isinstance(node.left, (grammar.Literal, grammar.UnaryOperation, grammar.BinaryOperation)): | |
| return get_ast_node_value(node.left) | |
| elif isinstance(node.right, (grammar.Literal, grammar.UnaryOperation, grammar.BinaryOperation)): | |
| return get_ast_node_value(node.right) | |
| def get_ast_node_value(node: grammar.ASTNode) -> Any: | |
| """Gets the relevant value from any given expression type (Name or Literal) | |
| Unary operations can be evaluated into literals here (for negative Numbers) | |
| Binary operations can be evaluated for constant arithmetic (e.g., 3600 * 24) | |
| """ | |
| if isinstance(node, grammar.UnaryOperation): | |
| return OspreyUnaryExecutor(node).get_execution_value() | |
| elif isinstance(node, grammar.List): | |
| return [get_ast_node_value(i) for i in node.items] | |
| elif isinstance(node, grammar.None_): | |
| return None | |
| elif isinstance(node, grammar.String) or isinstance(node, grammar.Number) or isinstance(node, grammar.Boolean): | |
| return node.value | |
| else: | |
| raise BQTransformException(node, 'Node has no known value attribute') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment