Skip to content

Instantly share code, notes, and snippets.

@kingbuzzman
Created March 31, 2025 10:58
Show Gist options
  • Save kingbuzzman/2d46d3e6aa69661d088c53c31e328ccd to your computer and use it in GitHub Desktop.
Save kingbuzzman/2d46d3e6aa69661d088c53c31e328ccd to your computer and use it in GitHub Desktop.
Django app validator -- WIP
from __future__ import annotations
import ast
from collections import namedtuple
import copy
import sys
import os
import site
from types import ModuleType
from django.utils.module_loading import import_string
import yaml
Import = namedtuple("Import", ["line", "module", "is_nested", "is_typing", "is_local"])
class CircularDependencyError(Exception):
pass
class DotDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Convert nested dicts into DotDicts automatically
for key, value in self.items():
if isinstance(value, dict):
self[key] = DotDict(value)
elif isinstance(value, list):
self[key] = [DotDict(item) if isinstance(item, dict) else item for item in value]
def __getattr__(self, attr):
try:
return self[attr]
except KeyError:
raise AttributeError(f"'DotDict' object has no attribute '{attr}'")
def __setattr__(self, attr, value):
self[attr] = value
def __delattr__(self, attr):
try:
del self[attr]
except KeyError:
raise AttributeError(f"'DotDict' object has no attribute '{attr}'")
def deep_merge(base, override):
"""
Recursively merge the override dictionary into the base dictionary.
"""
for key, override_value in override.items():
if key not in base:
base[key] = override_value
continue
base_value = base[key]
# Merge nested dictionaries.
if isinstance(base_value, dict) and isinstance(override_value, dict):
deep_merge(base_value, override_value)
# Merge lists by appending missing items.
elif isinstance(base_value, list) and isinstance(override_value, list):
for item in override_value:
if item not in base_value:
base_value.append(item)
# Otherwise, do not override the existing value.
return base
def expand_module_dependencies(modules):
expanded = {}
cache = {}
def dfs(module, stack):
if module in stack:
cycle = " -> ".join(stack + [module])
raise CircularDependencyError(f"Circular dependency detected: {cycle}")
if module in cache:
return cache[module]
result = []
seen = set() # to avoid duplicates while preserving order
# Get direct dependencies; if the module isn't in the dict,
# assume it has no dependencies.
for dep in modules.get(module, []):
if dep not in seen:
result.append(dep)
seen.add(dep)
subdeps = dfs(dep, stack + [module])
for sub in subdeps:
if sub not in seen:
result.append(sub)
seen.add(sub)
cache[module] = result
return result
for mod in modules:
expanded[mod] = dfs(mod, [])
return expanded
def expand_all_module_dependencies(config, app_name):
if app_name not in config.apps:
config.apps[app_name] = DotDict(config.apps.__all__)
return config
return config
def load_constraints(config_path):
"""
Load the YAML configuration and expand module dependencies.
The returned configuration dict will include an expanded view of the modules dependencies.
"""
with open(config_path, encoding="utf-8") as f:
config = yaml.safe_load(f)
# We assume that the modules are defined under apps->__all__->modules.
modules = config["apps"]["__all__"]["modules"]
try:
expanded_modules = expand_module_dependencies(modules)
except CircularDependencyError as e:
print(f"Error in configuration: {e}")
raise
config["apps"]["__all__"]["modules"] = expanded_modules
for app_name, values in list(config["apps"].items()):
if app_name == "__all__":
continue
deep_merge(values, copy.deepcopy(config["apps"]["__all__"]))
try:
expanded_modules = expand_module_dependencies(values['modules'])
except CircularDependencyError as e:
print(f"Error in configuration: {e}")
raise
config["apps"][app_name]["modules"] = expanded_modules
return DotDict(config)
def is_module_local(module):
try:
module_path = os.path.abspath(module.__file__)
except AttributeError:
# built-in modules don't have the __file__ attribute
return False
site_packages_dirs = []
for path in site.getsitepackages():
site_packages_dirs.append(path)
if path.endswith("site-packages"):
site_packages_dirs.append(os.path.dirname(path))
# Also add the user-specific site-packages directory.
if user_site := site.getusersitepackages():
site_packages_dirs.append(user_site)
# Check if the module path is inside any of the site-packages directories.
for sp in site_packages_dirs:
if module_path.startswith(os.path.abspath(sp)):
return False
return True
def get_imports_from_file(file_path):
current_module = ".".join(os.path.dirname(file_path).split(os.sep))
with open(file_path, encoding="utf-8") as f:
tree = ast.parse(f.read(), filename=file_path)
class ImportCollector(ast.NodeVisitor):
def __init__(self):
self.imports = []
self.nesting_level = 0 # Tracks if inside a function or class.
self.in_typing = False # Tracks if inside an "if TYPE_CHECKING" block.
def visit_FunctionDef(self, node):
self.nesting_level += 1
self.generic_visit(node)
self.nesting_level -= 1
def visit_AsyncFunctionDef(self, node):
self.nesting_level += 1
self.generic_visit(node)
self.nesting_level -= 1
def visit_ClassDef(self, node):
self.nesting_level += 1
self.generic_visit(node)
self.nesting_level -= 1
def visit_If(self, node):
# Save current state.
old_in_typing = self.in_typing
# Check if the if test is a type-checking condition.
if (isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING") or (
isinstance(node.test, ast.Attribute)
and isinstance(node.test.value, ast.Name)
and node.test.value.id == "typing"
and node.test.attr == "TYPE_CHECKING"
):
# Set flag for the if block's body.
self.in_typing = True
for child in node.body:
self.visit(child)
# Restore the previous flag.
self.in_typing = old_in_typing
# Visit the else/orelse block without the type-checking flag.
for child in node.orelse:
self.visit(child)
else:
self.generic_visit(node)
def visit_Import(self, node):
for alias in node.names:
self.imports.append(
{
"module": alias.name, # e.g., "module_a" or "package_a.module_x"
"unknown_references": [],
"line": node.lineno,
"is_nested": self.nesting_level > 0,
"is_typing": self.in_typing,
}
)
self.generic_visit(node)
def visit_ImportFrom(self, node):
self.imports.append(
{
"module": node.module or current_module,
"unknown_references": [alias.name for alias in node.names],
"line": node.lineno,
"is_nested": self.nesting_level > 0,
"is_typing": self.in_typing,
}
)
self.generic_visit(node)
collector = ImportCollector()
collector.visit(tree)
return collector.imports
def get_import_modules(file_path):
relevant_paths = []
for imp in get_imports_from_file(file_path):
try:
path = imp["module"]
module = import_string(path)
relevant_paths.append(
Import(imp["line"], module, imp["is_nested"], imp["is_typing"], is_module_local(module))
)
except ImportError:
try:
module = __import__(path)
relevant_paths.append(
Import(imp["line"], __import__(path), imp["is_nested"], imp["is_typing"], is_module_local(module))
)
except ImportError as e:
# print(f"Error importing {path}: {e}")
pass
for unknown_reference in imp["unknown_references"]:
try:
path = imp["module"] + "." + unknown_reference
obj = import_string(path)
if isinstance(obj, ModuleType):
relevant_paths.append(
Import(imp["line"], obj, imp["is_nested"], imp["is_typing"], is_module_local(obj))
)
except ImportError as e:
# print(f"Error importing {path}: {e}")
pass
return relevant_paths
def project_structure(config):
for root, _, files in os.walk(config.project_root):
try:
app_name = root.split(os.sep)[1:][0]
except IndexError:
continue
if app_name in config.ignore_apps:
continue
expand_all_module_dependencies(config, app_name)
try:
module = root.split(os.sep)[1:][1]
except IndexError:
module = None
if module in config.apps[app_name].ignore_modules:
continue
for file in files:
if not file.endswith(".py"):
continue
_module = module
if not _module:
_module = os.path.splitext(file)[0]
if _module in config.apps[app_name].ignore_modules:
continue
yield app_name, _module, os.path.join(root, file)
def validate_file(app_name, module_name, file_path, config):
violations = []
imports = get_import_modules(file_path)
app_rules = config.apps[app_name]
allowed_imports = set(app_rules.modules.get(module_name, []))
for imp in imports:
if not imp.is_local:
continue
if imp.is_typing:
continue
parts = imp.module.__name__.split('.')
if not (len(parts) > 2):
continue
imported_app_name = parts[1]
imported_module_name = parts[2]
if imported_module_name is None:
continue
if imported_module_name not in allowed_imports:
violations.append(
f"{file_path}:{imp.line}: Module '{imported_module_name}' in app '{imported_app_name}'"
)
return violations
def engine():
config = load_constraints("validator.yml")
issues = []
for app_name, module, file in project_structure(config):
issues.extend(validate_file(app_name, module, file, config))
if issues:
print("The following violations were found:")
for issue in issues:
print(issue)
sys.exit(1)
engine()
project_root: "nest"
ignore_apps: ["payout"]
app_hierarchy: [
["common"],
["application"],
["payout", "payee"],
]
apps:
__all__:
ignore_modules: ["migrations", "tests"]
modules:
constants: []
exceptions: []
utils: ["constants", "exceptions"]
models: ["utils"]
views: ["models", "tasks"]
serializers: ["models"]
urls: ["views"]
tasks: ["models"]
forms: ["models", "tasks"]
admin: ["models", "forms"]
payee:
modules:
views: ['permissions', "backends"]
# app_a:
# allowed_apps: []
# modules:
# constants: []
# exceptions: []
# utils:
# allowed_imports: ["module_a"]
# "constants": set(),
# "exceptions": set(),
# "utils": {"constants", "exceptions"},
# app_b:
# allowed_apps: ["app_a"] # can import from app_a
# modules:
# module_x:
# allowed_imports: [] # no restrictions in this simple example
# module_y:
# allowed_imports: ["module_x", "module_a"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment