diff --git a/cirq-core/cirq/transformers/measurement_transformers.py b/cirq-core/cirq/transformers/measurement_transformers.py index 7e38c8c991b..7bf034ceed3 100644 --- a/cirq-core/cirq/transformers/measurement_transformers.py +++ b/cirq-core/cirq/transformers/measurement_transformers.py @@ -292,7 +292,20 @@ def drop_terminal_measurements( def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': if isinstance(op.gate, ops.MeasurementGate): return [ - ops.X(q) if b else ops.I(q) for q, b in zip(op.qubits, op.gate.full_invert_mask()) + ( + (ops.X if b else ops.I) + if q.dimension == 2 + else ( + ops.MatrixGate( + # Per SimulationState.measure(), swap 0,1 but leave other dims alone + np.eye(q.dimension)[[1, 0, *range(2, q.dimension)]], + qid_shape=(q.dimension,), + ) + if b + else ops.IdentityGate(qid_shape=(q.dimension,)) + ) + ).on(q) + for q, b in zip(op.qubits, op.gate.full_invert_mask()) ] return op diff --git a/cirq-core/cirq/transformers/measurement_transformers_test.py b/cirq-core/cirq/transformers/measurement_transformers_test.py index 1290ce70606..a941993dec0 100644 --- a/cirq-core/cirq/transformers/measurement_transformers_test.py +++ b/cirq-core/cirq/transformers/measurement_transformers_test.py @@ -759,6 +759,41 @@ def test_drop_terminal(): ) +def test_drop_terminal_qudit(): + q0, q1 = cirq.LineQid.range(2, dimension=3) + circuit = cirq.Circuit( + cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0, q1, key='m', invert_mask=[0, 1]))) + ) + dropped = cirq.drop_terminal_measurements(circuit) + expected_inversion_matrix = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + cirq.testing.assert_same_circuits( + dropped, + cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.IdentityGate(qid_shape=(3,)).on(q0), + cirq.MatrixGate(expected_inversion_matrix, qid_shape=(3,)).on(q1), + ) + ) + ), + ) + # Verify behavior equivalent to simulator (invert_mask swaps 0,1 but leaves 2 alone) + dropped.append(cirq.measure(q0, q1, key='m')) + sim = cirq.Simulator() + c0 = sim.simulate(circuit, initial_state=[0, 0]) + d0 = sim.simulate(dropped, initial_state=[0, 0]) + assert np.all(c0.measurements['m'] == [0, 1]) + assert np.all(d0.measurements['m'] == [0, 1]) + c1 = sim.simulate(circuit, initial_state=[1, 1]) + d1 = sim.simulate(dropped, initial_state=[1, 1]) + assert np.all(c1.measurements['m'] == [1, 0]) + assert np.all(d1.measurements['m'] == [1, 0]) + c2 = sim.simulate(circuit, initial_state=[2, 2]) + d2 = sim.simulate(dropped, initial_state=[2, 2]) + assert np.all(c2.measurements['m'] == [2, 2]) + assert np.all(d2.measurements['m'] == [2, 2]) + + def test_drop_terminal_nonterminal_error(): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(