Skip to content

Instantly share code, notes, and snippets.

@philippefutureboy
Created August 26, 2023 17:45
Show Gist options
  • Save philippefutureboy/27ba48a835c713b45001f2db04b7f527 to your computer and use it in GitHub Desktop.
Save philippefutureboy/27ba48a835c713b45001f2db04b7f527 to your computer and use it in GitHub Desktop.
TagSkipPythonOperator Implementation
r"""
@source https://github.com/pytest-dev/pytest/blob/main/src/_pytest/mark/expression.py
Evaluate match expressions, as used by `-k` and `-m`.
The grammar is:
expression: expr? EOF
expr: and_expr ('or' and_expr)*
and_expr: not_expr ('and' not_expr)*
not_expr: 'not' not_expr | '(' expr ')' | ident
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
The semantics are:
- Empty expression evaluates to False.
- ident evaluates to True of False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics.
"""
import ast
import enum
import re
import types
from typing import TYPE_CHECKING, Callable, Iterator, Mapping, Optional, Sequence
import attr
if TYPE_CHECKING:
from typing import NoReturn
__all__ = [
"Expression",
"ParseError",
]
class TokenType(enum.Enum):
LPAREN = "left parenthesis"
RPAREN = "right parenthesis"
OR = "or"
AND = "and"
NOT = "not"
IDENT = "identifier"
EOF = "end of input"
@attr.s(frozen=True, slots=True, auto_attribs=True)
class Token:
type: TokenType
value: str
pos: int
class ParseError(Exception):
"""The expression contains invalid syntax.
:param column: The column in the line where the error occurred (1-based).
:param message: A description of the error.
"""
def __init__(self, column: int, message: str) -> None:
self.column = column
self.message = message
def __str__(self) -> str:
return f"at column {self.column}: {self.message}"
class Scanner:
__slots__ = ("tokens", "current")
def __init__(self, input: str) -> None:
self.tokens = self.lex(input)
self.current = next(self.tokens)
def lex(self, input: str) -> Iterator[Token]:
pos = 0
while pos < len(input):
if input[pos] in (" ", "\t"):
pos += 1
elif input[pos] == "(":
yield Token(TokenType.LPAREN, "(", pos)
pos += 1
elif input[pos] == ")":
yield Token(TokenType.RPAREN, ")", pos)
pos += 1
else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
if match:
value = match.group(0)
if value == "or":
yield Token(TokenType.OR, value, pos)
elif value == "and":
yield Token(TokenType.AND, value, pos)
elif value == "not":
yield Token(TokenType.NOT, value, pos)
else:
yield Token(TokenType.IDENT, value, pos)
pos += len(value)
else:
raise ParseError(
pos + 1,
f'unexpected character "{input[pos]}"',
)
yield Token(TokenType.EOF, "", pos)
def accept(self, type: TokenType, *, reject: bool = False) -> Optional[Token]:
if self.current.type is type:
token = self.current
if token.type is not TokenType.EOF:
self.current = next(self.tokens)
return token
if reject:
self.reject((type,))
return None
def reject(self, expected: Sequence[TokenType]) -> "NoReturn":
raise ParseError(
self.current.pos + 1,
"expected {}; got {}".format(
" OR ".join(type.value for type in expected),
self.current.type.value,
),
)
# True, False and None are legal match expression identifiers,
# but illegal as Python identifiers. To fix this, this prefix
# is added to identifiers in the conversion to Python AST.
IDENT_PREFIX = "$"
def expression(s: Scanner) -> ast.Expression:
if s.accept(TokenType.EOF):
ret: ast.expr = ast.NameConstant(False)
else:
ret = expr(s)
s.accept(TokenType.EOF, reject=True)
return ast.fix_missing_locations(ast.Expression(ret))
def expr(s: Scanner) -> ast.expr:
ret = and_expr(s)
while s.accept(TokenType.OR):
rhs = and_expr(s)
ret = ast.BoolOp(ast.Or(), [ret, rhs])
return ret
def and_expr(s: Scanner) -> ast.expr:
ret = not_expr(s)
while s.accept(TokenType.AND):
rhs = not_expr(s)
ret = ast.BoolOp(ast.And(), [ret, rhs])
return ret
def not_expr(s: Scanner) -> ast.expr:
if s.accept(TokenType.NOT):
return ast.UnaryOp(ast.Not(), not_expr(s))
if s.accept(TokenType.LPAREN):
ret = expr(s)
s.accept(TokenType.RPAREN, reject=True)
return ret
ident = s.accept(TokenType.IDENT)
if ident:
return ast.Name(IDENT_PREFIX + ident.value, ast.Load())
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
class MatcherAdapter(Mapping[str, bool]):
"""Adapts a matcher function to a locals mapping as required by eval()."""
def __init__(self, matcher: Callable[[str], bool]) -> None:
self.matcher = matcher
def __getitem__(self, key: str) -> bool:
return self.matcher(key[len(IDENT_PREFIX) :])
def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
def __len__(self) -> int:
raise NotImplementedError()
class Expression:
"""A compiled match expression as used by -k and -m.
The expression can be evaluated against different matchers.
"""
__slots__ = ("code",)
def __init__(self, code: types.CodeType) -> None:
self.code = code
@classmethod
def compile(self, input: str) -> "Expression":
"""Compile a match expression.
:param input: The input expression - one line.
"""
astexpr = expression(Scanner(input))
code: types.CodeType = compile(
astexpr,
filename="<pytest match expression>",
mode="eval",
)
return Expression(code)
def evaluate(self, matcher: Callable[[str], bool]) -> bool:
"""Evaluate the match expression.
:param matcher:
Given an identifier, should return whether it matches or not.
Should be prepared to handle arbitrary strings as input.
:returns: Whether the expression matches or not.
"""
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
return ret
import logging
from typing import Any, List, Optional
from airflow.exceptions import AirflowSkipException
from airflow.models import VariableAccessor
from airflow.operators.python import PythonOperator
from airflow.utils.context import Context
from airflow.utils.trigger_rule import TriggerRule
from lib.airflow.operators._expressions import Expression
LOG = logging.getLogger(__name__)
class TagSkipPythonOperator(PythonOperator):
def __init__(
self,
*,
tags: List[str],
trigger_rule: str = TriggerRule.NONE_FAILED,
**kwargs,
) -> None:
super().__init__(**kwargs, trigger_rule=trigger_rule)
self.tags = set(tags)
def execute(self, context: Context) -> Any:
"""
Determine whether or not the task is to be skipped and skips the task or executes the task.
Overrides PythonOperator.execute.
Args:
context (airflow.utils.context.Context): Airflow task instance Context
Raises:
TypeError: tag_query is not a str
AirflowSkipException: Task is to be skipped
Returns:
Any: the return value of the callable
"""
dag_run = context["dag_run"]
try:
tag_query = dag_run.conf["tag_query"]
if not isinstance(tag_query, str):
raise TypeError(type(tag_query))
except (AttributeError, KeyError):
tag_query = None
if tag_query is None:
execute = True
else:
execute = Expression.compile(tag_query).evaluate(lambda term: term in self.tags)
if not execute:
LOG.info(
"[SKIP] Task %s tagged with tags %s is to to be skipped (tag_query='%s')",
self.task_id,
list(self.tags),
tag_query,
)
raise AirflowSkipException()
else:
LOG.info(
"[RUN] Task %s tagged with tags %s is to proceed (tag_query='%s')",
self.task_id,
list(self.tags),
tag_query,
)
return super().execute(context)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment