Skip to content

Commit

Permalink
Setup for disabling state_vector copy (quantumlib#5324)
Browse files Browse the repository at this point in the history
Part of quantumlib#3494 (requires deprecation cycle to resolve).

This adds the `copy` parameter to all `state_vector` methods affected by quantumlib#3494, and sets up deprecation messages to change the default copy behavior to "don't copy" in the next Cirq release.

The only substantive changes are in `sparse_simulator.py`, `state_vector.py`, and `state_vector_simulator.py`; all other changes simply inject `copy` to prevent deprecation warnings from `state_vector` calls (True for step results, False for final results).
  • Loading branch information
95-martin-orion authored Jun 10, 2022
1 parent 4f625a2 commit a402385
Show file tree
Hide file tree
Showing 20 changed files with 140 additions and 68 deletions.
34 changes: 17 additions & 17 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def test_recursive_params():

# First example should behave like an X when simulated
result = cirq.Simulator().simulate(cirq.Circuit(circuitop), param_resolver=outer_params)
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])


@pytest.mark.parametrize('add_measurements', [True, False])
Expand Down Expand Up @@ -343,9 +343,9 @@ def test_repeat_zero_times(add_measurements, use_repetition_ids, initial_reps):
subcircuit.freeze(), repetitions=initial_reps, use_repetition_ids=use_repetition_ids
)
result = cirq.Simulator().simulate(cirq.Circuit(op))
assert np.allclose(result.state_vector(), [0, 1] if initial_reps % 2 else [1, 0])
assert np.allclose(result.state_vector(copy=False), [0, 1] if initial_reps % 2 else [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op**0))
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])


def test_no_repetition_ids():
Expand Down Expand Up @@ -375,13 +375,13 @@ def test_parameterized_repeat():
assert cirq.parameter_names(op) == {'a'}
assert not cirq.has_unitary(op)
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 0})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': -1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5})
with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'):
Expand All @@ -390,13 +390,13 @@ def test_parameterized_repeat():
assert cirq.parameter_names(op) == {'a'}
assert not cirq.has_unitary(op)
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 0})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': -1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5})
with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'):
Expand All @@ -405,11 +405,11 @@ def test_parameterized_repeat():
assert cirq.parameter_names(op) == {'a', 'b'}
assert not cirq.has_unitary(op)
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2, 'b': 1})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 2})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1})
with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'):
Expand All @@ -418,11 +418,11 @@ def test_parameterized_repeat():
assert cirq.parameter_names(op) == {'a', 'b'}
assert not cirq.has_unitary(op)
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1})
assert np.allclose(result.state_vector(), [1, 0])
assert np.allclose(result.state_vector(copy=False), [1, 0])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1.5})
assert np.allclose(result.state_vector(), [0, 1])
assert np.allclose(result.state_vector(copy=False), [0, 1])
with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1.5})
with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'):
Expand Down
6 changes: 4 additions & 2 deletions cirq-core/cirq/contrib/quantum_volume/quantum_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ def compute_heavy_set(circuit: cirq.Circuit) -> List[int]:
# output is defined in terms of probabilities, where our wave function is in
# terms of amplitudes. We convert it by using the Born rule: squaring each
# amplitude and taking their absolute value
median = np.median(np.abs(results.state_vector() ** 2))
median = np.median(np.abs(results.state_vector(copy=False) ** 2))

# The output wave function is a vector from the result value (big-endian) to
# the probability of that bit-string. Return all of the bit-string
# values that have a probability greater than the median.
return [idx for idx, amp in enumerate(results.state_vector()) if np.abs(amp**2) > median]
return [
idx for idx, amp in enumerate(results.state_vector(copy=False)) if np.abs(amp**2) > median
]


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/experiments/grid_parallel_two_qubit_xeb.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ def _get_xeb_result(
while moment_index < 2 * depth:
step_result = next(step_results)
moment_index += 1
amplitudes = step_result.state_vector()
# copy=False is safe because state_vector_to_probabilities will copy anyways
amplitudes = step_result.state_vector(copy=False)
probabilities = value.state_vector_to_probabilities(amplitudes)
_, counts = np.unique(measurements, return_counts=True)
empirical_probs = counts / len(measurements)
Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/experiments/xeb_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __call__(self, task: _Simulate2qXEBTask) -> List[Dict[str, Any]]:
if cycle_depth not in cycle_depths:
continue

psi = step_result.state_vector()
# copy=False is safe because state_vector_to_probabilities will copy anyways
psi = step_result.state_vector(copy=False)
pure_probs = value.state_vector_to_probabilities(psi)

records += [
Expand Down
6 changes: 5 additions & 1 deletion cirq-core/cirq/ops/boolean_hamiltonian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def test_circuit(boolean_str):

circuit.append(hamiltonian_gate.on(*qubits))

phi = cirq.Simulator().simulate(circuit, qubit_order=qubits, initial_state=0).state_vector()
phi = (
cirq.Simulator()
.simulate(circuit, qubit_order=qubits, initial_state=0)
.state_vector(copy=False)
)
actual = np.arctan2(phi.real, phi.imag) - math.pi / 2.0 > 0.0

# Compare the two:
Expand Down
32 changes: 23 additions & 9 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def test_xpow_dim_3():

sim = cirq.Simulator()
circuit = cirq.Circuit([x(cirq.LineQid(0, 3)) ** 0.5] * 6)
svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit)]
svs = [step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit)]
# fmt: off
expected = [
[0.67, 0.67, 0.33],
Expand Down Expand Up @@ -1116,7 +1116,7 @@ def test_xpow_dim_4():

sim = cirq.Simulator()
circuit = cirq.Circuit([x(cirq.LineQid(0, 4)) ** 0.5] * 8)
svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit)]
svs = [step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit)]
# fmt: off
expected = [
[0.65, 0.65, 0.27, 0.27],
Expand Down Expand Up @@ -1147,11 +1147,15 @@ def test_zpow_dim_3():

sim = cirq.Simulator()
circuit = cirq.Circuit([z(cirq.LineQid(0, 3)) ** 0.5] * 6)
svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=0)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=0)
]
expected = [[1, 0, 0]] * 6
assert np.allclose((svs), expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=1)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=1)
]
# fmt: off
expected = [
[0, L**0.5, 0],
Expand All @@ -1164,7 +1168,9 @@ def test_zpow_dim_3():
# fmt: on
assert np.allclose((svs), expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=2)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=2)
]
# fmt: off
expected = [
[0, 0, L],
Expand Down Expand Up @@ -1192,11 +1198,15 @@ def test_zpow_dim_4():

sim = cirq.Simulator()
circuit = cirq.Circuit([z(cirq.LineQid(0, 4)) ** 0.5] * 8)
svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=0)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=0)
]
expected = [[1, 0, 0, 0]] * 8
assert np.allclose((svs), expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=1)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=1)
]
# fmt: off
expected = [
[0, 1j**0.5, 0, 0],
Expand All @@ -1211,7 +1221,9 @@ def test_zpow_dim_4():
# fmt: on
assert np.allclose(svs, expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=2)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=2)
]
# fmt: off
expected = [
[0, 0, 1j, 0],
Expand All @@ -1226,7 +1238,9 @@ def test_zpow_dim_4():
# fmt: on
assert np.allclose(svs, expected)

svs = [step.state_vector() for step in sim.simulate_moment_steps(circuit, initial_state=3)]
svs = [
step.state_vector(copy=True) for step in sim.simulate_moment_steps(circuit, initial_state=3)
]
# fmt: off
expected = [
[0, 0, 0, 1j**1.5],
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def final_state_vector(
param_resolver=param_resolver,
)

return result.state_vector()
return result.state_vector(copy=False)


def sample_sweep(
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ def _kraus_(self):
cirq.Circuit(Reset11To00().on(*cirq.LineQubit.range(2))), initial_state=k
)
np.testing.assert_allclose(
out.state_vector(), cirq.one_hot(index=k % 3, shape=4, dtype=np.complex64), atol=1e-8
out.state_vector(copy=False),
cirq.one_hot(index=k % 3, shape=4, dtype=np.complex64),
atol=1e-8,
)


Expand Down
10 changes: 8 additions & 2 deletions cirq-core/cirq/sim/sparse_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

from cirq import ops
from cirq import _compat, ops
from cirq._compat import deprecated_parameter
from cirq.sim import simulator, state_vector, state_vector_simulator, state_vector_simulation_state

Expand Down Expand Up @@ -246,7 +246,7 @@ def __init__(
self._dtype = dtype
self._state_vector: Optional[np.ndarray] = None

def state_vector(self, copy: bool = True):
def state_vector(self, copy: Optional[bool] = None):
"""Return the state vector at this point in the computation.
The state is returned in the computational basis with these basis
Expand Down Expand Up @@ -279,6 +279,12 @@ def state_vector(self, copy: bool = True):
parameters from the state vector and store then using False
can speed up simulation by eliminating a memory copy.
"""
if copy is None:
_compat._warn_or_error(
"Starting in v0.16, state_vector will not copy the state by default. "
"Explicitly set copy=True to copy the state."
)
copy = True
if self._state_vector is None:
self._state_vector = np.array([1])
state = self._merged_sim_state
Expand Down
39 changes: 25 additions & 14 deletions cirq-core/cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,20 @@ def test_simulate_moment_steps(dtype: Type[np.number], split: bool):
simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
for i, step in enumerate(simulator.simulate_moment_steps(circuit)):
if i == 0:
np.testing.assert_almost_equal(step.state_vector(), np.array([0.5] * 4))
np.testing.assert_almost_equal(step.state_vector(copy=True), np.array([0.5] * 4))
else:
np.testing.assert_almost_equal(step.state_vector(), np.array([1, 0, 0, 0]))
np.testing.assert_almost_equal(step.state_vector(copy=True), np.array([1, 0, 0, 0]))


def test_simulate_moment_steps_implicit_copy_deprecated():
q0 = cirq.LineQubit(0)
simulator = cirq.Simulator()
steps = list(simulator.simulate_moment_steps(cirq.Circuit(cirq.X(q0))))

with cirq.testing.assert_deprecated(
"state_vector will not copy the state by default", deadline="v0.16"
):
_ = steps[0].state_vector()


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand All @@ -563,7 +574,7 @@ def test_simulate_moment_steps_empty_circuit(dtype: Type[np.number], split: bool
step = None
for step in simulator.simulate_moment_steps(circuit):
pass
assert np.allclose(step.state_vector(), np.array([1]))
assert np.allclose(step.state_vector(copy=True), np.array([1]))
assert not step.qubit_map


Expand Down Expand Up @@ -599,10 +610,10 @@ def test_simulate_moment_steps_intermediate_measurement(dtype: Type[np.number],
result = int(step.measurements['q(0)'][0])
expected = np.zeros(2)
expected[result] = 1
np.testing.assert_almost_equal(step.state_vector(), expected)
np.testing.assert_almost_equal(step.state_vector(copy=True), expected)
if i == 2:
expected = np.array([np.sqrt(0.5), np.sqrt(0.5) * (-1) ** result])
np.testing.assert_almost_equal(step.state_vector(), expected)
np.testing.assert_almost_equal(step.state_vector(copy=True), expected)


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand Down Expand Up @@ -710,8 +721,8 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):

initial_state = np.array([np.sqrt(0.5), np.sqrt(0.5)], dtype=np.complex64)
result = simulator.simulate(circuit, initial_state=initial_state)
np.testing.assert_array_almost_equal(result.state_vector(), initial_state)
assert not initial_state is result.state_vector()
np.testing.assert_array_almost_equal(result.state_vector(copy=False), initial_state)
assert not initial_state is result.state_vector(copy=False)


def test_does_not_modify_initial_state():
Expand All @@ -735,7 +746,7 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):
result = simulator.simulate(circuit, initial_state=initial_state)
np.testing.assert_array_almost_equal(np.array([1, 0], dtype=np.complex64), initial_state)
np.testing.assert_array_almost_equal(
result.state_vector(), np.array([0, 1], dtype=np.complex64)
result.state_vector(copy=False), np.array([0, 1], dtype=np.complex64)
)


Expand Down Expand Up @@ -787,7 +798,7 @@ def test_simulates_composite():
np.testing.assert_allclose(
c.final_state_vector(ignore_terminal_measurements=False, dtype=np.complex64), expected
)
np.testing.assert_allclose(cirq.Simulator().simulate(c).state_vector(), expected)
np.testing.assert_allclose(cirq.Simulator().simulate(c).state_vector(copy=False), expected)


def test_simulate_measurement_inversions():
Expand All @@ -804,15 +815,15 @@ def test_works_on_pauli_string_phasor():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(np.exp(0.5j * np.pi * cirq.X(a) * cirq.X(b)))
sim = cirq.Simulator()
result = sim.simulate(c).state_vector()
result = sim.simulate(c).state_vector(copy=False)
np.testing.assert_allclose(result.reshape(4), np.array([0, 0, 0, 1j]), atol=1e-8)


def test_works_on_pauli_string():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(a) * cirq.X(b))
sim = cirq.Simulator()
result = sim.simulate(c).state_vector()
result = sim.simulate(c).state_vector(copy=False)
np.testing.assert_allclose(result.reshape(4), np.array([0, 0, 0, 1]), atol=1e-8)


Expand Down Expand Up @@ -1322,9 +1333,9 @@ def test_final_state_vector_is_not_last_object():
initial_state = np.array([1, 0], dtype=np.complex64)
circuit = cirq.Circuit(cirq.wait(q))
result = sim.simulate(circuit, initial_state=initial_state)
assert result.state_vector() is not initial_state
assert not np.shares_memory(result.state_vector(), initial_state)
np.testing.assert_equal(result.state_vector(), initial_state)
assert result.state_vector(copy=False) is not initial_state
assert not np.shares_memory(result.state_vector(copy=False), initial_state)
np.testing.assert_equal(result.state_vector(copy=False), initial_state)


def test_deterministic_gate_noise():
Expand Down
Loading

0 comments on commit a402385

Please sign in to comment.