|
# quantum_em_mis_tests.py |
|
|
|
import torch |
|
import numpy as np |
|
from datetime import datetime |
|
import logging |
|
import unittest |
|
import time |
|
import psutil |
|
from pathlib import Path |
|
import json |
|
from typing import List, Tuple, Dict, Any |
|
from complextensor import ComplexTensor |
|
from quantum_em_mis_core import MISTransform, EnhancedQuantumWaveFunction, EnhancedEMField |
|
|
|
# Configure logging |
|
logging.basicConfig( |
|
filename=f'quantum_em_mis_tests_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log', |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class DataCollector: |
|
"""Collect and save test data.""" |
|
def __init__(self): |
|
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
self.data_dir = Path(f"test_data_{self.timestamp}") |
|
self.data_dir.mkdir(exist_ok=True) |
|
self.metadata = { |
|
"timestamp": self.timestamp, |
|
"data_files": [] |
|
} |
|
|
|
def save_tensor_data(self, data: torch.Tensor, name: str, metadata: Dict = None): |
|
"""Save tensor data with metadata.""" |
|
file_path = self.data_dir / f"{name}.npy" |
|
np.save(file_path, data.detach().cpu().numpy()) |
|
|
|
meta_info = { |
|
"filename": f"{name}.npy", |
|
"shape": list(data.shape), |
|
"dtype": str(data.dtype), |
|
"metadata": metadata or {} |
|
} |
|
self.metadata["data_files"].append(meta_info) |
|
|
|
def save_complex_tensor_data(self, data: ComplexTensor, name: str, metadata: Dict = None): |
|
"""Save ComplexTensor data.""" |
|
self.save_tensor_data(data.real, f"{name}_real", metadata) |
|
self.save_tensor_data(data.imag, f"{name}_imag", metadata) |
|
|
|
def log_performance(self, name: str, execution_time: float, memory_used: float): |
|
"""Log performance metrics.""" |
|
with open(self.data_dir / "performance.json", "a") as f: |
|
json.dump({ |
|
"name": name, |
|
"time": execution_time, |
|
"memory_mb": memory_used, |
|
"timestamp": datetime.now().isoformat() |
|
}, f) |
|
f.write('\n') |
|
|
|
def save_metadata(self): |
|
"""Save collected metadata.""" |
|
with open(self.data_dir / "metadata.json", "w") as f: |
|
json.dump(self.metadata, f, indent=2) |
|
|
|
class QuantumEMMISTests(unittest.TestCase): |
|
def setUp(self): |
|
self.data_collector = DataCollector() |
|
|
|
def tearDown(self): |
|
self.data_collector.save_metadata() |
|
|
|
def measure_performance(self, name: str, callable_obj, *args, **kwargs): |
|
"""Measure performance and collect data.""" |
|
process = psutil.Process() |
|
start_mem = process.memory_info().rss / 1024 / 1024 |
|
start_time = time.time() |
|
|
|
result = callable_obj(*args, **kwargs) |
|
|
|
end_time = time.time() |
|
end_mem = process.memory_info().rss / 1024 / 1024 |
|
|
|
execution_time = end_time - start_time |
|
memory_used = end_mem - start_mem |
|
|
|
self.data_collector.log_performance(name, execution_time, memory_used) |
|
return result |
|
|
|
def test_mis_scale_invariance(self): |
|
"""Test MIS transformation scale invariance.""" |
|
logger.info("Testing MIS scale invariance") |
|
mis = MISTransform() |
|
scales = [0.1, 1.0, 10.0] |
|
|
|
x = torch.linspace(-5, 5, 50) |
|
X, Y = torch.meshgrid(x, x, indexing='ij') |
|
Z = ComplexTensor(X.unsqueeze(0), Y.unsqueeze(0)) |
|
|
|
for scale in scales: |
|
scaled_Z = ComplexTensor(Z.real * scale, Z.imag * scale) |
|
transformed = self.measure_performance( |
|
f"mis_transform_scale_{scale}", |
|
mis.__call__, |
|
scaled_Z, t=1.0 |
|
) |
|
self.data_collector.save_complex_tensor_data( |
|
transformed, |
|
f"mis_scale_{scale}", |
|
{"scale": scale} |
|
) |
|
|
|
def test_quantum_evolution(self): |
|
"""Test quantum wave function evolution.""" |
|
logger.info("Testing quantum evolution") |
|
qwf = EnhancedQuantumWaveFunction(n_states=3, n_dimensions=2) |
|
|
|
x = torch.linspace(-5, 5, 50) |
|
X, Y = torch.meshgrid(x, x, indexing='ij') |
|
X = X.unsqueeze(0) |
|
Y = Y.unsqueeze(0) |
|
|
|
times = [0.0, 0.5, 1.0, 2.0] |
|
|
|
for t in times: |
|
psi = self.measure_performance( |
|
f"quantum_evolution_t_{t}", |
|
qwf.evolve_unnormalized, |
|
[X, Y], t |
|
) |
|
self.data_collector.save_complex_tensor_data( |
|
psi, |
|
f"psi_t_{t}", |
|
{"time": t} |
|
) |
|
|
|
prob = qwf.measure(psi) |
|
self.data_collector.save_tensor_data( |
|
prob, |
|
f"probability_t_{t}", |
|
{"time": t} |
|
) |
|
|
|
total_prob = prob.sum().item() |
|
logger.info(f"Total probability at t={t}: {total_prob}") |
|
self.assertAlmostEqual(total_prob, 1.0, places=5) |
|
|
|
def test_em_field_operations(self): |
|
"""Test electromagnetic field operations.""" |
|
logger.info("Testing EM field operations") |
|
em = EnhancedEMField(grid_size=50) |
|
|
|
# Create test field |
|
k = 2 * np.pi / 10 |
|
omega = 2 * np.pi |
|
x = torch.linspace(-5, 5, 50) |
|
X, Y, Z = torch.meshgrid(x, x, x, indexing='ij') |
|
X = X.unsqueeze(0) |
|
Y = Y.unsqueeze(0) |
|
Z = Z.unsqueeze(0) |
|
|
|
# Vector potential |
|
A_amp = 1.0 |
|
A_z = A_amp * torch.cos(k * X) |
|
|
|
# Magnetic field |
|
B_x = -k * A_amp * torch.sin(k * X) |
|
B_y = torch.zeros_like(X) |
|
B_z = torch.zeros_like(X) |
|
|
|
# Electric field |
|
Ex = ComplexTensor(torch.zeros_like(X), torch.zeros_like(X)) |
|
Ey = ComplexTensor(torch.zeros_like(X), torch.zeros_like(X)) |
|
Ez = ComplexTensor( |
|
omega * A_amp * torch.sin(k * X), |
|
torch.zeros_like(X) |
|
) |
|
|
|
# Save field configurations |
|
self.data_collector.save_complex_tensor_data( |
|
Ex, "E_x", {"field": "electric", "component": "x"} |
|
) |
|
self.data_collector.save_complex_tensor_data( |
|
Ey, "E_y", {"field": "electric", "component": "y"} |
|
) |
|
self.data_collector.save_complex_tensor_data( |
|
Ez, "E_z", {"field": "electric", "component": "z"} |
|
) |
|
|
|
# Test operations |
|
curl = self.measure_performance( |
|
"em_curl", |
|
em.curl, |
|
[Ex, Ey, Ez] |
|
) |
|
for i, component in enumerate(['x', 'y', 'z']): |
|
self.data_collector.save_complex_tensor_data( |
|
curl[i], |
|
f"curl_{component}", |
|
{"operation": "curl", "component": component} |
|
) |
|
|
|
div = self.measure_performance( |
|
"em_divergence", |
|
em.divergence, |
|
[Ex, Ey, Ez] |
|
) |
|
self.data_collector.save_complex_tensor_data( |
|
div, |
|
"divergence", |
|
{"operation": "divergence"} |
|
) |
|
|
|
# Verify Maxwell's equations |
|
div_mean = div.abs().mean().item() |
|
logger.info(f"Mean divergence: {div_mean}") |
|
self.assertLess(div_mean, 1e-5) |
|
|
|
# B-field divergence |
|
B_fields = [ |
|
ComplexTensor(B_x, torch.zeros_like(B_x)), |
|
ComplexTensor(B_y, torch.zeros_like(B_y)), |
|
ComplexTensor(B_z, torch.zeros_like(B_z)) |
|
] |
|
B_div = self.measure_performance( |
|
"b_field_divergence", |
|
em.divergence, |
|
B_fields |
|
) |
|
self.data_collector.save_complex_tensor_data( |
|
B_div, |
|
"B_divergence", |
|
{"operation": "B-field divergence"} |
|
) |
|
|
|
def test_edge_cases(self): |
|
"""Test behavior at edge cases.""" |
|
logger.info("Testing edge cases") |
|
mis = MISTransform() |
|
|
|
# Near zero |
|
x = torch.linspace(-1e-10, 1e-10, 50) |
|
X, Y = torch.meshgrid(x, x, indexing='ij') |
|
Z_near_zero = ComplexTensor(X.unsqueeze(0), Y.unsqueeze(0)) |
|
|
|
result = self.measure_performance( |
|
"mis_near_zero", |
|
mis.__call__, |
|
Z_near_zero, t=1.0 |
|
) |
|
self.data_collector.save_complex_tensor_data( |
|
result, |
|
"edge_case_near_zero", |
|
{"case": "near_zero"} |
|
) |
|
|
|
# Large magnitude |
|
x = torch.linspace(-1e10, 1e10, 50) |
|
X, Y = torch.meshgrid(x, x, indexing='ij') |
|
Z_large = ComplexTensor(X.unsqueeze(0), Y.unsqueeze(0)) |
|
|
|
result = self.measure_performance( |
|
"mis_large_magnitude", |
|
mis.__call__, |
|
Z_large, t=1.0 |
|
) |
|
self.data_collector.save_complex_tensor_data( |
|
result, |
|
"edge_case_large_magnitude", |
|
{"case": "large_magnitude"} |
|
) |
|
|
|
def run_tests(): |
|
"""Run all tests and collect data.""" |
|
logger.info("Starting comprehensive test suite") |
|
suite = unittest.TestLoader().loadTestsFromTestCase(QuantumEMMISTests) |
|
runner = unittest.TextTestRunner(verbosity=2) |
|
result = runner.run(suite) |
|
logger.info("Test suite completed") |
|
return result |
|
|
|
if __name__ == "__main__": |
|
run_tests() |