Created
August 26, 2023 17:45
-
-
Save philippefutureboy/27ba48a835c713b45001f2db04b7f527 to your computer and use it in GitHub Desktop.
TagSkipPythonOperator Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
r""" | |
@source https://github.com/pytest-dev/pytest/blob/main/src/_pytest/mark/expression.py | |
Evaluate match expressions, as used by `-k` and `-m`. | |
The grammar is: | |
expression: expr? EOF | |
expr: and_expr ('or' and_expr)* | |
and_expr: not_expr ('and' not_expr)* | |
not_expr: 'not' not_expr | '(' expr ')' | ident | |
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+ | |
The semantics are: | |
- Empty expression evaluates to False. | |
- ident evaluates to True of False according to a provided matcher function. | |
- or/and/not evaluate according to the usual boolean semantics. | |
""" | |
import ast | |
import enum | |
import re | |
import types | |
from typing import TYPE_CHECKING, Callable, Iterator, Mapping, Optional, Sequence | |
import attr | |
if TYPE_CHECKING: | |
from typing import NoReturn | |
__all__ = [ | |
"Expression", | |
"ParseError", | |
] | |
class TokenType(enum.Enum): | |
LPAREN = "left parenthesis" | |
RPAREN = "right parenthesis" | |
OR = "or" | |
AND = "and" | |
NOT = "not" | |
IDENT = "identifier" | |
EOF = "end of input" | |
@attr.s(frozen=True, slots=True, auto_attribs=True) | |
class Token: | |
type: TokenType | |
value: str | |
pos: int | |
class ParseError(Exception): | |
"""The expression contains invalid syntax. | |
:param column: The column in the line where the error occurred (1-based). | |
:param message: A description of the error. | |
""" | |
def __init__(self, column: int, message: str) -> None: | |
self.column = column | |
self.message = message | |
def __str__(self) -> str: | |
return f"at column {self.column}: {self.message}" | |
class Scanner: | |
__slots__ = ("tokens", "current") | |
def __init__(self, input: str) -> None: | |
self.tokens = self.lex(input) | |
self.current = next(self.tokens) | |
def lex(self, input: str) -> Iterator[Token]: | |
pos = 0 | |
while pos < len(input): | |
if input[pos] in (" ", "\t"): | |
pos += 1 | |
elif input[pos] == "(": | |
yield Token(TokenType.LPAREN, "(", pos) | |
pos += 1 | |
elif input[pos] == ")": | |
yield Token(TokenType.RPAREN, ")", pos) | |
pos += 1 | |
else: | |
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:]) | |
if match: | |
value = match.group(0) | |
if value == "or": | |
yield Token(TokenType.OR, value, pos) | |
elif value == "and": | |
yield Token(TokenType.AND, value, pos) | |
elif value == "not": | |
yield Token(TokenType.NOT, value, pos) | |
else: | |
yield Token(TokenType.IDENT, value, pos) | |
pos += len(value) | |
else: | |
raise ParseError( | |
pos + 1, | |
f'unexpected character "{input[pos]}"', | |
) | |
yield Token(TokenType.EOF, "", pos) | |
def accept(self, type: TokenType, *, reject: bool = False) -> Optional[Token]: | |
if self.current.type is type: | |
token = self.current | |
if token.type is not TokenType.EOF: | |
self.current = next(self.tokens) | |
return token | |
if reject: | |
self.reject((type,)) | |
return None | |
def reject(self, expected: Sequence[TokenType]) -> "NoReturn": | |
raise ParseError( | |
self.current.pos + 1, | |
"expected {}; got {}".format( | |
" OR ".join(type.value for type in expected), | |
self.current.type.value, | |
), | |
) | |
# True, False and None are legal match expression identifiers, | |
# but illegal as Python identifiers. To fix this, this prefix | |
# is added to identifiers in the conversion to Python AST. | |
IDENT_PREFIX = "$" | |
def expression(s: Scanner) -> ast.Expression: | |
if s.accept(TokenType.EOF): | |
ret: ast.expr = ast.NameConstant(False) | |
else: | |
ret = expr(s) | |
s.accept(TokenType.EOF, reject=True) | |
return ast.fix_missing_locations(ast.Expression(ret)) | |
def expr(s: Scanner) -> ast.expr: | |
ret = and_expr(s) | |
while s.accept(TokenType.OR): | |
rhs = and_expr(s) | |
ret = ast.BoolOp(ast.Or(), [ret, rhs]) | |
return ret | |
def and_expr(s: Scanner) -> ast.expr: | |
ret = not_expr(s) | |
while s.accept(TokenType.AND): | |
rhs = not_expr(s) | |
ret = ast.BoolOp(ast.And(), [ret, rhs]) | |
return ret | |
def not_expr(s: Scanner) -> ast.expr: | |
if s.accept(TokenType.NOT): | |
return ast.UnaryOp(ast.Not(), not_expr(s)) | |
if s.accept(TokenType.LPAREN): | |
ret = expr(s) | |
s.accept(TokenType.RPAREN, reject=True) | |
return ret | |
ident = s.accept(TokenType.IDENT) | |
if ident: | |
return ast.Name(IDENT_PREFIX + ident.value, ast.Load()) | |
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT)) | |
class MatcherAdapter(Mapping[str, bool]): | |
"""Adapts a matcher function to a locals mapping as required by eval().""" | |
def __init__(self, matcher: Callable[[str], bool]) -> None: | |
self.matcher = matcher | |
def __getitem__(self, key: str) -> bool: | |
return self.matcher(key[len(IDENT_PREFIX) :]) | |
def __iter__(self) -> Iterator[str]: | |
raise NotImplementedError() | |
def __len__(self) -> int: | |
raise NotImplementedError() | |
class Expression: | |
"""A compiled match expression as used by -k and -m. | |
The expression can be evaluated against different matchers. | |
""" | |
__slots__ = ("code",) | |
def __init__(self, code: types.CodeType) -> None: | |
self.code = code | |
@classmethod | |
def compile(self, input: str) -> "Expression": | |
"""Compile a match expression. | |
:param input: The input expression - one line. | |
""" | |
astexpr = expression(Scanner(input)) | |
code: types.CodeType = compile( | |
astexpr, | |
filename="<pytest match expression>", | |
mode="eval", | |
) | |
return Expression(code) | |
def evaluate(self, matcher: Callable[[str], bool]) -> bool: | |
"""Evaluate the match expression. | |
:param matcher: | |
Given an identifier, should return whether it matches or not. | |
Should be prepared to handle arbitrary strings as input. | |
:returns: Whether the expression matches or not. | |
""" | |
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)) | |
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from typing import Any, List, Optional | |
from airflow.exceptions import AirflowSkipException | |
from airflow.models import VariableAccessor | |
from airflow.operators.python import PythonOperator | |
from airflow.utils.context import Context | |
from airflow.utils.trigger_rule import TriggerRule | |
from lib.airflow.operators._expressions import Expression | |
LOG = logging.getLogger(__name__) | |
class TagSkipPythonOperator(PythonOperator): | |
def __init__( | |
self, | |
*, | |
tags: List[str], | |
trigger_rule: str = TriggerRule.NONE_FAILED, | |
**kwargs, | |
) -> None: | |
super().__init__(**kwargs, trigger_rule=trigger_rule) | |
self.tags = set(tags) | |
def execute(self, context: Context) -> Any: | |
""" | |
Determine whether or not the task is to be skipped and skips the task or executes the task. | |
Overrides PythonOperator.execute. | |
Args: | |
context (airflow.utils.context.Context): Airflow task instance Context | |
Raises: | |
TypeError: tag_query is not a str | |
AirflowSkipException: Task is to be skipped | |
Returns: | |
Any: the return value of the callable | |
""" | |
dag_run = context["dag_run"] | |
try: | |
tag_query = dag_run.conf["tag_query"] | |
if not isinstance(tag_query, str): | |
raise TypeError(type(tag_query)) | |
except (AttributeError, KeyError): | |
tag_query = None | |
if tag_query is None: | |
execute = True | |
else: | |
execute = Expression.compile(tag_query).evaluate(lambda term: term in self.tags) | |
if not execute: | |
LOG.info( | |
"[SKIP] Task %s tagged with tags %s is to to be skipped (tag_query='%s')", | |
self.task_id, | |
list(self.tags), | |
tag_query, | |
) | |
raise AirflowSkipException() | |
else: | |
LOG.info( | |
"[RUN] Task %s tagged with tags %s is to proceed (tag_query='%s')", | |
self.task_id, | |
list(self.tags), | |
tag_query, | |
) | |
return super().execute(context) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment