Skip to content

Instantly share code, notes, and snippets.

@mfikes
Last active July 13, 2024 21:03
Show Gist options
  • Save mfikes/49a26c76e761ef1770f64eec3d262d67 to your computer and use it in GitHub Desktop.
Save mfikes/49a26c76e761ef1770f64eec3d262d67 to your computer and use it in GitHub Desktop.
Simulate Shor in a 3-qubit computer
% Number to factor
N = 6; % Example number
a = 5; % Randomly chosen integer coprime to N
% Number of qubits
n = 3;
dim = 2^n;
% Initialize the state vector to |000>
state = zeros(dim, 1);
state(1) = 1;
% Define the Hadamard gate
H = (1/sqrt(2)) * [1, 1; 1, -1];
% Apply the Hadamard gate to each qubit
Hn = 1;
for i = 1:n
Hn = kron(Hn, H);
end
state = Hn * state;
% Display the state vector after applying Hadamard gates
disp('State vector after applying Hadamard gates:');
disp(state);
% Define modular exponentiation function (simplified for demonstration)
function new_state = modular_exponentiation(state, a, N, n)
dim = 2^n;
new_state = zeros(dim, 1);
for x = 0:(dim-1)
y = mod(a^x, N);
new_state(y+1) = new_state(y+1) + state(x+1);
end
new_state = new_state / norm(new_state); % Normalize the state
end
% Apply modular exponentiation
state = modular_exponentiation(state, a, N, n);
% Display the state vector after modular exponentiation
disp('State vector after modular exponentiation:');
disp(state);
% Define the QFT function
function state = qft(state, n)
dim = 2^n;
Q = zeros(dim, dim);
for x = 0:(dim-1)
for y = 0:(dim-1)
Q(x+1, y+1) = exp(2 * pi * 1i * x * y / dim);
end
end
Q = Q / sqrt(dim);
state = Q * state;
end
% Apply the Quantum Fourier Transform
state = qft(state, n);
% Display the state vector after QFT
disp('State vector after QFT:');
disp(state);
% Measurement simulation (classical part, simplified)
function measured_value = measure(state)
probabilities = abs(state).^2;
measured_value = find(cumsum(probabilities) >= rand, 1) - 1;
end
% Measure the state
measured_value = measure(state);
disp(['Measured value: ', num2str(measured_value)]);
% Classical post-processing to find factors (refined)
function factors = classical_post_processing(measured_value, N, a)
factors = [];
if mod(measured_value, 2) == 1
return;
end
r = measured_value;
if r > 0
r1 = gcd(a^(r/2) - 1, N);
r2 = gcd(a^(r/2) + 1, N);
if r1 > 1 && r1 < N
factors = [factors, r1, N / r1];
end
if r2 > 1 && r2 < N && r1 ~= r2
factors = [factors, r2, N / r2];
end
end
end
% Find factors
factors = classical_post_processing(measured_value, N, a);
disp(['Factors: ', num2str(factors)]);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment