Skip to content

Commit

Permalink
Make sure PEC preserves measurement gates (unitaryfund#1844)
Browse files Browse the repository at this point in the history
* fix bug

* check last get is measurement in tests
  • Loading branch information
andreamari authored May 23, 2023
1 parent b14c56c commit 753f89f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
6 changes: 0 additions & 6 deletions mitiq/pec/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,13 @@ def sample_circuit(
norm = 1.0

for op in ideal.all_operations():
# Ignore all measurements.
if cirq.is_measurement(op):
continue

sequences, loc_signs, loc_norm = sample_sequence(
cirq.Circuit(op),
representations,
num_samples=num_samples,
random_state=random_state,
)

norm *= loc_norm

for j in range(num_samples):
sampled_signs[j] *= loc_signs[j]
cirq_seq, _ = convert_to_mitiq(sequences[j])
Expand Down
11 changes: 10 additions & 1 deletion mitiq/pec/tests/test_pec_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def test_qubit_independent_representation_cirq():
)

expected_a = Circuit([cirq.I.on(LineQubit(0)), cirq.X.on(LineQubit(1))])
expected_a.append(measure_each(*LineQubit.range(2)))

expected_b = Circuit([cirq.I.on(LineQubit(0)), cirq.Z.on(LineQubit(1))])
expected_b.append(measure_each(*LineQubit.range(2)))

for _ in range(5):
seqs, signs, norm = sample_circuit(circuit, representations=[rep])
Expand Down Expand Up @@ -272,9 +275,15 @@ def test_sample_circuit_cirq(measure):
)

assert isinstance(sampled_circuits[0], Circuit)
assert len(sampled_circuits[0]) == 2
assert signs[0] in (-1, 1)
assert norm >= 1
if measure:
assert len(sampled_circuits[0]) == 3
assert cirq.is_measurement(
list(sampled_circuits[0].all_operations())[-1] # last gate
)
else:
assert len(sampled_circuits[0]) == 2


def test_sample_circuit_partial_representations():
Expand Down

0 comments on commit 753f89f

Please sign in to comment.