Last active
November 12, 2019 17:42
-
-
Save braised-babbage/eb2682c704740a0b3622820b3bed4bfb to your computer and use it in GitHub Desktop.
Translate a pyquil program to a single einsum call
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## einsum_compiler.py | |
## | |
## What it is | |
## ---------- | |
## | |
## Numpy has a standard function, einsum, which can be used to evaluate large | |
## nested summations. Here we show how to translate a pyQuil program consisting | |
## of only gate applications to a single einsum call. | |
## | |
## | |
## The basic idea | |
## -------------- | |
## | |
## The program | |
## | |
## H 0 | |
## CNOT 0 1 | |
## X 1 | |
## | |
## is equivalent to the circuit | |
## | |
## --- | |
## |q0> ---| H |-----*------------ | |
## --- | | |
## --- --- | |
## |q1> -----------| + |---| X |--- | |
## --- --- | |
## | |
## The usual way of thinking about how to evaluate this is, going from top to | |
## bottom in the program, or left to right in the diagram construct a matrix | |
## representing the gate, and multiply the wavefunction by this. | |
## | |
## An alternative perspective is that the quantity computed is one large sum, | |
## with indices of summation corresponding to 'wires' in the circuit diagram | |
## (aka a 'tensor contraction'). | |
## | |
## - To a 1Q unitary U, associate an array u of size (2,2) defined as | |
## u[out,in] = <out|U|in> | |
## - To a 2Q unitary U, associate an array u of size (2,2,2,2) defined as | |
## u[out0,out1,in0,in1] = <out0 out1|U|in0 in1> | |
## - To a n-qubit wavefunction ψ, associate an array wf of size (2,...,2), a total of | |
## n dimensions, e.g. (for n = 2) | |
## psi[out0,out1] = <out0 out1|ψ> | |
## | |
## The result of applying the circuit to ψ is ψ', where psi'[c,d] is given | |
## by the sum | |
## | |
## sum_{i,j,k,l} psi[i,j]*h[k,i]*cnot[l,m,k,j]*x[n,m] | |
## | |
## Pictorially, every index is a wire in the diagram | |
## | |
## --- | |
## |q0> -i-| H |--k--*-----l------ | |
## --- | | |
## --- --- | |
## |q1> -----j-----| + |-m-| X |-n- | |
## --- --- | |
## | |
## | |
## with free indices corresponding to free wires, and summed indices | |
## corresponding to wires connected at both ends. | |
## | |
## We thus walk the pyQuil program from top to bottom. A gate on k qubits | |
## consumes k previously free indices and produces k new free indices. The | |
## consumed indices are to be 'summed over' by einsum. | |
from typing import Optional | |
import numpy as np | |
from numpy import einsum # if you hate blas | |
# from opt_einsum import contract as einsum # if you like blas | |
from pyquil import Program | |
from pyquil.quilatom import Parameter | |
from pyquil.quilbase import Gate | |
from pyquil.gates import * | |
from pyquil.gate_matrices import QUANTUM_GATES | |
from pyquil.wavefunction import Wavefunction | |
next_index = 0 | |
def run(program: Program) -> Wavefunction: | |
""" Run the Quil program, returning a Wavefunction result. """ | |
num_qubits = max(program.get_qubits()) + 1 | |
# We represent a n-qubit wavefunction as a n dimensional complex array | |
wf = np.zeros((2,)*num_qubits, dtype=np.complex128) | |
wf[(0,)*num_qubits] = 1+0j | |
assert all(0 <= q < num_qubits for q in program.get_qubits()) | |
global next_index | |
next_index = 0 | |
free_indices = {q:q for q in range(num_qubits)} | |
# we build a list of alternating tensor, indices | |
# for starters, we have the wavefunction, and with its free indices | |
contraction_data = [wf, [free_indices[q] for q in reversed(range(num_qubits))]] | |
next_index = num_qubits | |
for instr in program.instructions: | |
if not isinstance(instr, Gate): | |
raise ValueError(f"None-gate instruction {instr} is not supported.") | |
# get the gate tensor, along with summation indices | |
# note: this updates free_indices | |
tensor, indices = gate_tensor_contraction(instr, free_indices) | |
contraction_data.append(tensor) | |
contraction_data.append(indices) | |
result = einsum(*contraction_data, | |
[free_indices[q] for q in reversed(range(num_qubits))]) | |
return Wavefunction(result.flatten()) | |
def gate_tensor_contraction(instr: Gate, qubit_indices: dict): | |
""" | |
Given a gate, return a pair (tensor, indices) representing the application | |
of the gate, given the supplied qubit_indices. Updates qubit_indices. | |
""" | |
global next_index | |
qubits = [q.index for q in instr.qubits] | |
matrix = gate_matrix(instr) | |
# reshape to have k 'lower' indices followed by k 'higher' indices | |
# e.g. 1Q gate A looks like A_i^j, | |
# 2Q gate B looks like B_ij^kl | |
tensor = matrix.reshape((2,)*(2*len(qubits))) | |
summed = [qubit_indices[q] for q in qubits] | |
free = [] | |
for q in qubits: | |
idx = next_index | |
next_index += 1 | |
free.append(idx) | |
qubit_indices[q] = idx | |
return tensor, free+summed | |
def gate_matrix(gate: Gate): | |
""" Get the matrix corresponding to the given gate. """ | |
if any(isinstance(param, Parameter) for param in gate.params): | |
raise ValueError("Cannot produce a matrix from a gate with non-constant parameters.") | |
if len(gate.modifiers) == 0: # base case | |
if len(gate.params) > 0: | |
return QUANTUM_GATES[gate.name](*gate.params) | |
else: | |
return QUANTUM_GATES[gate.name] | |
else: | |
raise ValueError('Gate modifiers are not currently supported.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment