|
import ast |
|
import sys |
|
from pathlib import Path |
|
from shutil import copytree |
|
from textwrap import dedent |
|
|
|
import refactor |
|
from refactor import Action, Replace, Rule |
|
|
|
|
|
def match_add_prefix(node: ast.AST, prefix_code: str) -> Action | None: |
|
if isinstance(node, ast.Module): |
|
prefix_nodes = ast.parse(prefix_code).body |
|
new_body = prefix_nodes + node.body |
|
return Replace(node, ast.Module(body=new_body, type_ignores=node.type_ignores)) |
|
return None |
|
|
|
def match_add_suffix(node: ast.AST, prefix_code: str) -> Action | None: |
|
if isinstance(node, ast.Module): |
|
prefix_nodes = ast.parse(prefix_code).body |
|
new_body = node.body + prefix_nodes |
|
return Replace(node, ast.Module(body=new_body, type_ignores=node.type_ignores)) |
|
return None |
|
|
|
def match_replace_function_body(node: ast.AST, function_name: str, new_body_code: str) -> Action | None: |
|
if isinstance(node, ast.FunctionDef) and node.name == function_name: |
|
new_body = ast.parse(new_body_code).body |
|
return Replace(node, ast.FunctionDef(name=node.name, args=node.args, body=new_body, decorator_list=node.decorator_list, returns=node.returns)) |
|
return None |
|
|
|
def match_replace_handler_with_mock(node: ast.AST, path: str, method: str, new_handler_code: str) -> Action | None: |
|
if isinstance(node, ast.FunctionDef): |
|
for decorator in node.decorator_list: |
|
if isinstance(decorator, ast.Call) and hasattr(decorator.func, 'attr') and decorator.func.attr == method: |
|
if path in ast.unparse(decorator): |
|
new_handler_node = ast.parse(new_handler_code).body[0] |
|
return Replace(node, new_handler_node) |
|
return None |
|
|
|
def apply_rules(file_path: Path, rule_classes: list): |
|
session = refactor.Session(rules=rule_classes) |
|
change = session.run_file(file_path) |
|
|
|
if change is not None: |
|
change.apply_diff() |
|
print(f"Applied changes to {file_path}") |
|
else: |
|
print(f"No changes applied to {file_path}") |
|
|
|
def mock_handlers(src: Path): |
|
|
|
|
|
class DagsRouteImports(refactor.Rule): |
|
def match(self, node: ast.AST) -> refactor.Action | None: |
|
return match_add_prefix( |
|
node=node, |
|
prefix_code=dedent( |
|
""" |
|
from fastapi import Request |
|
from fastapi.responses import JSONResponse |
|
from laminar.apitest.mocks import get_mock_handler |
|
""" |
|
).strip() |
|
) |
|
|
|
class DagStructureReplacer(refactor.Rule): |
|
|
|
def match(self, node: ast.AST) -> refactor.Action | None: |
|
return match_replace_handler_with_mock( |
|
node, |
|
path="/deployments/{deployment_id}/dags/{dag_id}/structure", |
|
method="get", |
|
new_handler_code=dedent( |
|
""" |
|
@router.get("/deployments/{deployment_id}/dags/{dag_id}/structure") |
|
async def dag_structure(request: Request) -> JSONResponse: |
|
mock = get_mock_handler( |
|
route_path="/deployments/{deployment_id}/dags/{dag_id}/structure", |
|
RequestModel=None, |
|
ResponseModel=schemas.v1.DagStructure |
|
) |
|
return await mock(request=request, log=log) |
|
""" |
|
), |
|
) |
|
|
|
|
|
apply_rules( |
|
src / "laminar/apiserver/routers/v1/dags.py", |
|
[ |
|
DagsRouteImports, |
|
DagStructureReplacer |
|
], |
|
) |
|
|
|
def patch_dependencies(src: Path): |
|
class DependenciesPrefixer(Rule): |
|
def match(self, node: ast.AST) -> Action | None: |
|
return match_add_prefix( |
|
node=node, |
|
prefix_code=dedent(""" |
|
from laminar.apitest.mocks import Mocks |
|
""").strip() |
|
) |
|
|
|
class DependenciesSuffixer(Rule): |
|
def match(self, node: ast.AST) -> Action | None: |
|
return match_add_suffix( |
|
node=node, |
|
prefix_code=dedent(""" |
|
def get_mocks() -> Mocks: |
|
return Mocks() |
|
""").strip() |
|
) |
|
|
|
apply_rules( |
|
src / "laminar/apiserver/dependencies.py", |
|
[DependenciesPrefixer, DependenciesSuffixer] |
|
) |
|
|
|
def patch_build_app(src: Path): |
|
|
|
class AppLifespanReplacer(Rule): |
|
def match(self, node: ast.AST) -> Action | None: |
|
return match_replace_function_body( |
|
node=node, |
|
function_name="app_lifespan", |
|
new_body_code=dedent(""" |
|
async def app_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: |
|
yield |
|
""").strip() |
|
) |
|
|
|
class BuildAppPatcher(Rule): |
|
def match(self, node: ast.AST) -> Action | None: |
|
if isinstance(node, ast.FunctionDef) and node.name == "build_app": |
|
new_body = [] |
|
for stmt in node.body: |
|
if (isinstance(stmt, ast.Expr) and |
|
isinstance(stmt.value, ast.Call) and |
|
'db_connection_scope_middleware' in ast.unparse(stmt)): |
|
continue |
|
|
|
if isinstance(stmt, ast.Match): |
|
new_stmts_code = dedent(""" |
|
from laminar.apitest.mocks import Mocks |
|
from .apiserver import router as module # type: ignore[no-redef] |
|
from .apitest import router as mock_module |
|
app.state.mocks = Mocks() |
|
app.include_router(mock_module.router) |
|
""") |
|
new_stmts = ast.parse(new_stmts_code).body |
|
|
|
new_body.extend(new_stmts) |
|
else: |
|
new_body.append(stmt) |
|
|
|
# Constructing the new FunctionDef node with the modified body |
|
new_node = ast.FunctionDef( |
|
name=node.name, |
|
args=node.args, |
|
body=new_body, |
|
decorator_list=node.decorator_list, |
|
returns=node.returns |
|
) |
|
|
|
# Apply fix_missing_locations to the entire new function definition |
|
ast.fix_missing_locations(new_node) |
|
|
|
return Replace(node, new_node) |
|
return None |
|
|
|
apply_rules( |
|
src / "laminar" / "app.py", |
|
[AppLifespanReplacer, BuildAppPatcher] |
|
) |
|
|
|
|
|
|
|
def main(): |
|
src = Path.cwd() / sys.argv[1] |
|
assert src.exists() |
|
assert src.is_dir() |
|
|
|
copytree(Path(__file__).parent, src / "laminar" / "apitest") |
|
patch_build_app(src) |
|
mock_handlers(src) |
|
patch_dependencies(src) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |