Skip to content

Commit

Permalink
Preliminary support for outputting to OpenQASM 3.0 (quantumlib#6795)
Browse files Browse the repository at this point in the history
* Preliminary support for outputting to OpenQASM 3.0

- This supports changing the version and the obvious
changes in outputting to OpenQASM 3.0
- This PR does not cover imports from OpenQASM 3.0
  • Loading branch information
dstrain115 authored Nov 19, 2024
1 parent 326df25 commit f2c7330
Show file tree
Hide file tree
Showing 13 changed files with 80 additions and 35 deletions.
18 changes: 14 additions & 4 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,14 +1318,19 @@ def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool)
return self
return self._from_moments(resolved_moments)

def _qasm_(self) -> str:
return self.to_qasm()
def _qasm_(self, args: Optional['cirq.QasmArgs'] = None) -> str:
if args is None:
output = self._to_qasm_output()
else:
output = self._to_qasm_output(precision=args.precision, version=args.version)
return str(output)

def _to_qasm_output(
self,
header: Optional[str] = None,
precision: int = 10,
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT,
version: str = '2.0',
) -> 'cirq.QasmOutput':
"""Returns a QASM object equivalent to the circuit.
Expand All @@ -1335,6 +1340,8 @@ def _to_qasm_output(
precision: Number of digits to use when representing numbers.
qubit_order: Determines how qubits are ordered in the QASM
register.
version: Version of OpenQASM to render as output. Defaults
to OpenQASM 2.0. For OpenQASM 3.0, set this to '3.0'.
"""
if header is None:
header = f'Generated from Cirq v{cirq._version.__version__}'
Expand All @@ -1344,14 +1351,15 @@ def _to_qasm_output(
qubits=qubits,
header=header,
precision=precision,
version='2.0',
version=version,
)

def to_qasm(
self,
header: Optional[str] = None,
precision: int = 10,
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT,
version: str = '2.0',
) -> str:
"""Returns QASM equivalent to the circuit.
Expand All @@ -1361,9 +1369,11 @@ def to_qasm(
precision: Number of digits to use when representing numbers.
qubit_order: Determines how qubits are ordered in the QASM
register.
version: Version of OpenQASM to output. Defaults to OpenQASM 2.0.
Specify '3.0' if OpenQASM 3.0 is desired.
"""

return str(self._to_qasm_output(header, precision, qubit_order))
return str(self._to_qasm_output(header, precision, qubit_order, version))

def save_qasm(
self,
Expand Down
22 changes: 21 additions & 1 deletion cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3579,7 +3579,7 @@ def test_insert_operations_errors():
@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_to_qasm(circuit_cls):
q0 = cirq.NamedQubit('q0')
circuit = circuit_cls(cirq.X(q0))
circuit = circuit_cls(cirq.X(q0), cirq.measure(q0, key='mmm'))
assert circuit.to_qasm() == cirq.qasm(circuit)
assert (
circuit.to_qasm()
Expand All @@ -3591,9 +3591,29 @@ def test_to_qasm(circuit_cls):
// Qubits: [q0]
qreg q[1];
creg m_mmm[1];
x q[0];
measure q[0] -> m_mmm[0];
"""
)
assert circuit.to_qasm(version="3.0") == cirq.qasm(circuit, args=cirq.QasmArgs(version="3.0"))
assert (
circuit.to_qasm(version="3.0")
== f"""// Generated from Cirq v{cirq.__version__}
OPENQASM 3.0;
include "stdgates.inc";
// Qubits: [q0]
qubit[1] q;
bit[1] m_mmm;
x q[0];
m_mmm[0] = measure q[0];
"""
)

Expand Down
30 changes: 21 additions & 9 deletions cirq-core/cirq/circuits/qasm_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _has_unitary_(self):
return True

def _qasm_(self, qubits: Tuple['cirq.Qid', ...], args: 'cirq.QasmArgs') -> str:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format(
'u3({0:half_turns},{1:half_turns},{2:half_turns}) {3};\n',
self.theta,
Expand Down Expand Up @@ -246,7 +246,7 @@ def __str__(self) -> str:
return ''.join(output)

def _write_qasm(self, output_func: Callable[[str], None]) -> None:
self.args.validate_version('2.0')
self.args.validate_version('2.0', '3.0')

# Generate nice line spacing
line_gap = [0]
Expand All @@ -267,18 +267,26 @@ def output(text):
output('\n')

# Version
output('OPENQASM 2.0;\n')
output('include "qelib1.inc";\n')
output(f'OPENQASM {self.args.version};\n')
if self.args.version == '2.0':
output('include "qelib1.inc";\n')
else:
output('include "stdgates.inc";\n')

output_line_gap(2)

# Function definitions
# None yet

# Register definitions
# Qubit registers

output(f"// Qubits: [{', '.join(map(str, self.qubits))}]\n")
if len(self.qubits) > 0:
output(f'qreg q[{len(self.qubits)}];\n')
if self.args.version == '2.0':
output(f'qreg q[{len(self.qubits)}];\n')
else:
output(f'qubit[{len(self.qubits)}] q;\n')
# Classical registers
# Pick an id for the creg that will store each measurement
already_output_keys: Set[str] = set()
Expand All @@ -288,11 +296,15 @@ def output(text):
continue
already_output_keys.add(key)
meas_id = self.args.meas_key_id_map[key]
comment = self.meas_comments[key]
if comment is None:
output(f'creg {meas_id}[{len(meas.qubits)}];\n')
if self.meas_comments[key] is not None:
comment = f' // Measurement: {self.meas_comments[key]}'
else:
comment = ''

if self.args.version == '2.0':
output(f'creg {meas_id}[{len(meas.qubits)}];{comment}\n')
else:
output(f'creg {meas_id}[{len(meas.qubits)}]; // Measurement: {comment}\n')
output(f'bit[{len(meas.qubits)}] {meas_id};{comment}\n')
# In OpenQASM 2.0, the transformation of global phase gates is ignored.
# Therefore, no newline is created when the operations contained in
# a circuit consist only of global phase gates.
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/circuits/qasm_output_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_precision():
def test_version():
(q0,) = _make_qubits(1)
with pytest.raises(ValueError):
output = cirq.QasmOutput((), (q0,), version='3.0')
output = cirq.QasmOutput((), (q0,), version='4.0')
_ = str(output)


Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
return local_keys.union(protocols.control_keys(self._sub_operation))

def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
if len(self._conditions) > 1:
raise ValueError('QASM does not support multiple conditions.')
subop_qasm = protocols.qasm(self._sub_operation, args=args)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def _has_stabilizer_effect_(self) -> Optional[bool]:
return True

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('reset {0};\n', qubits[0])

def _qid_shape_(self):
Expand Down
18 changes: 9 additions & 9 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _circuit_diagram_info_(
)

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
if self._global_shift == 0:
if self._exponent == 1:
return args.format('x {0};\n', qubits[0])
Expand Down Expand Up @@ -374,7 +374,7 @@ def __repr__(self) -> str:
return f'cirq.Rx(rads={proper_repr(self._rads)})'

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('rx({0:half_turns}) {1};\n', self._exponent, qubits[0])

def _json_dict_(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -478,7 +478,7 @@ def _circuit_diagram_info_(
)

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
if self._exponent == 1 and self.global_shift != -0.5:
return args.format('y {0};\n', qubits[0])

Expand Down Expand Up @@ -560,7 +560,7 @@ def __repr__(self) -> str:
return f'cirq.Ry(rads={proper_repr(self._rads)})'

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('ry({0:half_turns}) {1};\n', self._exponent, qubits[0])

def _json_dict_(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -791,7 +791,7 @@ def _circuit_diagram_info_(
return protocols.CircuitDiagramInfo(wire_symbols=('Z',), exponent=e)

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')

if self.global_shift == 0:
if self._exponent == 1:
Expand Down Expand Up @@ -910,7 +910,7 @@ def __repr__(self) -> str:
return f'cirq.Rz(rads={proper_repr(self._rads)})'

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('rz({0:half_turns}) {1};\n', self._exponent, qubits[0])

def _json_dict_(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1016,7 +1016,7 @@ def _circuit_diagram_info_(
)

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
if self._exponent == 0:
return args.format('id {0};\n', qubits[0])
elif self._exponent == 1 and self._global_shift == 0:
Expand Down Expand Up @@ -1204,7 +1204,7 @@ def _circuit_diagram_info_(
def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
if self._exponent != 1:
return None # Don't have an equivalent gate in QASM
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('cz {0},{1};\n', qubits[0], qubits[1])

def _has_stabilizer_effect_(self) -> Optional[bool]:
Expand Down Expand Up @@ -1405,7 +1405,7 @@ def controlled(
def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
if self._exponent != 1:
return None # Don't have an equivalent gate in QASM
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('cx {0},{1};\n', qubits[0], qubits[1])

def _has_stabilizer_effect_(self) -> Optional[bool]:
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _circuit_diagram_info_(self, args) -> Tuple[str, ...]:
return ('I',) * self.num_qubits()

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return ''.join([args.format('id {0};\n', qubit) for qubit in qubits])

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/matrix_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _circuit_diagram_info_(
return protocols.CircuitDiagramInfo(wire_symbols=[main, *rest])

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
if self._qid_shape == (2,):
return protocols.qasm(
phased_x_z_gate.PhasedXZGate.from_matrix(self._matrix), args=args, qubits=qubits
Expand Down
7 changes: 5 additions & 2 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,18 @@ def _circuit_diagram_info_(
def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
if self.confusion_map or not all(d == 2 for d in self._qid_shape):
return NotImplemented
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
invert_mask = self.invert_mask
if len(invert_mask) < len(qubits):
invert_mask = invert_mask + (False,) * (len(qubits) - len(invert_mask))
lines = []
for i, (qubit, inv) in enumerate(zip(qubits, invert_mask)):
if inv:
lines.append(args.format('x {0}; // Invert the following measurement\n', qubit))
lines.append(args.format('measure {0} -> {1:meas}[{2}];\n', qubit, self.key, i))
if args.version == '2.0':
lines.append(args.format('measure {0} -> {1:meas}[{2}];\n', qubit, self.key, i))
else:
lines.append(args.format('{1:meas}[{2}] = measure {0};\n', qubit, self.key, i))
if inv:
lines.append(args.format('x {0}; // Undo the inversion\n', qubit))
return ''.join(lines)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/phased_x_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
if cirq.is_parameterized(self):
return None

args.validate_version('2.0')
args.validate_version('2.0', '3.0')

e = cast(float, value.canonicalize_half_turns(self._exponent))
p = cast(float, self.phase_exponent)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/swap_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _circuit_diagram_info_(
def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
if self._exponent != 1:
return None # Don't have an equivalent gate in QASM
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('swap {0},{1};\n', qubits[0], qubits[1])

def __str__(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/three_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
if self._exponent != 1:
return None

args.validate_version('2.0')
args.validate_version('2.0', '3.0')
lines = [
args.format('h {0};\n', qubits[2]),
args.format('ccx {0},{1},{2};\n', qubits[0], qubits[1], qubits[2]),
Expand Down Expand Up @@ -483,7 +483,7 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
if self._exponent != 1:
return None

args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('ccx {0},{1},{2};\n', qubits[0], qubits[1], qubits[2])

def __repr__(self) -> str:
Expand Down Expand Up @@ -661,7 +661,7 @@ def _circuit_diagram_info_(
return protocols.CircuitDiagramInfo(('@', '×', '×'))

def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
args.validate_version('2.0')
args.validate_version('2.0', '3.0')
return args.format('cswap {0},{1},{2};\n', qubits[0], qubits[1], qubits[2])

def _value_equality_values_(self):
Expand Down

0 comments on commit f2c7330

Please sign in to comment.