Skip to content

Commit

Permalink
Fix cirq.Simulator assuming a reset is safe to cache before sampling (q…
Browse files Browse the repository at this point in the history
…uantumlib#3093)

- This resulted in some circuits with random behavior giving long lists of identical samples when repeated
- Change the general strategy used by the simulator to one of extracting a unitary prefix that can be cached, and then looking at the rest to decide whether sampling can be done efficiently or not
  • Loading branch information
Strilanc authored Jun 17, 2020
1 parent b0b398e commit 037afff
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 68 deletions.
64 changes: 35 additions & 29 deletions cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,8 @@
as the simulation iterates through the moments of a cirq.
"""

from typing import (
Any,
Dict,
Iterator,
List,
Sequence,
Tuple,
Optional,
TYPE_CHECKING,
)
from typing import (Any, Dict, Iterator, List, Sequence, Tuple, Optional,
TYPE_CHECKING, Set, cast)

import abc
import collections
Expand Down Expand Up @@ -471,29 +463,43 @@ def sample_measurement_ops(self,
instances or a qubit is acted upon multiple times by different
operations from `measurement_ops`.
"""
bounds = {} # type: Dict[str, Tuple]
all_qubits = [] # type: List[ops.Qid]
meas_ops = {}
current_index = 0

# Sanity checks.
seen_measurement_keys: Set[str] = set()
for op in measurement_ops:
gate = op.gate
if not isinstance(gate, ops.MeasurementGate):
raise ValueError('{} was not a MeasurementGate'.format(gate))
raise ValueError(f'{op.gate} was not a MeasurementGate')
key = protocols.measurement_key(gate)
meas_ops[key] = gate
if key in bounds:
raise ValueError(
'Duplicate MeasurementGate with key {}'.format(key))
bounds[key] = (current_index, current_index + len(op.qubits))
all_qubits.extend(op.qubits)
current_index += len(op.qubits)
indexed_sample = self.sample(all_qubits, repetitions, seed=seed)

results = {}
for k, (s, e) in bounds.items():
before_invert_mask = indexed_sample[:, s:e]
results[k] = before_invert_mask ^ (np.logical_and(
before_invert_mask < 2, meas_ops[k].full_invert_mask()))
if key in seen_measurement_keys:
raise ValueError(f'Duplicate MeasurementGate with key {key}')
seen_measurement_keys.add(key)

# Find measured qubits, ensuring a consistent ordering.
measured_qubits = []
seen_qubits: Set[cirq.Qid] = set()
for op in measurement_ops:
for q in op.qubits:
if q not in seen_qubits:
seen_qubits.add(q)
measured_qubits.append(q)

# Perform whole-system sampling of the measured qubits.
indexed_sample = self.sample(measured_qubits, repetitions, seed=seed)

# Extract results for each measurement.
results: Dict[str, np.ndarray] = {}
qubits_to_index = {q: i for i, q in enumerate(measured_qubits)}
for op in measurement_ops:
gate = cast(ops.MeasurementGate, op.gate)
out = np.zeros(shape=(repetitions, len(op.qubits)), dtype=np.int8)
inv_mask = gate.full_invert_mask()
for i, q in enumerate(op.qubits):
out[:, i] = indexed_sample[:, qubits_to_index[q]]
if inv_mask[i]:
out[:, i] ^= out[:, i] < 2
results[gate.key] = out

return results


Expand Down
104 changes: 71 additions & 33 deletions cirq/sim/sparse_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""A simulator that uses numpy's einsum for sparse matrix operations."""

import collections
from typing import Dict, Iterator, List, Type, TYPE_CHECKING, DefaultDict
from typing import Dict, Iterator, List, Type, TYPE_CHECKING, DefaultDict, \
Tuple, cast, Set

import numpy as np

Expand Down Expand Up @@ -134,35 +135,41 @@ def _run(self, circuit: circuits.Circuit,
param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)

def measure_or_mixture(op):
return not protocols.has_unitary(op) and (
protocols.is_measurement(op) or protocols.has_channel(op))

if resolved_circuit.are_all_matches_terminal(measure_or_mixture):
return self._run_sweep_terminal_sample(resolved_circuit,
repetitions)
return self._run_sweep_repeat(resolved_circuit, repetitions)

def _run_sweep_terminal_sample(self, circuit: circuits.Circuit,
repetitions: int) -> Dict[str, np.ndarray]:
for step_result in self._base_iterator(
circuit=circuit,
qubit_order=ops.QubitOrder.DEFAULT,
initial_state=0,
perform_measurements=False):
qubit_order = sorted(resolved_circuit.all_qubits())

# Simulate as many unitary operations as possible before having to
# repeat work for each sample.
unitary_prefix, general_suffix = _split_into_unitary_then_general(
resolved_circuit)
step_result = None
for step_result in self._base_iterator(circuit=unitary_prefix,
qubit_order=qubit_order,
initial_state=0,
perform_measurements=False):
pass
# We can ignore the mixtures since this is a run method which
# does not return the state.
measurement_ops = [op for _, op, _ in
circuit.findall_operations_with_gate_type(
ops.MeasurementGate)]
return step_result.sample_measurement_ops(measurement_ops,
repetitions,
seed=self._prng)

def _run_sweep_repeat(self, circuit: circuits.Circuit,
repetitions: int) -> Dict[str, np.ndarray]:
assert step_result is not None

# When an otherwise unitary circuit ends with non-demolition computation
# basis measurements, we can sample the results more efficiently.
general_ops = list(general_suffix.all_operations())
if all(isinstance(op.gate, ops.MeasurementGate) for op in general_ops):
return step_result.sample_measurement_ops(measurement_ops=cast(
List[ops.GateOperation], general_ops),
repetitions=repetitions,
seed=self._prng)

qid_shape = protocols.qid_shape(qubit_order)
intermediate_state = step_result.state_vector().reshape(qid_shape)
return self._brute_force_samples(initial_state=intermediate_state,
circuit=general_suffix,
repetitions=repetitions,
qubit_order=qubit_order)

def _brute_force_samples(self, initial_state: np.ndarray,
circuit: circuits.Circuit,
qubit_order: 'cirq.QubitOrderOrList',
repetitions: int) -> Dict[str, np.ndarray]:
"""Repeatedly simulate a circuit in order to produce samples."""
if repetitions == 0:
return {
key: np.empty(shape=[0, 1])
Expand All @@ -172,10 +179,9 @@ def _run_sweep_repeat(self, circuit: circuits.Circuit,
measurements: DefaultDict[str, List[
np.ndarray]] = collections.defaultdict(list)
for _ in range(repetitions):
all_step_results = self._base_iterator(
circuit,
qubit_order=ops.QubitOrder.DEFAULT,
initial_state=0)
all_step_results = self._base_iterator(circuit,
initial_state=initial_state,
qubit_order=qubit_order)

for step_result in all_step_results:
for k, v in step_result.measurements.items():
Expand Down Expand Up @@ -331,3 +337,35 @@ def sample(self,
self, None),
repetitions=repetitions,
seed=seed)


def _split_into_unitary_then_general(circuit: 'cirq.Circuit'
) -> Tuple['cirq.Circuit', 'cirq.Circuit']:
"""Splits the circuit into a unitary prefix and non-unitary suffix.
The splitting happens in a per-qubit fashion. A non-unitary operation on
qubit A will cause later operations on A to be part of the non-unitary
suffix, but later operations on other qubits will continue to be put into
the unitary part (as long as those qubits have had no non-unitary operation
up to that point).
"""
blocked_qubits: Set[cirq.Qid] = set()
unitary_prefix = circuits.Circuit()
general_suffix = circuits.Circuit()
for moment in circuit:
unitary_part = []
general_part = []
for op in moment:
qs = set(op.qubits)
if not protocols.has_unitary(op):
blocked_qubits |= qs

if qs.isdisjoint(blocked_qubits):
unitary_part.append(op)
else:
general_part.append(op)
if unitary_part:
unitary_prefix.append(ops.Moment(unitary_part))
if general_part:
general_suffix.append(ops.Moment(general_part))
return unitary_prefix, general_suffix
60 changes: 54 additions & 6 deletions cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_run_measure_at_end_no_repetitions(dtype):
'1': np.empty([0, 1])
})
assert result.repetitions == 0
# We expect one call per b0,b1.
assert mock_sim.call_count == 4


Expand All @@ -125,6 +126,7 @@ def test_run_repetitions_measure_at_end(dtype):
np.testing.assert_equal(result.measurements,
{'0': [[b0]] * 3, '1': [[b1]] * 3})
assert result.repetitions == 3
# We expect one call per b0,b1.
assert mock_sim.call_count == 4


Expand All @@ -147,7 +149,8 @@ def test_run_invert_mask_measure_not_terminal(dtype):
np.testing.assert_equal(result.measurements,
{'m': [[1 - b0, b1]] * 3})
assert result.repetitions == 3
assert mock_sim.call_count == 12
# We expect repeated calls per b0,b1 instead of one call.
assert mock_sim.call_count > 4


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand All @@ -169,7 +172,8 @@ def test_run_partial_invert_mask_measure_not_terminal(dtype):
np.testing.assert_equal(result.measurements,
{'m': [[1 - b0, b1]] * 3})
assert result.repetitions == 3
assert mock_sim.call_count == 12
# We expect repeated calls per b0,b1 instead of one call.
assert mock_sim.call_count > 4


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand All @@ -190,7 +194,8 @@ def test_run_measurement_not_terminal_no_repetitions(dtype):
'1': np.empty([0, 1])
})
assert result.repetitions == 0
assert mock_sim.call_count == 0
# We expect one call per b0,b1 instead of one call.
assert mock_sim.call_count == 4


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand All @@ -208,7 +213,8 @@ def test_run_repetitions_measurement_not_terminal(dtype):
np.testing.assert_equal(result.measurements,
{'0': [[b0]] * 3, '1': [[b1]] * 3})
assert result.repetitions == 3
assert mock_sim.call_count == 12
# We expect repeated calls per b0,b1 instead of one call.
assert mock_sim.call_count > 4


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand All @@ -233,8 +239,7 @@ def test_run_mixture(dtype):
simulator = cirq.Simulator(dtype=dtype)
circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0))
result = simulator.run(circuit, repetitions=100)
assert sum(result.measurements['0'])[0] < 80
assert sum(result.measurements['0'])[0] > 20
assert 20 < sum(result.measurements['0'])[0] < 80


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand Down Expand Up @@ -900,3 +905,46 @@ def test_random_seed_mixture_deterministic():
[[1], [0], [0], [0], [1], [0], [0], [1], [1], [1], [1], [1],
[0], [1], [0], [0], [0], [0], [0], [1], [0], [1], [1], [0],
[1], [1], [1], [1], [1], [0]])


def test_entangled_reset_does_not_break_randomness():
"""
A previous version of cirq made the mistake of assuming that it was okay to
cache the wavefunction produced by general channels on unrelated qubits
before repeatedly sampling measurements. This test checks for that mistake.
"""

a, b = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.H(a), cirq.CNOT(a, b),
cirq.ResetChannel().on(a), cirq.measure(b,
key='out'))
samples = cirq.Simulator().sample(circuit, repetitions=100)['out']
counts = samples.value_counts()
assert len(counts) == 2
assert 10 <= counts[0] <= 90
assert 10 <= counts[1] <= 90


def test_overlapping_measurements_at_end():
a, b = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.H(a),
cirq.CNOT(a, b),

# These measurements are not on independent qubits but they commute.
cirq.measure(a, key='a'),
cirq.measure(a, key='not a', invert_mask=(True,)),
cirq.measure(b, key='b'),
cirq.measure(a, b, key='ab'),
)

samples = cirq.Simulator().sample(circuit, repetitions=100)
np.testing.assert_array_equal(samples['a'].values,
samples['not a'].values ^ 1)
np.testing.assert_array_equal(samples['a'].values * 2 + samples['b'].values,
samples['ab'].values)

counts = samples['b'].value_counts()
assert len(counts) == 2
assert 10 <= counts[0] <= 90
assert 10 <= counts[1] <= 90

0 comments on commit 037afff

Please sign in to comment.