Created
August 22, 2025 22:29
-
-
Save richardkiss/13f4614dee7988da0b4d4bf6d4f1645f to your computer and use it in GitHub Desktop.
Systematic API stub generator using is-comparison for import detection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Systematic import detector using is-comparison to find true import sources. | |
Clean algorithm: try all potential import locations, use 'is' to find the truth. | |
""" | |
import inspect | |
import importlib | |
from typing import Dict, Set, List, Optional, get_type_hints, get_origin, get_args | |
class SystematicImportDetector: | |
"""Detects required imports by systematically testing import locations with 'is' comparison.""" | |
def __init__(self): | |
self.true_module_cache = {} | |
self.potential_locations = [ | |
# Common alternative locations for types that might lie about their module | |
'chia_rs', | |
'chia_rs.sized_ints', | |
'chia_rs.sized_bytes', | |
'builtins', | |
'typing', | |
] | |
def find_true_import_source(self, type_obj) -> Optional[str]: | |
"""Find the true import source for a type using systematic is-comparison.""" | |
if not hasattr(type_obj, '__name__'): | |
return None | |
type_name = type_obj.__name__ | |
# Check cache first | |
cache_key = (type_name, id(type_obj)) | |
if cache_key in self.true_module_cache: | |
return self.true_module_cache[cache_key] | |
# Build list of locations to check | |
locations_to_check = [] | |
# Start with the claimed module | |
if hasattr(type_obj, '__module__') and type_obj.__module__: | |
locations_to_check.append(type_obj.__module__) | |
# Add potential alternative locations | |
locations_to_check.extend(self.potential_locations) | |
# Remove duplicates while preserving order | |
seen = set() | |
unique_locations = [] | |
for loc in locations_to_check: | |
if loc not in seen: | |
seen.add(loc) | |
unique_locations.append(loc) | |
print(f"Finding true source for {type_name}...") | |
print(f" Claimed module: {getattr(type_obj, '__module__', 'none')}") | |
print(f" Testing locations: {unique_locations}") | |
# Test each location systematically | |
for module_name in unique_locations: | |
try: | |
module = importlib.import_module(module_name) | |
if hasattr(module, type_name): | |
candidate = getattr(module, type_name) | |
if candidate is type_obj: | |
print(f" ✓ Found true source: {module_name}") | |
self.true_module_cache[cache_key] = module_name | |
return module_name | |
else: | |
print(f" ✗ {module_name}.{type_name} exists but is different object") | |
else: | |
print(f" ✗ {module_name} has no {type_name}") | |
except ImportError: | |
print(f" ✗ Cannot import {module_name}") | |
except Exception as e: | |
print(f" ✗ Error checking {module_name}: {e}") | |
print(f" ? No true source found for {type_name}") | |
self.true_module_cache[cache_key] = None | |
return None | |
def get_method_types(self, api_class) -> Set[object]: | |
"""Extract all type objects used in API method signatures.""" | |
all_types = set() | |
if not hasattr(api_class, 'metadata') or not hasattr(api_class.metadata, 'message_type_to_request'): | |
return all_types | |
for request_info in api_class.metadata.message_type_to_request.values(): | |
method = request_info.method | |
try: | |
# Get type hints with the method's module context | |
hints = get_type_hints(method) | |
for param_name, type_hint in hints.items(): | |
discovered_types = self.extract_all_types_from_hint(type_hint) | |
all_types.update(discovered_types) | |
except Exception as e: | |
print(f"Warning: Could not get type hints for {method.__name__}: {e}") | |
# Fallback to signature inspection | |
try: | |
sig = inspect.signature(method) | |
for param in sig.parameters.values(): | |
if param.annotation != inspect.Parameter.empty: | |
discovered_types = self.extract_all_types_from_annotation(param.annotation) | |
all_types.update(discovered_types) | |
if sig.return_annotation != inspect.Signature.empty: | |
discovered_types = self.extract_all_types_from_annotation(sig.return_annotation) | |
all_types.update(discovered_types) | |
except Exception as e2: | |
print(f"Warning: Could not analyze {method.__name__}: {e2}") | |
return all_types | |
def extract_all_types_from_hint(self, type_hint) -> Set[object]: | |
"""Extract all type objects from a resolved type hint.""" | |
types = set() | |
# Add the type itself if it's a real type | |
if hasattr(type_hint, '__module__') and hasattr(type_hint, '__name__'): | |
types.add(type_hint) | |
# Handle generic types (Optional[X], List[Y], etc.) | |
origin = get_origin(type_hint) | |
if origin: | |
if hasattr(origin, '__module__') and hasattr(origin, '__name__'): | |
types.add(origin) | |
# Recursively process type arguments | |
args = get_args(type_hint) | |
for arg in args: | |
types.update(self.extract_all_types_from_hint(arg)) | |
return types | |
def extract_all_types_from_annotation(self, annotation) -> Set[object]: | |
"""Extract type objects from a raw annotation (fallback method).""" | |
types = set() | |
# If it's a real type object, add it | |
if hasattr(annotation, '__module__') and hasattr(annotation, '__name__'): | |
types.add(annotation) | |
return types | |
def generate_import_for_type(self, type_obj) -> Optional[str]: | |
"""Generate the import statement for a single type object.""" | |
if not hasattr(type_obj, '__name__'): | |
return None | |
type_name = type_obj.__name__ | |
true_module = self.find_true_import_source(type_obj) | |
if not true_module: | |
return None | |
# Skip Python builtins and typing generics | |
if true_module in ['builtins', 'typing']: | |
return None | |
# Generate the import statement | |
return f'from {true_module} import {type_name}' | |
def generate_protocol_module_imports(self, type_objects) -> Set[str]: | |
"""Generate protocol module imports (from chia.protocols import X).""" | |
protocol_imports = set() | |
for type_obj in type_objects: | |
true_module = self.find_true_import_source(type_obj) | |
if not true_module: | |
continue | |
# Check if this is a protocol module | |
if '.protocols.' in true_module and true_module.startswith('chia.protocols.'): | |
protocol_name = true_module.split('.')[-1] # e.g., 'harvester_protocol' | |
protocol_imports.add(f'from chia.protocols import {protocol_name}') | |
return protocol_imports | |
def generate_imports_for_api(self, api_class) -> List[str]: | |
"""Generate complete import statements for an API class.""" | |
# Get all type objects used in the API | |
used_type_objects = self.get_method_types(api_class) | |
print(f"Found {len(used_type_objects)} unique type objects") | |
# Generate imports | |
imports = set() | |
# Always add basic imports | |
imports.add('from __future__ import annotations') | |
imports.add('import logging') | |
imports.add('from typing import TYPE_CHECKING, ClassVar, cast') | |
# Check if we need Optional | |
needs_optional = any( | |
str(get_origin(getattr(obj, '__origin__', obj))).endswith('Union') and | |
type(None) in get_args(getattr(obj, '__origin__', obj)) | |
for obj in used_type_objects | |
if hasattr(obj, '__origin__') or 'Optional' in str(obj) | |
) | |
if needs_optional or any('Optional' in str(obj) for obj in used_type_objects): | |
imports.add('from typing import Optional') | |
# Add API infrastructure | |
imports.add('from chia.server.api_protocol import ApiMetadata') | |
# Generate protocol module imports | |
protocol_imports = self.generate_protocol_module_imports(used_type_objects) | |
imports.update(protocol_imports) | |
# Generate individual type imports | |
for type_obj in used_type_objects: | |
import_stmt = self.generate_import_for_type(type_obj) | |
if import_stmt: | |
imports.add(import_stmt) | |
# Sort and format imports | |
sorted_imports = self.sort_imports(list(imports)) | |
return sorted_imports | |
def sort_imports(self, imports: List[str]) -> List[str]: | |
"""Sort imports in a logical order.""" | |
sorted_imports = [] | |
# Future imports | |
future = [imp for imp in imports if imp.startswith('from __future__')] | |
sorted_imports.extend(sorted(future)) | |
if future: | |
sorted_imports.append('') | |
# Standard library | |
stdlib = [imp for imp in imports if imp.startswith('import ') and not 'chia' in imp] | |
typing = [imp for imp in imports if 'from typing import' in imp] | |
sorted_imports.extend(sorted(stdlib)) | |
sorted_imports.extend(sorted(typing)) | |
if stdlib or typing: | |
sorted_imports.append('') | |
# TYPE_CHECKING block | |
sorted_imports.append('if TYPE_CHECKING:') | |
sorted_imports.append(' from chia.server.api_protocol import ApiProtocol') | |
sorted_imports.append('') | |
# Third party (chia_rs and submodules) | |
chia_rs = [imp for imp in imports if 'from chia_rs' in imp] | |
if chia_rs: | |
sorted_imports.extend(sorted(chia_rs)) | |
sorted_imports.append('') | |
# Chia imports | |
sorted_imports.append('# Minimal imports to avoid circular dependencies') | |
chia = [imp for imp in imports if imp.startswith('from chia') and 'typing' not in imp] | |
sorted_imports.extend(sorted(chia)) | |
return sorted_imports | |
def generate_stub_class(self, api_class, output_file=None) -> str: | |
"""Generate a complete stub class from a real API class.""" | |
# Generate class name | |
api_name = api_class.__name__ | |
if api_name.endswith("API"): | |
stub_name = api_name[:-3] + "ApiStub" | |
else: | |
stub_name = api_name + "Stub" | |
lines = [] | |
# Generate imports | |
import_lines = self.generate_imports_for_api(api_class) | |
lines.extend(import_lines) | |
lines.append('') | |
lines.append('') | |
# Class definition | |
lines.append(f"class {stub_name}:") | |
lines.append(f' """Lightweight API stub for {api_name} to break circular dependencies."""') | |
lines.append('') | |
lines.append(' if TYPE_CHECKING:') | |
lines.append(f' _protocol_check: ClassVar[ApiProtocol] = cast("{stub_name}", None)') | |
lines.append('') | |
lines.append(' log: logging.Logger') | |
lines.append(' metadata: ClassVar[ApiMetadata] = ApiMetadata()') | |
lines.append('') | |
lines.append(' def ready(self) -> bool:') | |
lines.append(' """Check if the service is ready."""') | |
lines.append(' return True') | |
lines.append('') | |
# Generate methods | |
if hasattr(api_class, 'metadata') and hasattr(api_class.metadata, 'message_type_to_request'): | |
for request_info in api_class.metadata.message_type_to_request.values(): | |
method = request_info.method | |
try: | |
# Get method source lines | |
source_lines, _ = inspect.getsourcelines(method) | |
# Extract signature (up to colon) | |
method_lines = [] | |
for line in source_lines: | |
method_lines.append(line.rstrip()) | |
if line.rstrip().endswith(':'): | |
break | |
# Add method signature | |
lines.extend(method_lines) | |
# Add docstring | |
lines.append(' """Stub method."""') | |
# Determine return behavior | |
sig = inspect.signature(method) | |
return_annotation = sig.return_annotation | |
if (return_annotation == type(None) or | |
str(return_annotation) == "<class 'NoneType'>" or | |
return_annotation == inspect.Signature.empty): | |
# No return needed for None/void methods | |
pass | |
else: | |
# Return None for non-void methods | |
lines.append(' return None') | |
lines.append('') | |
except Exception as e: | |
print(f"Warning: Could not process method {method.__name__}: {e}") | |
continue | |
# Join lines | |
stub_content = '\n'.join(lines) | |
# Write to file if requested | |
if output_file: | |
with open(output_file, 'w') as f: | |
f.write(stub_content) | |
print(f"Generated: {output_file}") | |
return stub_content | |
def validate_stub(self, stub_content: str) -> bool: | |
"""Validate that the generated stub compiles.""" | |
try: | |
import ast | |
ast.parse(stub_content) | |
return True | |
except SyntaxError as e: | |
print(f"Validation failed: {e}") | |
return False | |
def test_systematic_detection(): | |
"""Test the systematic import detection with is-comparison.""" | |
try: | |
from chia.harvester.harvester_api import HarvesterAPI | |
from chia.wallet.wallet_node_api import WalletNodeAPI | |
detector = SystematicImportDetector() | |
print("=== SYSTEMATIC IMPORT DETECTION WITH IS-COMPARISON ===") | |
print() | |
# Test with HarvesterAPI | |
print("--- HarvesterAPI ---") | |
imports = detector.generate_imports_for_api(HarvesterAPI) | |
print("\nGenerated imports:") | |
for imp in imports: | |
print(f" {imp}") | |
print("\n--- WalletAPI (RespondToPhUpdates test) ---") | |
imports = detector.generate_imports_for_api(WalletNodeAPI) | |
# Check for RespondToPhUpdates | |
for imp in imports: | |
if 'RespondToPhUpdates' in imp: | |
print(f"✓ Correctly resolved: {imp}") | |
break | |
else: | |
print("✗ RespondToPhUpdates not found") | |
# Show first 10 imports | |
print(f"\nFirst 10 imports:") | |
for imp in imports[:10]: | |
print(f" {imp}") | |
print("\n--- Validation Test ---") | |
stub = detector.generate_stub_class(HarvesterAPI) | |
if detector.validate_stub(stub): | |
print("✓ Generated stub validates successfully") | |
else: | |
print("✗ Stub validation failed") | |
except ImportError as e: | |
print(f"Could not import APIs: {e}") | |
def main(): | |
"""Generate all API stubs using systematic detection.""" | |
detector = SystematicImportDetector() | |
apis = [ | |
('chia.harvester.harvester_api', 'HarvesterAPI', 'chia/apis/harvester_stub_systematic.py'), | |
('chia.farmer.farmer_api', 'FarmerAPI', 'chia/apis/farmer_stub_systematic.py'), | |
('chia.timelord.timelord_api', 'TimelordAPI', 'chia/apis/timelord_stub_systematic.py'), | |
('chia.introducer.introducer_api', 'IntroducerAPI', 'chia/apis/introducer_stub_systematic.py'), | |
('chia.wallet.wallet_node_api', 'WalletNodeAPI', 'chia/apis/wallet_stub_systematic.py'), | |
('chia.full_node.full_node_api', 'FullNodeAPI', 'chia/apis/full_node_stub_systematic.py'), | |
] | |
for module_name, class_name, output_file in apis: | |
try: | |
module = importlib.import_module(module_name) | |
api_class = getattr(module, class_name) | |
print(f"\nGenerating {output_file}...") | |
stub_content = detector.generate_stub_class(api_class, output_file) | |
if detector.validate_stub(stub_content): | |
print(f"✓ {output_file} generated and validated") | |
else: | |
print(f"✗ {output_file} failed validation") | |
except ImportError as e: | |
print(f"Could not import {module_name}.{class_name}: {e}") | |
except Exception as e: | |
print(f"Error generating {class_name}: {e}") | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) > 1 and sys.argv[1] == "demo": | |
test_systematic_detection() | |
else: | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment