Skip to content

Commit

Permalink
Cache Circuit properties between mutations (#6322)
Browse files Browse the repository at this point in the history
This caches various computed properties on `Circuit` so that they do not need to be recomputed when accessed if the circuit has not been mutated. Any mutations cause these properties to be invalidated so that they will be recomputed the next time they are accessed.
  • Loading branch information
maffoo authored Nov 1, 2023
1 parent 5485227 commit 2f7d732
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 15 deletions.
88 changes: 73 additions & 15 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,28 +188,20 @@ def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) ->
def moments(self) -> Sequence['cirq.Moment']:
pass

@abc.abstractmethod
def freeze(self) -> 'cirq.FrozenCircuit':
"""Creates a FrozenCircuit from this circuit.
If 'self' is a FrozenCircuit, the original object is returned.
"""
from cirq.circuits import FrozenCircuit

if isinstance(self, FrozenCircuit):
return self

return FrozenCircuit(self, strategy=InsertStrategy.EARLIEST)

@abc.abstractmethod
def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
"""Creates a Circuit from this circuit.
Args:
copy: If True and 'self' is a Circuit, returns a copy that circuit.
"""
if isinstance(self, Circuit):
return Circuit.copy(self) if copy else self

return Circuit(self, strategy=InsertStrategy.EARLIEST)

def __bool__(self):
return bool(self.moments)
Expand Down Expand Up @@ -822,6 +814,9 @@ def has_measurements(self):
"""
return protocols.is_measurement(self)

def _is_measurement_(self) -> bool:
return any(protocols.is_measurement(op) for op in self.all_operations())

def are_all_measurements_terminal(self) -> bool:
"""Whether all measurement gates are at the end of the circuit.
Expand Down Expand Up @@ -1383,8 +1378,7 @@ def save_qasm(
self._to_qasm_output(header, precision, qubit_order).save(file_path)

def _json_dict_(self):
ret = protocols.obj_to_dict_helper(self, ['moments'])
return ret
return protocols.obj_to_dict_helper(self, ['moments'])

@classmethod
def _from_json_dict_(cls, moments, **kwargs):
Expand Down Expand Up @@ -1759,6 +1753,16 @@ def __init__(
circuit.
"""
self._moments: List['cirq.Moment'] = []

# Implementation note: the following cached properties are set lazily and then
# invalidated and reset to None in `self._mutated()`, which is called any time
# `self._moments` is changed.
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._frozen: Optional['cirq.FrozenCircuit'] = None
self._is_measurement: Optional[bool] = None
self._is_parameterized: Optional[bool] = None
self._parameter_names: Optional[AbstractSet[str]] = None

flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
if all(isinstance(c, Moment) for c in flattened_contents):
self._moments[:] = cast(Iterable[Moment], flattened_contents)
Expand All @@ -1769,6 +1773,14 @@ def __init__(
else:
self.append(flattened_contents, strategy=strategy)

def _mutated(self) -> None:
"""Clear cached properties in response to this circuit being mutated."""
self._all_qubits = None
self._frozen = None
self._is_measurement = None
self._is_parameterized = None
self._parameter_names = None

@classmethod
def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'Circuit':
new_circuit = Circuit()
Expand Down Expand Up @@ -1831,6 +1843,41 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
def __copy__(self) -> 'cirq.Circuit':
return self.copy()

def freeze(self) -> 'cirq.FrozenCircuit':
"""Gets a frozen version of this circuit.
Repeated calls to `.freeze()` will return the same FrozenCircuit
instance as long as this circuit is not mutated.
"""
from cirq.circuits.frozen_circuit import FrozenCircuit

if self._frozen is None:
self._frozen = FrozenCircuit.from_moments(*self._moments)
return self._frozen

def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
return self.copy() if copy else self

def all_qubits(self) -> FrozenSet['cirq.Qid']:
if self._all_qubits is None:
self._all_qubits = super().all_qubits()
return self._all_qubits

def _is_measurement_(self) -> bool:
if self._is_measurement is None:
self._is_measurement = super()._is_measurement_()
return self._is_measurement

def _is_parameterized_(self) -> bool:
if self._is_parameterized is None:
self._is_parameterized = super()._is_parameterized_()
return self._is_parameterized

def _parameter_names_(self) -> AbstractSet[str]:
if self._parameter_names is None:
self._parameter_names = super()._parameter_names_()
return self._parameter_names

def copy(self) -> 'Circuit':
"""Return a copy of this circuit."""
copied_circuit = Circuit()
Expand All @@ -1856,11 +1903,13 @@ def __setitem__(self, key, value):
raise TypeError('Can only assign Moments into Circuits.')

self._moments[key] = value
self._mutated()

# pylint: enable=function-redefined

def __delitem__(self, key: Union[int, slice]):
del self._moments[key]
self._mutated()

def __iadd__(self, other):
self.append(other)
Expand Down Expand Up @@ -1889,6 +1938,7 @@ def __imul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
self._moments *= int(repetitions)
self._mutated()
return self

def __mul__(self, repetitions: _INT_TYPE):
Expand Down Expand Up @@ -2032,6 +2082,7 @@ def _pick_or_create_inserted_op_moment_index(

if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE:
self._moments.insert(splitter_index, Moment())
self._mutated()
return splitter_index

if strategy is InsertStrategy.INLINE:
Expand Down Expand Up @@ -2099,6 +2150,7 @@ def insert(
k = max(k, p + 1)
if strategy is InsertStrategy.NEW_THEN_INLINE:
strategy = InsertStrategy.INLINE
self._mutated()
return k

def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) -> int:
Expand Down Expand Up @@ -2135,6 +2187,7 @@ def insert_into_range(self, operations: 'cirq.OP_TREE', start: int, end: int) ->

self._moments[i] = self._moments[i].with_operation(op)
op_index += 1
self._mutated()

if op_index >= len(flat_ops):
return end
Expand Down Expand Up @@ -2180,6 +2233,7 @@ def _push_frontier(
if n_new_moments > 0:
insert_index = min(late_frontier.values())
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
self._mutated()
for q in update_qubits:
if early_frontier.get(q, 0) > insert_index:
early_frontier[q] += n_new_moments
Expand All @@ -2206,13 +2260,12 @@ def _insert_operations(
if len(operations) != len(insertion_indices):
raise ValueError('operations and insertion_indices must have the same length.')
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
self._mutated()
moment_to_ops: Dict[int, List['cirq.Operation']] = defaultdict(list)
for op_index, moment_index in enumerate(insertion_indices):
moment_to_ops[moment_index].append(operations[op_index])
for moment_index, new_ops in moment_to_ops.items():
self._moments[moment_index] = Moment(
self._moments[moment_index].operations + tuple(new_ops)
)
self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops)

def insert_at_frontier(
self,
Expand Down Expand Up @@ -2274,6 +2327,7 @@ def batch_remove(self, removals: Iterable[Tuple[int, 'cirq.Operation']]) -> None
old_op for old_op in copy._moments[i].operations if op != old_op
)
self._moments = copy._moments
self._mutated()

def batch_replace(
self, replacements: Iterable[Tuple[int, 'cirq.Operation', 'cirq.Operation']]
Expand All @@ -2298,6 +2352,7 @@ def batch_replace(
old_op if old_op != op else new_op for old_op in copy._moments[i].operations
)
self._moments = copy._moments
self._mutated()

def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
"""Inserts operations into empty spaces in existing moments.
Expand All @@ -2318,6 +2373,7 @@ def batch_insert_into(self, insert_intos: Iterable[Tuple[int, 'cirq.OP_TREE']])
for i, insertions in insert_intos:
copy._moments[i] = copy._moments[i].with_operations(insertions)
self._moments = copy._moments
self._mutated()

def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None:
"""Applies a batched insert operation to the circuit.
Expand Down Expand Up @@ -2352,6 +2408,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None
if next_index > insert_index:
shift += next_index - insert_index
self._moments = copy._moments
self._mutated()

def append(
self,
Expand Down Expand Up @@ -2382,6 +2439,7 @@ def clear_operations_touching(
for k in moment_indices:
if 0 <= k < len(self._moments):
self._moments[k] = self._moments[k].without_operations_touching(qubits)
self._mutated()

@property
def moments(self) -> Sequence['cirq.Moment']:
Expand Down
94 changes: 94 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4533,6 +4533,100 @@ def test_freeze_not_relocate_moments():
assert [mc is fc for mc, fc in zip(c, f)] == [True, True]


def test_freeze_is_cached():
q = cirq.q(0)
c = cirq.Circuit(cirq.X(q), cirq.measure(q))
f0 = c.freeze()
f1 = c.freeze()
assert f1 is f0

c.append(cirq.Y(q))
f2 = c.freeze()
f3 = c.freeze()
assert f2 is not f1
assert f3 is f2

c[-1] = cirq.Moment(cirq.Y(q))
f4 = c.freeze()
f5 = c.freeze()
assert f4 is not f3
assert f5 is f4


@pytest.mark.parametrize(
"circuit, mutate",
[
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.__setitem__(0, cirq.Moment(cirq.Y(cirq.q(0)))),
),
(cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__delitem__(0)),
(cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2)),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.insert(1, cirq.Y(cirq.q(0))),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.insert_into_range([cirq.Y(cirq.q(1)), cirq.M(cirq.q(1))], 0, 2),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.insert_at_frontier([cirq.Y(cirq.q(0)), cirq.Y(cirq.q(1))], 1),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.batch_replace([(0, cirq.X(cirq.q(0)), cirq.Y(cirq.q(0)))]),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0), cirq.q(1))),
lambda c: c.batch_insert_into([(0, cirq.X(cirq.q(1)))]),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.batch_insert([(1, cirq.Y(cirq.q(0)))]),
),
(
cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))),
lambda c: c.clear_operations_touching([cirq.q(0)], [0]),
),
],
)
def test_mutation_clears_cached_attributes(circuit, mutate):
cached_attributes = [
"_all_qubits",
"_frozen",
"_is_measurement",
"_is_parameterized",
"_parameter_names",
]

for attr in cached_attributes:
assert getattr(circuit, attr) is None, f"{attr=} is not None"

# Check that attributes are cached after getting them.
qubits = circuit.all_qubits()
frozen = circuit.freeze()
is_measurement = cirq.is_measurement(circuit)
is_parameterized = cirq.is_parameterized(circuit)
parameter_names = cirq.parameter_names(circuit)

for attr in cached_attributes:
assert getattr(circuit, attr) is not None, f"{attr=} is None"

# Check that getting again returns same object.
assert circuit.all_qubits() is qubits
assert circuit.freeze() is frozen
assert cirq.is_measurement(circuit) is is_measurement
assert cirq.is_parameterized(circuit) is is_parameterized
assert cirq.parameter_names(circuit) is parameter_names

# Check that attributes are cleared after mutation.
mutate(circuit)
for attr in cached_attributes:
assert getattr(circuit, attr) is None, f"{attr=} is not None"


def test_factorize_one_factor():
circuit = cirq.Circuit()
q0, q1, q2 = cirq.LineQubit.range(3)
Expand Down
6 changes: 6 additions & 0 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
def moments(self) -> Sequence['cirq.Moment']:
return self._moments

def freeze(self) -> 'cirq.FrozenCircuit':
return self

def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
return Circuit.from_moments(*self)

@property
def tags(self) -> Tuple[Hashable, ...]:
"""Returns a tuple of the Circuit's tags."""
Expand Down

0 comments on commit 2f7d732

Please sign in to comment.