Skip to content

Commit

Permalink
Allow measurement keys with differing qubits in OpenQASM (#6803)
Browse files Browse the repository at this point in the history
* Allow measurement keys with differing qubits in OpenQASM

- Previously, if there was a measurement with the same key
but two differently sized registers, then the qasm output
might select the wrong key and size the register incorrectly.
- This PR examines all the measurements first, and selects the
biggest one, so to correctly size the classical register.
- Adds unit tests to demonstrate.

Fixes: #6508

---------

Co-authored-by: Pavol Juhas <[email protected]>
  • Loading branch information
dstrain115 and pavoljuhas authored Nov 22, 2024
1 parent 76125a2 commit e0087b0
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 16 deletions.
48 changes: 32 additions & 16 deletions cirq-core/cirq/circuits/qasm_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Utility classes for representing QASM."""

from typing import Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Union, TYPE_CHECKING
from typing import Callable, Dict, Iterator, Optional, Sequence, Tuple, Union, TYPE_CHECKING

import re
import numpy as np
Expand Down Expand Up @@ -203,6 +203,7 @@ def __init__(
qubit_id_map=qubit_id_map,
meas_key_id_map=meas_key_id_map,
)
self.cregs = self._generate_cregs()

def _generate_measurement_ids(self) -> Tuple[Dict[str, str], Dict[str, Optional[str]]]:
# Pick an id for the creg that will store each measurement
Expand All @@ -226,6 +227,30 @@ def _generate_measurement_ids(self) -> Tuple[Dict[str, str], Dict[str, Optional[
def _generate_qubit_ids(self) -> Dict['cirq.Qid', str]:
return {qubit: f'q[{i}]' for i, qubit in enumerate(self.qubits)}

def _generate_cregs(self) -> Dict[str, tuple[int, str]]:
"""Pick an id for the creg that will store each measurement
This function finds the largest measurement using each key.
That is, if multiple measurements are made with the same key,
it will use the key with the most number of qubits.
Returns: dictionary with key of measurement id and value of (#qubits, comment).
"""
cregs: Dict[str, tuple[int, str]] = {}
for meas in self.measurements:
key = protocols.measurement_key_name(meas)
meas_id = self.args.meas_key_id_map[key]

if self.meas_comments[key] is not None:
comment = f' // Measurement: {self.meas_comments[key]}'
else:
comment = ''

if meas_id not in cregs or cregs[meas_id][0] < len(meas.qubits):
cregs[meas_id] = (len(meas.qubits), comment)

return cregs

def is_valid_qasm_id(self, id_str: str) -> bool:
"""Test if id_str is a valid id in QASM grammar."""
return self.valid_id_re.match(id_str) is not None
Expand Down Expand Up @@ -287,24 +312,15 @@ def output(text):
output(f'qreg q[{len(self.qubits)}];\n')
else:
output(f'qubit[{len(self.qubits)}] q;\n')
# Classical registers
# Pick an id for the creg that will store each measurement
already_output_keys: Set[str] = set()
for meas in self.measurements:
key = protocols.measurement_key_name(meas)
if key in already_output_keys:
continue
already_output_keys.add(key)
meas_id = self.args.meas_key_id_map[key]
if self.meas_comments[key] is not None:
comment = f' // Measurement: {self.meas_comments[key]}'
else:
comment = ''

# Classical registers
for meas_id in self.cregs:
length, comment = self.cregs[meas_id]
if self.args.version == '2.0':
output(f'creg {meas_id}[{len(meas.qubits)}];{comment}\n')
output(f'creg {meas_id}[{length}];{comment}\n')
else:
output(f'bit[{len(meas.qubits)}] {meas_id};{comment}\n')
output(f'bit[{length}] {meas_id};{comment}\n')

# In OpenQASM 2.0, the transformation of global phase gates is ignored.
# Therefore, no newline is created when the operations contained in
# a circuit consist only of global phase gates.
Expand Down
55 changes: 55 additions & 0 deletions cirq-core/cirq/circuits/qasm_output_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,58 @@ def test_reset():
reset q[1];
""".strip()
)


def test_different_sized_registers():
qubits = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.measure(qubits[0], key='c'), cirq.measure(qubits, key='c'))
output = cirq.QasmOutput(
c.all_operations(), tuple(sorted(c.all_qubits())), header='Generated from Cirq!'
)
assert (
str(output)
== """// Generated from Cirq!
OPENQASM 2.0;
include "qelib1.inc";
// Qubits: [q(0), q(1)]
qreg q[2];
creg m_c[2];
measure q[0] -> m_c[0];
// Gate: cirq.MeasurementGate(2, cirq.MeasurementKey(name='c'), ())
measure q[0] -> m_c[0];
measure q[1] -> m_c[1];
"""
)
# OPENQASM 3.0
output3 = cirq.QasmOutput(
c.all_operations(),
tuple(sorted(c.all_qubits())),
header='Generated from Cirq!',
version='3.0',
)
assert (
str(output3)
== """// Generated from Cirq!
OPENQASM 3.0;
include "stdgates.inc";
// Qubits: [q(0), q(1)]
qubit[2] q;
bit[2] m_c;
m_c[0] = measure q[0];
// Gate: cirq.MeasurementGate(2, cirq.MeasurementKey(name='c'), ())
m_c[0] = measure q[0];
m_c[1] = measure q[1];
"""
)

0 comments on commit e0087b0

Please sign in to comment.