Last active
April 25, 2024 14:47
-
-
Save dibrinsofor/b6eee7d3799de58d9e13f831c6c09069 to your computer and use it in GitHub Desktop.
Flagging interested patterns from source code that uses Mypy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pathlib | |
import click | |
import io, os | |
from enum import Enum | |
from typing import Optional, Union | |
import mypy.nodes | |
from mypy.parse import parse | |
from mypy.options import Options | |
from mypy.errors import CompileError, Errors | |
import sys | |
sys.path.append("../check") | |
from check import flush_errs, serialize, get_file_name, log_error, get_files, check_error_found, rare_error_found, parse_call_expr | |
ERRORS = set() | |
PATTERNS = Enum('PATTERNS', ['Unpack', 'Wrapper', 'DepDicts']) | |
# patterns | |
# 1. instances that depend on untyped keys in dicts | |
# 2. Uses of the unpack type [TODO: ignore] | |
# 3. wrappers called with any args | |
class Flagged: | |
line: int | |
var_name: str | |
pattern: PATTERNS | |
def __init__(self, line, patterns, v_n) -> None: | |
self.line = line | |
self.pattern = patterns | |
self.var_name = v_n | |
def trace_untyped_wrapper(name: str, stmt: any) -> bool: | |
found = False | |
# TODO: flag any cases that attempt to type the args and kwargs | |
kw = [mypy.nodes.ArgKind.ARG_STAR, mypy.nodes.ArgKind.ARG_STAR2] | |
if isinstance(stmt, mypy.nodes.FuncDef): | |
for sub_stmt in stmt.body.body: | |
found = trace_untyped_wrapper(name, sub_stmt) | |
if found: | |
break | |
elif isinstance(stmt, mypy.nodes.ReturnStmt): | |
if hasattr(stmt.expr, 'callee'): | |
found = trace_untyped_wrapper(name, stmt.expr) | |
elif isinstance(stmt, mypy.nodes.AssignmentStmt): | |
if isinstance(stmt.rvalue, mypy.nodes.CallExpr): | |
found = trace_untyped_wrapper(name, stmt.rvalue) | |
elif isinstance(stmt, mypy.nodes.CallExpr): | |
if stmt.callee.name == name: | |
if kw == stmt.arg_kinds: | |
found = True | |
return found | |
else: | |
rare_error_found('Statement') | |
return found | |
def trace_untyped_dict(name: str, stmt: any) -> bool: | |
found = False | |
vals = Union[mypy.nodes.StrExpr | mypy.nodes.IntExpr | mypy.nodes.BytesExpr | mypy.nodes.FloatExpr] | |
if isinstance(stmt, mypy.nodes.IfStmt): | |
for sub_stmt in stmt.expr: | |
found = trace_untyped_dict(name, sub_stmt) | |
if found: | |
break | |
elif isinstance(stmt, mypy.nodes.ComparisonExpr): | |
for ops in stmt.operators: | |
if "in" in ops: | |
if (name == stmt.operands[-1].name) and \ | |
(isinstance(stmt.operands[0], vals)): | |
found = True | |
return found | |
else: | |
if isinstance(stmt.operands[0], mypy.nodes.IndexExpr): | |
found = trace_untyped_dict(name, stmt.operands[0]) | |
elif isinstance(stmt, mypy.nodes.OpExpr): | |
left = trace_untyped_dict(name, stmt.left) | |
right = trace_untyped_dict(name, stmt.right) | |
if left or right: | |
found = True | |
return found | |
elif isinstance(stmt, mypy.nodes.IndexExpr): | |
if (stmt.base.name == name) and isinstance(stmt.index, vals): | |
found = True | |
return found | |
else: | |
rare_error_found('Statement') | |
return found | |
def has_wrapper_pattern(node: any) -> list[Flagged]: | |
line = node.line | |
found_patts: list[Flagged] = [] | |
for idx, arg in enumerate(node.type.arg_types): | |
if is_weak_callable(arg): | |
var_name = node.arg_names[idx] | |
for stmt in node.body.body: | |
trace = trace_untyped_wrapper(var_name, stmt) | |
if trace: | |
found_patts.append(Flagged(line, PATTERNS.DepDicts, var_name)) | |
return found_patts | |
def has_unpacked_type(node: any) -> list[Flagged]: | |
line = node.line | |
found_patts: list[Flagged] = [] | |
raise NotImplementedError() | |
def has_dep_dicts(node: any) -> list[Flagged]: | |
line = node.line | |
found_patts: list[Flagged] = [] | |
for idx, arg in enumerate(node.type.arg_types): | |
if is_weak_dict(arg): | |
var_name = node.arg_names[idx] | |
for stmt in node.body.body: | |
trace = trace_untyped_dict(var_name, stmt) | |
if trace: | |
found_patts.append(Flagged(line, PATTERNS.DepDicts, var_name)) | |
return found_patts | |
def check_for_pattern(node: any) -> list[Flagged]: | |
patterns_found = [] | |
exists = has_wrapper_pattern(node) | |
if exists != []: | |
patterns_found.extend(exists) | |
exists = has_unpacked_type(node) | |
if exists != []: | |
patterns_found.extend(exists) | |
exists = has_dep_dicts(node) | |
if exists != []: | |
patterns_found.extend(exists) | |
return patterns_found | |
def is_weak_type(t: any) -> bool: | |
raise NotImplementedError() | |
def is_weak_callable(t: any) -> bool: | |
t = str(t).upper() | |
if t == "CALLABLE?": | |
return True | |
elif "CALLABLE?" in t and "ANY?" in t: | |
return True | |
return False | |
def is_weak_dict(t: any) -> bool: | |
t = str(t).upper() | |
if t == "DICT?": | |
return True | |
elif "DICT?" in t and "ANY?" in t: | |
return True | |
return False | |
# only looking for weak callables or dicts now | |
def func_is_weak(node: any) -> Optional[list[any]]: | |
for arg in node.type.arg_types: | |
if is_weak_callable(arg): | |
return True | |
if is_weak_dict(arg): | |
return True | |
return False | |
def parse_expr(stmt: any) -> any: | |
if isinstance(stmt, (mypy.nodes.IntExpr, mypy.nodes.StrExpr, mypy.nodes.BytesExpr, mypy.nodes.FloatExpr, mypy.nodes.ComplexExpr)): | |
return stmt.value | |
elif isinstance(stmt, mypy.nodes.NameExpr): | |
return stmt.name | |
elif isinstance(stmt, mypy.nodes.EllipsisExpr): | |
... | |
elif isinstance(stmt, mypy.nodes.CallExpr): | |
return parse_call_expr(stmt) | |
else: | |
rare_error_found('Expression') | |
def parse_node(node: any) -> Optional[list[list[Flagged]]]: | |
patts_found = [] | |
if isinstance(node, mypy.nodes.FuncDef): | |
if func_is_weak(node): | |
patts = check_for_pattern(node) | |
if patts != []: | |
patts_found.append(patts) | |
elif isinstance(node, mypy.nodes.Decorator): | |
if func_is_weak(node.func): | |
patts = check_for_pattern(node.func) | |
if patts != []: | |
patts_found.append(patts) | |
elif isinstance(node, mypy.nodes.OverloadedFuncDef): | |
breakpoint() | |
print("overloaded", node.line) | |
elif isinstance(node, mypy.nodes.ClassDef): | |
breakpoint() | |
print("ClassDef", node.line) | |
# TODO: top level has type_vars and decorators attr. look into them | |
for stmt in node.defs.body: | |
res = parse_node(stmt) | |
if res is not None: | |
patts_found.extend(res) | |
elif isinstance(node, mypy.nodes.ExpressionStmt): | |
breakpoint() | |
print("ExpressionStmt") | |
elif isinstance(node, mypy.nodes.AssignmentStmt): | |
if node.is_alias_def(): | |
... | |
else: | |
... | |
breakpoint() | |
print("Assignment") | |
else: | |
rare_error_found('Statement') | |
return patts_found | |
def scan_file(file_name: str) -> list[tuple[int, str]]: | |
t = [] | |
if not file_name.endswith(".py"): | |
check_error_found("expected type stub not source code") | |
return t | |
f = open(pathlib.Path(file_name), "r", io.DEFAULT_BUFFER_SIZE) | |
options = Options() | |
errors = Errors(options) | |
try: | |
ast = parse(f.read(), file_name, None, errors, options) | |
except CompileError: | |
check_error_found("unable to scan file") | |
return t | |
for stmt in ast.defs: | |
result = parse_node(stmt) | |
if result is not None or result != []: | |
t.extend(result) | |
f.close() | |
return t | |
@click.command() | |
@click.option('-r', '--run-tests', 'tests', is_flag=True, help="Run tests without generating new scan files") | |
@click.argument("filepath", type=click.Path(exists=True)) | |
def sample(filepath, tests) -> None: | |
if os.path.isdir(filepath): | |
dir_files = get_files(filepath) | |
for file in dir_files: | |
p = scan_file(file) | |
print(p) | |
if len(ERRORS) != 0: | |
flush_errs(file) | |
else: | |
p = scan_file(filepath) | |
if len(ERRORS) != 0: | |
flush_errs(filepath) | |
if not tests: | |
# write output | |
... | |
if __name__ == "__main__": | |
sample() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
[WIP]: Parser to flag interesting patterns from source code (identified during manual study) that may inform the need the for more precise Mypy types.