Last active
January 10, 2025 22:22
-
-
Save ehartford/22fd16973569f0871d36f562f610955a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
shor.py | |
Generalized implementation of Shor's algorithm to factor arbitrary numbers into their prime factors. | |
Can run on both local simulator and IBM Quantum hardware. | |
python shor.py --local --N 77 | |
""" | |
import numpy as np | |
from qiskit import QuantumCircuit, transpile | |
from qiskit_aer import AerSimulator | |
from qiskit_ibm_runtime import QiskitRuntimeService, Sampler | |
from math import gcd, ceil, log2 | |
from fractions import Fraction | |
import logging | |
import argparse | |
import time | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def quantum_random_int(min_val: int, max_val: int, backend=None) -> int: | |
""" | |
Generate a random integer using quantum randomness. | |
Args: | |
min_val (int): Minimum value (inclusive) | |
max_val (int): Maximum value (exclusive) | |
backend: Qiskit backend to use (if None, uses AerSimulator) | |
Returns: | |
int: Random integer between min_val and max_val | |
""" | |
# Calculate number of bits needed | |
range_size = max_val - min_val | |
n_bits = max(1, ceil(log2(range_size))) | |
# Create quantum circuit | |
qc = QuantumCircuit(n_bits, n_bits) | |
# Apply Hadamard to all qubits to create superposition | |
for q in range(n_bits): | |
qc.h(q) | |
# Measure all qubits | |
qc.measure(range(n_bits), range(n_bits)) | |
# Use provided backend or create new one | |
if backend is None: | |
backend = AerSimulator() | |
# Run the circuit | |
result = backend.run(qc, shots=1, memory=True).result() | |
# Convert binary string to integer | |
binary_str = result.get_memory()[0] | |
value = int(binary_str, 2) | |
# Map to desired range | |
value = min_val + (value % range_size) | |
return value | |
def calculate_required_qubits(N: int) -> tuple[int, int]: | |
""" | |
Calculate the number of qubits needed for Shor's algorithm. | |
""" | |
n = ceil(log2(N)) | |
n_count = 2 * n | |
n_data = n | |
return n_count, n_data | |
def get_backend(use_ibm: bool = False, required_qubits: int = 0): | |
""" | |
Get either a local Aer simulator or IBM Quantum backend. | |
""" | |
if use_ibm: | |
try: | |
service = QiskitRuntimeService() | |
backends = service.backends() | |
real_backends = [b for b in backends if not b.name.contains('simulator') | |
and b.configuration().n_qubits >= required_qubits] | |
if not real_backends: | |
raise ValueError(f"No IBM Quantum backends found with {required_qubits} qubits") | |
selected_backend = min(real_backends, | |
key=lambda b: b.status().pending_jobs) | |
logger.info(f"Selected IBM Quantum backend: {selected_backend.name}") | |
return selected_backend, service | |
except Exception as e: | |
logger.error(f"Error accessing IBM Quantum: {str(e)}") | |
logger.info("Falling back to Aer simulator") | |
return AerSimulator(), None | |
else: | |
logger.info("Using local Aer simulator") | |
return AerSimulator(), None | |
def controlled_modular_addition(qc: QuantumCircuit, N: int, a: int, | |
ctrl: int, x_reg: list, anc_reg: list): | |
""" | |
Implements controlled modular addition |x⟩ -> |x + a mod N⟩ | |
""" | |
# Convert inputs to Python ints | |
N, a = int(N), int(a) | |
n_bits = len(x_reg) | |
a_bits = [(a >> i) & 1 for i in range(n_bits)] | |
# Addition part | |
for i in range(n_bits): | |
if a_bits[i]: | |
qc.ccx(ctrl, x_reg[i], anc_reg[0]) | |
# Propagate carry through available qubits | |
for j in range(i+1, min(n_bits, len(x_reg))): | |
qc.ccx(x_reg[j-1], anc_reg[j-1], anc_reg[j]) | |
# Flip target bits within bounds | |
for j in range(i+1, min(n_bits, len(x_reg))): | |
qc.cx(anc_reg[j-1], x_reg[j]) | |
# Uncompute carries | |
for j in range(min(n_bits-1, len(x_reg)-1), i, -1): | |
qc.ccx(x_reg[j-1], anc_reg[j-1], anc_reg[j]) | |
# Modular reduction | |
N_bits = [(N >> i) & 1 for i in range(n_bits)] | |
# Controlled subtraction of N (within bounds) | |
for i in range(min(n_bits, len(x_reg))): | |
if N_bits[i]: | |
qc.cx(ctrl, x_reg[i]) | |
# Uncompute addition | |
for i in range(n_bits-1, -1, -1): | |
if a_bits[i]: | |
for j in range(i+1, min(n_bits-1, len(x_reg)-1)): | |
qc.ccx(x_reg[j], anc_reg[j], anc_reg[j+1]) | |
for j in range(min(n_bits-1, len(x_reg)-1), i, -1): | |
if j+1 < len(x_reg): # Check if index is valid | |
qc.cx(anc_reg[j], x_reg[j+1]) | |
for j in range(min(n_bits-1, len(x_reg)-1), i+1, -1): | |
qc.ccx(x_reg[j-1], anc_reg[j-1], anc_reg[j]) | |
qc.ccx(ctrl, x_reg[i], anc_reg[0]) | |
def controlled_modular_multiplication(qc: QuantumCircuit, N: int, a: int, | |
ctrl: int, x_reg: list, anc_reg: list): | |
""" | |
Implements quantum modular multiplication |x⟩ -> |ax mod N⟩ | |
""" | |
# Convert inputs to Python ints | |
N, a = int(N), int(a) | |
n_bits = len(x_reg) | |
for i in range(n_bits): | |
if (a >> i) & 1: | |
controlled_modular_addition(qc, N, (1 << i), ctrl, x_reg, anc_reg) | |
def improved_qft_dagger(n_count: int) -> QuantumCircuit: | |
""" | |
Improved inverse QFT with better phase estimation accuracy. | |
""" | |
qc = QuantumCircuit(n_count, name="QFT†") | |
for j in range(n_count//2): | |
qc.swap(j, n_count-j-1) | |
for j in range(n_count): | |
for k in range(j): | |
qc.cp(-np.pi/float(2**(j-k)), k, j) | |
qc.h(j) | |
return qc | |
def create_shor_circuit(a: int, N: int) -> QuantumCircuit: | |
""" | |
Creates a Shor's algorithm circuit for factoring N. | |
""" | |
# Convert inputs to Python ints | |
a, N = int(a), int(N) | |
n_count, n_data = calculate_required_qubits(N) | |
# We need to ensure we have enough ancilla qubits | |
n_ancilla = n_data + 2 # Added one more ancilla qubit | |
total_qubits = n_count + n_data + n_ancilla | |
qc = QuantumCircuit(total_qubits, n_count) | |
# Define registers with proper bounds | |
counting_qubits = list(range(n_count)) | |
x_register = list(range(n_count, n_count + n_data)) | |
ancilla = list(range(n_count + n_data, total_qubits)) | |
# Initialize counting register | |
for q in counting_qubits: | |
qc.h(q) | |
# Initialize x register to |1⟩ | |
qc.x(x_register[0]) | |
# Apply controlled modular multiplications | |
for power in range(n_count): | |
if power < len(counting_qubits): # Check if we have enough qubits | |
controlled_modular_multiplication( | |
qc, N, pow(a, 2**power, N), | |
counting_qubits[power], x_register, ancilla | |
) | |
# Apply inverse QFT to counting register | |
qc.append(improved_qft_dagger(n_count), counting_qubits) | |
# Measure counting register | |
qc.measure(counting_qubits, range(n_count)) | |
return qc | |
def find_factors_from_period(a: int, r: int, N: int) -> tuple[int, int]: | |
""" | |
Attempts to find factors of N given the period r. | |
""" | |
# Convert inputs to Python ints | |
a, r, N = int(a), int(r), int(N) | |
if r % 2 != 0: | |
return 1, N | |
half_r = r // 2 | |
factor1 = gcd(pow(a, half_r, N) - 1, N) | |
factor2 = gcd(pow(a, half_r, N) + 1, N) | |
for factor in [factor1, factor2]: | |
if 1 < factor < N: | |
return factor, N // factor | |
return 1, N | |
def check_if_prime(n: int) -> bool: | |
""" | |
Simple primality test. | |
""" | |
n = int(n) | |
if n < 2: | |
return False | |
for i in range(2, int(n ** 0.5) + 1): | |
if n % i == 0: | |
return False | |
return True | |
def find_prime_factors(N: int, max_tries: int = 10, seed: int = 42, use_ibm: bool = False) -> list[int]: | |
""" | |
Recursively find all prime factors using Shor's algorithm. | |
""" | |
N = int(N) | |
if N <= 1: | |
return [] | |
# Handle powers of 2 directly | |
if N & (N - 1) == 0: # If N is a power of 2 | |
factors = [] | |
while N > 1: | |
factors.append(2) | |
N //= 2 | |
return factors | |
# If N is even, factor out all 2s first | |
if N % 2 == 0: | |
factors = [2] | |
factors.extend(find_prime_factors(N // 2, max_tries, seed, use_ibm)) | |
return factors | |
# If N is prime, return it | |
if check_if_prime(N): | |
return [N] | |
try: | |
# Try to factor N using Shor's algorithm | |
factor1, factor2 = run_shor_algorithm(N, max_tries, seed, use_ibm) | |
if factor1 == 1 or factor2 == 1: | |
# If Shor's algorithm failed, try classical factoring | |
for i in range(3, int(N ** 0.5) + 1, 2): | |
if N % i == 0: | |
return find_prime_factors(i, max_tries, seed, use_ibm) + \ | |
find_prime_factors(N // i, max_tries, seed, use_ibm) | |
# If we get here, something went wrong | |
raise ValueError(f"Failed to factor {N}") | |
# Recursively factor the factors | |
return find_prime_factors(factor1, max_tries, seed, use_ibm) + \ | |
find_prime_factors(factor2, max_tries, seed, use_ibm) | |
except Exception as e: | |
logger.error(f"Error factoring {N}: {str(e)}") | |
raise | |
def run_shor_algorithm(N: int, max_tries: int = 10, seed: int = 42, use_ibm: bool = False) -> tuple[int, int]: | |
""" | |
Run Shor's algorithm to factor N. | |
""" | |
N = int(N) | |
if N < 3: | |
raise ValueError("N must be ≥ 3") | |
if N % 2 == 0: | |
return 2, N // 2 | |
if check_if_prime(N): | |
raise ValueError(f"{N} is prime") | |
n_count, n_data = calculate_required_qubits(N) | |
total_qubits = n_count + n_data + n_data + 1 | |
logger.info(f"Required qubits: {total_qubits}") | |
backend, service = get_backend(use_ibm, total_qubits) | |
# Keep track of tried numbers | |
tried_numbers = set() | |
attempts = 0 | |
while attempts < max_tries and len(tried_numbers) < N-2: | |
attempts += 1 | |
logger.info(f"\nAttempt {attempts}/{max_tries}") | |
# Generate a random number using quantum randomness | |
available_numbers = list(set(range(2, N)) - tried_numbers) | |
if not available_numbers: | |
logger.warning("No more numbers to try") | |
break | |
# Use quantum RNG to select index from available numbers | |
idx = quantum_random_int(0, len(available_numbers), backend) | |
a = available_numbers[idx] | |
tried_numbers.add(a) | |
logger.info(f"Testing with a = {a} (quantum random choice)") | |
if (g := gcd(a, N)) > 1: | |
logger.info(f"Found factor immediately: gcd({a},{N})={g}") | |
return g, N // g | |
try: | |
qc = create_shor_circuit(a, N) | |
opt_level = 1 if use_ibm else 3 | |
tqc = transpile(qc, backend, optimization_level=opt_level) | |
if use_ibm: | |
sampler = Sampler(backend=backend) | |
job = sampler.run(tqc, shots=1) | |
logger.info(f"Job ID: {job.job_id()}") | |
logger.info("Waiting for job completion...") | |
result = job.result() | |
counts = result.quasi_dists[0] | |
measured_state = max(counts.items(), key=lambda x: x[1])[0] | |
else: | |
job = backend.run(tqc, shots=1, memory=True) | |
result = job.result() | |
measured_state = int(result.get_memory()[0], 2) | |
n_count = calculate_required_qubits(N)[0] | |
measured_phase = measured_state / (2**n_count) | |
frac = Fraction(measured_phase).limit_denominator(2**n_count) | |
r = frac.denominator | |
logger.info(f"Measured phase = {measured_phase:.4f}") | |
logger.info(f"Estimated period = {r}") | |
factor1, factor2 = find_factors_from_period(a, r, N) | |
if factor1 > 1: | |
logger.info(f"Success! Factors found: {factor1} × {factor2} = {N}") | |
return factor1, factor2 | |
except Exception as e: | |
logger.error(f"Error in attempt {attempts}: {str(e)}") | |
continue | |
logger.warning("Failed to find factors within maximum attempts") | |
return 1, N | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Factor N into primes using Shor's algorithm") | |
parser.add_argument('--N', type=int, required=True, | |
help='Number to factor') | |
# Create a mutually exclusive group for backend selection | |
backend_group = parser.add_mutually_exclusive_group() | |
backend_group.add_argument('--local', action='store_true', default=True, | |
help='Use local Aer simulator (default)') | |
backend_group.add_argument('--ibm', action='store_true', | |
help='Use IBM Quantum backend') | |
parser.add_argument('--tries', type=int, default=10, | |
help='Maximum number of attempts') | |
parser.add_argument('--seed', type=int, default=1234, | |
help='Random seed') | |
args = parser.parse_args() | |
# If --ibm is specified, override --local | |
use_ibm = args.ibm | |
start_time = time.time() | |
try: | |
prime_factors = find_prime_factors(args.N, max_tries=args.tries, | |
seed=args.seed, use_ibm=use_ibm) | |
elapsed_time = time.time() - start_time | |
prime_factors.sort() | |
if len(prime_factors) > 0: | |
factorization = " × ".join(map(str, prime_factors)) | |
product = " = " + str(args.N) | |
if len(prime_factors) > 1: | |
parentheses = "(" + factorization + ")" | |
else: | |
parentheses = factorization | |
else: | |
parentheses = str(args.N) | |
product = "" | |
print(f"\nExecution time: {elapsed_time:.2f} seconds") | |
print(f"Prime factorization: {parentheses}{product}") | |
except ValueError as e: | |
print(f"Error: {str(e)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment