Skip to content

Instantly share code, notes, and snippets.

@ehartford
Last active January 10, 2025 22:22
Show Gist options
  • Save ehartford/22fd16973569f0871d36f562f610955a to your computer and use it in GitHub Desktop.
Save ehartford/22fd16973569f0871d36f562f610955a to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
shor.py
Generalized implementation of Shor's algorithm to factor arbitrary numbers into their prime factors.
Can run on both local simulator and IBM Quantum hardware.
python shor.py --local --N 77
"""
import numpy as np
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator
from qiskit_ibm_runtime import QiskitRuntimeService, Sampler
from math import gcd, ceil, log2
from fractions import Fraction
import logging
import argparse
import time
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def quantum_random_int(min_val: int, max_val: int, backend=None) -> int:
"""
Generate a random integer using quantum randomness.
Args:
min_val (int): Minimum value (inclusive)
max_val (int): Maximum value (exclusive)
backend: Qiskit backend to use (if None, uses AerSimulator)
Returns:
int: Random integer between min_val and max_val
"""
# Calculate number of bits needed
range_size = max_val - min_val
n_bits = max(1, ceil(log2(range_size)))
# Create quantum circuit
qc = QuantumCircuit(n_bits, n_bits)
# Apply Hadamard to all qubits to create superposition
for q in range(n_bits):
qc.h(q)
# Measure all qubits
qc.measure(range(n_bits), range(n_bits))
# Use provided backend or create new one
if backend is None:
backend = AerSimulator()
# Run the circuit
result = backend.run(qc, shots=1, memory=True).result()
# Convert binary string to integer
binary_str = result.get_memory()[0]
value = int(binary_str, 2)
# Map to desired range
value = min_val + (value % range_size)
return value
def calculate_required_qubits(N: int) -> tuple[int, int]:
"""
Calculate the number of qubits needed for Shor's algorithm.
"""
n = ceil(log2(N))
n_count = 2 * n
n_data = n
return n_count, n_data
def get_backend(use_ibm: bool = False, required_qubits: int = 0):
"""
Get either a local Aer simulator or IBM Quantum backend.
"""
if use_ibm:
try:
service = QiskitRuntimeService()
backends = service.backends()
real_backends = [b for b in backends if not b.name.contains('simulator')
and b.configuration().n_qubits >= required_qubits]
if not real_backends:
raise ValueError(f"No IBM Quantum backends found with {required_qubits} qubits")
selected_backend = min(real_backends,
key=lambda b: b.status().pending_jobs)
logger.info(f"Selected IBM Quantum backend: {selected_backend.name}")
return selected_backend, service
except Exception as e:
logger.error(f"Error accessing IBM Quantum: {str(e)}")
logger.info("Falling back to Aer simulator")
return AerSimulator(), None
else:
logger.info("Using local Aer simulator")
return AerSimulator(), None
def controlled_modular_addition(qc: QuantumCircuit, N: int, a: int,
ctrl: int, x_reg: list, anc_reg: list):
"""
Implements controlled modular addition |x⟩ -> |x + a mod N⟩
"""
# Convert inputs to Python ints
N, a = int(N), int(a)
n_bits = len(x_reg)
a_bits = [(a >> i) & 1 for i in range(n_bits)]
# Addition part
for i in range(n_bits):
if a_bits[i]:
qc.ccx(ctrl, x_reg[i], anc_reg[0])
# Propagate carry through available qubits
for j in range(i+1, min(n_bits, len(x_reg))):
qc.ccx(x_reg[j-1], anc_reg[j-1], anc_reg[j])
# Flip target bits within bounds
for j in range(i+1, min(n_bits, len(x_reg))):
qc.cx(anc_reg[j-1], x_reg[j])
# Uncompute carries
for j in range(min(n_bits-1, len(x_reg)-1), i, -1):
qc.ccx(x_reg[j-1], anc_reg[j-1], anc_reg[j])
# Modular reduction
N_bits = [(N >> i) & 1 for i in range(n_bits)]
# Controlled subtraction of N (within bounds)
for i in range(min(n_bits, len(x_reg))):
if N_bits[i]:
qc.cx(ctrl, x_reg[i])
# Uncompute addition
for i in range(n_bits-1, -1, -1):
if a_bits[i]:
for j in range(i+1, min(n_bits-1, len(x_reg)-1)):
qc.ccx(x_reg[j], anc_reg[j], anc_reg[j+1])
for j in range(min(n_bits-1, len(x_reg)-1), i, -1):
if j+1 < len(x_reg): # Check if index is valid
qc.cx(anc_reg[j], x_reg[j+1])
for j in range(min(n_bits-1, len(x_reg)-1), i+1, -1):
qc.ccx(x_reg[j-1], anc_reg[j-1], anc_reg[j])
qc.ccx(ctrl, x_reg[i], anc_reg[0])
def controlled_modular_multiplication(qc: QuantumCircuit, N: int, a: int,
ctrl: int, x_reg: list, anc_reg: list):
"""
Implements quantum modular multiplication |x⟩ -> |ax mod N⟩
"""
# Convert inputs to Python ints
N, a = int(N), int(a)
n_bits = len(x_reg)
for i in range(n_bits):
if (a >> i) & 1:
controlled_modular_addition(qc, N, (1 << i), ctrl, x_reg, anc_reg)
def improved_qft_dagger(n_count: int) -> QuantumCircuit:
"""
Improved inverse QFT with better phase estimation accuracy.
"""
qc = QuantumCircuit(n_count, name="QFT†")
for j in range(n_count//2):
qc.swap(j, n_count-j-1)
for j in range(n_count):
for k in range(j):
qc.cp(-np.pi/float(2**(j-k)), k, j)
qc.h(j)
return qc
def create_shor_circuit(a: int, N: int) -> QuantumCircuit:
"""
Creates a Shor's algorithm circuit for factoring N.
"""
# Convert inputs to Python ints
a, N = int(a), int(N)
n_count, n_data = calculate_required_qubits(N)
# We need to ensure we have enough ancilla qubits
n_ancilla = n_data + 2 # Added one more ancilla qubit
total_qubits = n_count + n_data + n_ancilla
qc = QuantumCircuit(total_qubits, n_count)
# Define registers with proper bounds
counting_qubits = list(range(n_count))
x_register = list(range(n_count, n_count + n_data))
ancilla = list(range(n_count + n_data, total_qubits))
# Initialize counting register
for q in counting_qubits:
qc.h(q)
# Initialize x register to |1⟩
qc.x(x_register[0])
# Apply controlled modular multiplications
for power in range(n_count):
if power < len(counting_qubits): # Check if we have enough qubits
controlled_modular_multiplication(
qc, N, pow(a, 2**power, N),
counting_qubits[power], x_register, ancilla
)
# Apply inverse QFT to counting register
qc.append(improved_qft_dagger(n_count), counting_qubits)
# Measure counting register
qc.measure(counting_qubits, range(n_count))
return qc
def find_factors_from_period(a: int, r: int, N: int) -> tuple[int, int]:
"""
Attempts to find factors of N given the period r.
"""
# Convert inputs to Python ints
a, r, N = int(a), int(r), int(N)
if r % 2 != 0:
return 1, N
half_r = r // 2
factor1 = gcd(pow(a, half_r, N) - 1, N)
factor2 = gcd(pow(a, half_r, N) + 1, N)
for factor in [factor1, factor2]:
if 1 < factor < N:
return factor, N // factor
return 1, N
def check_if_prime(n: int) -> bool:
"""
Simple primality test.
"""
n = int(n)
if n < 2:
return False
for i in range(2, int(n ** 0.5) + 1):
if n % i == 0:
return False
return True
def find_prime_factors(N: int, max_tries: int = 10, seed: int = 42, use_ibm: bool = False) -> list[int]:
"""
Recursively find all prime factors using Shor's algorithm.
"""
N = int(N)
if N <= 1:
return []
# Handle powers of 2 directly
if N & (N - 1) == 0: # If N is a power of 2
factors = []
while N > 1:
factors.append(2)
N //= 2
return factors
# If N is even, factor out all 2s first
if N % 2 == 0:
factors = [2]
factors.extend(find_prime_factors(N // 2, max_tries, seed, use_ibm))
return factors
# If N is prime, return it
if check_if_prime(N):
return [N]
try:
# Try to factor N using Shor's algorithm
factor1, factor2 = run_shor_algorithm(N, max_tries, seed, use_ibm)
if factor1 == 1 or factor2 == 1:
# If Shor's algorithm failed, try classical factoring
for i in range(3, int(N ** 0.5) + 1, 2):
if N % i == 0:
return find_prime_factors(i, max_tries, seed, use_ibm) + \
find_prime_factors(N // i, max_tries, seed, use_ibm)
# If we get here, something went wrong
raise ValueError(f"Failed to factor {N}")
# Recursively factor the factors
return find_prime_factors(factor1, max_tries, seed, use_ibm) + \
find_prime_factors(factor2, max_tries, seed, use_ibm)
except Exception as e:
logger.error(f"Error factoring {N}: {str(e)}")
raise
def run_shor_algorithm(N: int, max_tries: int = 10, seed: int = 42, use_ibm: bool = False) -> tuple[int, int]:
"""
Run Shor's algorithm to factor N.
"""
N = int(N)
if N < 3:
raise ValueError("N must be ≥ 3")
if N % 2 == 0:
return 2, N // 2
if check_if_prime(N):
raise ValueError(f"{N} is prime")
n_count, n_data = calculate_required_qubits(N)
total_qubits = n_count + n_data + n_data + 1
logger.info(f"Required qubits: {total_qubits}")
backend, service = get_backend(use_ibm, total_qubits)
# Keep track of tried numbers
tried_numbers = set()
attempts = 0
while attempts < max_tries and len(tried_numbers) < N-2:
attempts += 1
logger.info(f"\nAttempt {attempts}/{max_tries}")
# Generate a random number using quantum randomness
available_numbers = list(set(range(2, N)) - tried_numbers)
if not available_numbers:
logger.warning("No more numbers to try")
break
# Use quantum RNG to select index from available numbers
idx = quantum_random_int(0, len(available_numbers), backend)
a = available_numbers[idx]
tried_numbers.add(a)
logger.info(f"Testing with a = {a} (quantum random choice)")
if (g := gcd(a, N)) > 1:
logger.info(f"Found factor immediately: gcd({a},{N})={g}")
return g, N // g
try:
qc = create_shor_circuit(a, N)
opt_level = 1 if use_ibm else 3
tqc = transpile(qc, backend, optimization_level=opt_level)
if use_ibm:
sampler = Sampler(backend=backend)
job = sampler.run(tqc, shots=1)
logger.info(f"Job ID: {job.job_id()}")
logger.info("Waiting for job completion...")
result = job.result()
counts = result.quasi_dists[0]
measured_state = max(counts.items(), key=lambda x: x[1])[0]
else:
job = backend.run(tqc, shots=1, memory=True)
result = job.result()
measured_state = int(result.get_memory()[0], 2)
n_count = calculate_required_qubits(N)[0]
measured_phase = measured_state / (2**n_count)
frac = Fraction(measured_phase).limit_denominator(2**n_count)
r = frac.denominator
logger.info(f"Measured phase = {measured_phase:.4f}")
logger.info(f"Estimated period = {r}")
factor1, factor2 = find_factors_from_period(a, r, N)
if factor1 > 1:
logger.info(f"Success! Factors found: {factor1} × {factor2} = {N}")
return factor1, factor2
except Exception as e:
logger.error(f"Error in attempt {attempts}: {str(e)}")
continue
logger.warning("Failed to find factors within maximum attempts")
return 1, N
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Factor N into primes using Shor's algorithm")
parser.add_argument('--N', type=int, required=True,
help='Number to factor')
# Create a mutually exclusive group for backend selection
backend_group = parser.add_mutually_exclusive_group()
backend_group.add_argument('--local', action='store_true', default=True,
help='Use local Aer simulator (default)')
backend_group.add_argument('--ibm', action='store_true',
help='Use IBM Quantum backend')
parser.add_argument('--tries', type=int, default=10,
help='Maximum number of attempts')
parser.add_argument('--seed', type=int, default=1234,
help='Random seed')
args = parser.parse_args()
# If --ibm is specified, override --local
use_ibm = args.ibm
start_time = time.time()
try:
prime_factors = find_prime_factors(args.N, max_tries=args.tries,
seed=args.seed, use_ibm=use_ibm)
elapsed_time = time.time() - start_time
prime_factors.sort()
if len(prime_factors) > 0:
factorization = " × ".join(map(str, prime_factors))
product = " = " + str(args.N)
if len(prime_factors) > 1:
parentheses = "(" + factorization + ")"
else:
parentheses = factorization
else:
parentheses = str(args.N)
product = ""
print(f"\nExecution time: {elapsed_time:.2f} seconds")
print(f"Prime factorization: {parentheses}{product}")
except ValueError as e:
print(f"Error: {str(e)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment