Skip to content

Commit

Permalink
Refactor run to take in oracle rather than function paramters directly
Browse files Browse the repository at this point in the history
  • Loading branch information
joelin0 committed Jun 27, 2017
1 parent 0d4f15c commit 6c1c196
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
22 changes: 12 additions & 10 deletions grove/bernstein_vazirani/bernstein_vazirani.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,20 @@ def bernstein_vazirani(oracle, qubits, ancilla):
return p


def run(cxn, vec_a, b):
def run(cxn, oracle, qubits, ancilla):
"""
Runs the Bernstein-Vazirani algorithm.
:param cxn: the QVM connection to use to run the programs
:param vec_a: a vector of 0s and 1s, to represent the a vector.
:param b: a bit, 0 or 1, to represent the b constant
:param oracle: the oracle to query that represents a function of the form f(x) = a*x+b (mod 2).
:param qubits: the input qubits
:param ancilla: the ancilla qubit
:return: a tuple that includes:
- the program's determination of a
- the program's determination of b
- the main program used to determine a
- the oracle used
:rtype: tuple
"""
# First, create the program to find a
qubits = range(len(vec_a))
ancilla = len(vec_a)

oracle = oracle_function(vec_a, b, qubits, ancilla)
bv_program = bernstein_vazirani(oracle, qubits, ancilla)

results = cxn.run_and_measure(bv_program, qubits)
Expand All @@ -81,7 +78,7 @@ def run(cxn, vec_a, b):
results = cxn.run_and_measure(oracle, [ancilla])
bv_b = results[0][0]

return bv_a, bv_b, bv_program, oracle
return bv_a, bv_b, bv_program


if __name__ == "__main__":
Expand All @@ -101,7 +98,12 @@ def run(cxn, vec_a, b):
b = int(raw_input("Give a single bit for b: "))

qvm = api.SyncConnection()
a, b, bv_program, oracle = run(qvm, vec_a, b)
qubits = range(len(vec_a))
ancilla = len(vec_a)

oracle = oracle_function(vec_a, b, qubits, ancilla)

a, b, bv_program = run(qvm, oracle, qubits, ancilla)
bitstring_a = "".join(map(str, a))
print "-----------------------------------"
print "The bitstring a is given by: ", bitstring
Expand Down
4 changes: 2 additions & 2 deletions grove/bernstein_vazirani/tests/test_bernstein_vazirani.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from grove.bernstein_vazirani.bernstein_vazirani import bernstein_vazirani, oracle_function
import pytest

@pytest.mark.skip(reason="Must add support for Forest connections in testing")
#@pytest.mark.skip(reason="Must add support for Forest connections in testing")
class TestOracleFunction(object):
def test_one_qubit(self):
vec_a = np.array([1])
Expand All @@ -32,7 +32,7 @@ def test_four_qubits(self):
for x in range(2**len(vec_a)):
_oracle_test_helper(vec_a, b, x)

@pytest.mark.skip(reason="Must add support for Forest connections in testing")
#@pytest.mark.skip(reason="Must add support for Forest connections in testing")
class TestBernsteinVazirani(object):
def test_one_qubit_all_zeros(self):
_bv_test_helper(np.array([0]), 0)
Expand Down

0 comments on commit 6c1c196

Please sign in to comment.