Skip to content

Instantly share code, notes, and snippets.

@richardkiss
Created August 22, 2025 22:29
Show Gist options
  • Save richardkiss/13f4614dee7988da0b4d4bf6d4f1645f to your computer and use it in GitHub Desktop.
Save richardkiss/13f4614dee7988da0b4d4bf6d4f1645f to your computer and use it in GitHub Desktop.
Systematic API stub generator using is-comparison for import detection
#!/usr/bin/env python3
"""
Systematic import detector using is-comparison to find true import sources.
Clean algorithm: try all potential import locations, use 'is' to find the truth.
"""
import inspect
import importlib
from typing import Dict, Set, List, Optional, get_type_hints, get_origin, get_args
class SystematicImportDetector:
"""Detects required imports by systematically testing import locations with 'is' comparison."""
def __init__(self):
self.true_module_cache = {}
self.potential_locations = [
# Common alternative locations for types that might lie about their module
'chia_rs',
'chia_rs.sized_ints',
'chia_rs.sized_bytes',
'builtins',
'typing',
]
def find_true_import_source(self, type_obj) -> Optional[str]:
"""Find the true import source for a type using systematic is-comparison."""
if not hasattr(type_obj, '__name__'):
return None
type_name = type_obj.__name__
# Check cache first
cache_key = (type_name, id(type_obj))
if cache_key in self.true_module_cache:
return self.true_module_cache[cache_key]
# Build list of locations to check
locations_to_check = []
# Start with the claimed module
if hasattr(type_obj, '__module__') and type_obj.__module__:
locations_to_check.append(type_obj.__module__)
# Add potential alternative locations
locations_to_check.extend(self.potential_locations)
# Remove duplicates while preserving order
seen = set()
unique_locations = []
for loc in locations_to_check:
if loc not in seen:
seen.add(loc)
unique_locations.append(loc)
print(f"Finding true source for {type_name}...")
print(f" Claimed module: {getattr(type_obj, '__module__', 'none')}")
print(f" Testing locations: {unique_locations}")
# Test each location systematically
for module_name in unique_locations:
try:
module = importlib.import_module(module_name)
if hasattr(module, type_name):
candidate = getattr(module, type_name)
if candidate is type_obj:
print(f" ✓ Found true source: {module_name}")
self.true_module_cache[cache_key] = module_name
return module_name
else:
print(f" ✗ {module_name}.{type_name} exists but is different object")
else:
print(f" ✗ {module_name} has no {type_name}")
except ImportError:
print(f" ✗ Cannot import {module_name}")
except Exception as e:
print(f" ✗ Error checking {module_name}: {e}")
print(f" ? No true source found for {type_name}")
self.true_module_cache[cache_key] = None
return None
def get_method_types(self, api_class) -> Set[object]:
"""Extract all type objects used in API method signatures."""
all_types = set()
if not hasattr(api_class, 'metadata') or not hasattr(api_class.metadata, 'message_type_to_request'):
return all_types
for request_info in api_class.metadata.message_type_to_request.values():
method = request_info.method
try:
# Get type hints with the method's module context
hints = get_type_hints(method)
for param_name, type_hint in hints.items():
discovered_types = self.extract_all_types_from_hint(type_hint)
all_types.update(discovered_types)
except Exception as e:
print(f"Warning: Could not get type hints for {method.__name__}: {e}")
# Fallback to signature inspection
try:
sig = inspect.signature(method)
for param in sig.parameters.values():
if param.annotation != inspect.Parameter.empty:
discovered_types = self.extract_all_types_from_annotation(param.annotation)
all_types.update(discovered_types)
if sig.return_annotation != inspect.Signature.empty:
discovered_types = self.extract_all_types_from_annotation(sig.return_annotation)
all_types.update(discovered_types)
except Exception as e2:
print(f"Warning: Could not analyze {method.__name__}: {e2}")
return all_types
def extract_all_types_from_hint(self, type_hint) -> Set[object]:
"""Extract all type objects from a resolved type hint."""
types = set()
# Add the type itself if it's a real type
if hasattr(type_hint, '__module__') and hasattr(type_hint, '__name__'):
types.add(type_hint)
# Handle generic types (Optional[X], List[Y], etc.)
origin = get_origin(type_hint)
if origin:
if hasattr(origin, '__module__') and hasattr(origin, '__name__'):
types.add(origin)
# Recursively process type arguments
args = get_args(type_hint)
for arg in args:
types.update(self.extract_all_types_from_hint(arg))
return types
def extract_all_types_from_annotation(self, annotation) -> Set[object]:
"""Extract type objects from a raw annotation (fallback method)."""
types = set()
# If it's a real type object, add it
if hasattr(annotation, '__module__') and hasattr(annotation, '__name__'):
types.add(annotation)
return types
def generate_import_for_type(self, type_obj) -> Optional[str]:
"""Generate the import statement for a single type object."""
if not hasattr(type_obj, '__name__'):
return None
type_name = type_obj.__name__
true_module = self.find_true_import_source(type_obj)
if not true_module:
return None
# Skip Python builtins and typing generics
if true_module in ['builtins', 'typing']:
return None
# Generate the import statement
return f'from {true_module} import {type_name}'
def generate_protocol_module_imports(self, type_objects) -> Set[str]:
"""Generate protocol module imports (from chia.protocols import X)."""
protocol_imports = set()
for type_obj in type_objects:
true_module = self.find_true_import_source(type_obj)
if not true_module:
continue
# Check if this is a protocol module
if '.protocols.' in true_module and true_module.startswith('chia.protocols.'):
protocol_name = true_module.split('.')[-1] # e.g., 'harvester_protocol'
protocol_imports.add(f'from chia.protocols import {protocol_name}')
return protocol_imports
def generate_imports_for_api(self, api_class) -> List[str]:
"""Generate complete import statements for an API class."""
# Get all type objects used in the API
used_type_objects = self.get_method_types(api_class)
print(f"Found {len(used_type_objects)} unique type objects")
# Generate imports
imports = set()
# Always add basic imports
imports.add('from __future__ import annotations')
imports.add('import logging')
imports.add('from typing import TYPE_CHECKING, ClassVar, cast')
# Check if we need Optional
needs_optional = any(
str(get_origin(getattr(obj, '__origin__', obj))).endswith('Union') and
type(None) in get_args(getattr(obj, '__origin__', obj))
for obj in used_type_objects
if hasattr(obj, '__origin__') or 'Optional' in str(obj)
)
if needs_optional or any('Optional' in str(obj) for obj in used_type_objects):
imports.add('from typing import Optional')
# Add API infrastructure
imports.add('from chia.server.api_protocol import ApiMetadata')
# Generate protocol module imports
protocol_imports = self.generate_protocol_module_imports(used_type_objects)
imports.update(protocol_imports)
# Generate individual type imports
for type_obj in used_type_objects:
import_stmt = self.generate_import_for_type(type_obj)
if import_stmt:
imports.add(import_stmt)
# Sort and format imports
sorted_imports = self.sort_imports(list(imports))
return sorted_imports
def sort_imports(self, imports: List[str]) -> List[str]:
"""Sort imports in a logical order."""
sorted_imports = []
# Future imports
future = [imp for imp in imports if imp.startswith('from __future__')]
sorted_imports.extend(sorted(future))
if future:
sorted_imports.append('')
# Standard library
stdlib = [imp for imp in imports if imp.startswith('import ') and not 'chia' in imp]
typing = [imp for imp in imports if 'from typing import' in imp]
sorted_imports.extend(sorted(stdlib))
sorted_imports.extend(sorted(typing))
if stdlib or typing:
sorted_imports.append('')
# TYPE_CHECKING block
sorted_imports.append('if TYPE_CHECKING:')
sorted_imports.append(' from chia.server.api_protocol import ApiProtocol')
sorted_imports.append('')
# Third party (chia_rs and submodules)
chia_rs = [imp for imp in imports if 'from chia_rs' in imp]
if chia_rs:
sorted_imports.extend(sorted(chia_rs))
sorted_imports.append('')
# Chia imports
sorted_imports.append('# Minimal imports to avoid circular dependencies')
chia = [imp for imp in imports if imp.startswith('from chia') and 'typing' not in imp]
sorted_imports.extend(sorted(chia))
return sorted_imports
def generate_stub_class(self, api_class, output_file=None) -> str:
"""Generate a complete stub class from a real API class."""
# Generate class name
api_name = api_class.__name__
if api_name.endswith("API"):
stub_name = api_name[:-3] + "ApiStub"
else:
stub_name = api_name + "Stub"
lines = []
# Generate imports
import_lines = self.generate_imports_for_api(api_class)
lines.extend(import_lines)
lines.append('')
lines.append('')
# Class definition
lines.append(f"class {stub_name}:")
lines.append(f' """Lightweight API stub for {api_name} to break circular dependencies."""')
lines.append('')
lines.append(' if TYPE_CHECKING:')
lines.append(f' _protocol_check: ClassVar[ApiProtocol] = cast("{stub_name}", None)')
lines.append('')
lines.append(' log: logging.Logger')
lines.append(' metadata: ClassVar[ApiMetadata] = ApiMetadata()')
lines.append('')
lines.append(' def ready(self) -> bool:')
lines.append(' """Check if the service is ready."""')
lines.append(' return True')
lines.append('')
# Generate methods
if hasattr(api_class, 'metadata') and hasattr(api_class.metadata, 'message_type_to_request'):
for request_info in api_class.metadata.message_type_to_request.values():
method = request_info.method
try:
# Get method source lines
source_lines, _ = inspect.getsourcelines(method)
# Extract signature (up to colon)
method_lines = []
for line in source_lines:
method_lines.append(line.rstrip())
if line.rstrip().endswith(':'):
break
# Add method signature
lines.extend(method_lines)
# Add docstring
lines.append(' """Stub method."""')
# Determine return behavior
sig = inspect.signature(method)
return_annotation = sig.return_annotation
if (return_annotation == type(None) or
str(return_annotation) == "<class 'NoneType'>" or
return_annotation == inspect.Signature.empty):
# No return needed for None/void methods
pass
else:
# Return None for non-void methods
lines.append(' return None')
lines.append('')
except Exception as e:
print(f"Warning: Could not process method {method.__name__}: {e}")
continue
# Join lines
stub_content = '\n'.join(lines)
# Write to file if requested
if output_file:
with open(output_file, 'w') as f:
f.write(stub_content)
print(f"Generated: {output_file}")
return stub_content
def validate_stub(self, stub_content: str) -> bool:
"""Validate that the generated stub compiles."""
try:
import ast
ast.parse(stub_content)
return True
except SyntaxError as e:
print(f"Validation failed: {e}")
return False
def test_systematic_detection():
"""Test the systematic import detection with is-comparison."""
try:
from chia.harvester.harvester_api import HarvesterAPI
from chia.wallet.wallet_node_api import WalletNodeAPI
detector = SystematicImportDetector()
print("=== SYSTEMATIC IMPORT DETECTION WITH IS-COMPARISON ===")
print()
# Test with HarvesterAPI
print("--- HarvesterAPI ---")
imports = detector.generate_imports_for_api(HarvesterAPI)
print("\nGenerated imports:")
for imp in imports:
print(f" {imp}")
print("\n--- WalletAPI (RespondToPhUpdates test) ---")
imports = detector.generate_imports_for_api(WalletNodeAPI)
# Check for RespondToPhUpdates
for imp in imports:
if 'RespondToPhUpdates' in imp:
print(f"✓ Correctly resolved: {imp}")
break
else:
print("✗ RespondToPhUpdates not found")
# Show first 10 imports
print(f"\nFirst 10 imports:")
for imp in imports[:10]:
print(f" {imp}")
print("\n--- Validation Test ---")
stub = detector.generate_stub_class(HarvesterAPI)
if detector.validate_stub(stub):
print("✓ Generated stub validates successfully")
else:
print("✗ Stub validation failed")
except ImportError as e:
print(f"Could not import APIs: {e}")
def main():
"""Generate all API stubs using systematic detection."""
detector = SystematicImportDetector()
apis = [
('chia.harvester.harvester_api', 'HarvesterAPI', 'chia/apis/harvester_stub_systematic.py'),
('chia.farmer.farmer_api', 'FarmerAPI', 'chia/apis/farmer_stub_systematic.py'),
('chia.timelord.timelord_api', 'TimelordAPI', 'chia/apis/timelord_stub_systematic.py'),
('chia.introducer.introducer_api', 'IntroducerAPI', 'chia/apis/introducer_stub_systematic.py'),
('chia.wallet.wallet_node_api', 'WalletNodeAPI', 'chia/apis/wallet_stub_systematic.py'),
('chia.full_node.full_node_api', 'FullNodeAPI', 'chia/apis/full_node_stub_systematic.py'),
]
for module_name, class_name, output_file in apis:
try:
module = importlib.import_module(module_name)
api_class = getattr(module, class_name)
print(f"\nGenerating {output_file}...")
stub_content = detector.generate_stub_class(api_class, output_file)
if detector.validate_stub(stub_content):
print(f"✓ {output_file} generated and validated")
else:
print(f"✗ {output_file} failed validation")
except ImportError as e:
print(f"Could not import {module_name}.{class_name}: {e}")
except Exception as e:
print(f"Error generating {class_name}: {e}")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "demo":
test_systematic_detection()
else:
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment