Skip to content

Commit

Permalink
Add Moment.__sub__ and generalize Moment.__add__ (quantumlib#3216)
Browse files Browse the repository at this point in the history
- `Moment.__add__` now accepts any `OP_TREE`
- `Moment.__sub__` takes an `OP_TREE` and returns a moment with those operations removed
- If there is an operation-to-remove that was not present, an error is raised
- Also bump mypy version to fix an incompatibility with python 3.8 and fix new warnings
    - Add explicit !r to format statements touching bytes
    - Rename `value_equality.py` to avoid name collision with the decorator it defines
    - Ignore false positives related to abstract base type bounds
  • Loading branch information
Strilanc authored Aug 14, 2020
1 parent 72c2a9f commit aff27b8
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 66 deletions.
5 changes: 2 additions & 3 deletions cirq/contrib/qcircuit/qcircuit_diagram_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def multigate_qcircuit_diagram_info(
assert args.known_qubits is not None
symbols = tuple(box if (args.qubit_map[q] == min_index) else
ghost for q in args.known_qubits)
return protocols.CircuitDiagramInfo(symbols,
exponent=info.exponent,
connected=False)
return protocols.CircuitDiagramInfo(
symbols, exponent=1 if info is None else info.exponent, connected=False)


def fallback_qcircuit_diagram_info(
Expand Down
4 changes: 2 additions & 2 deletions cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
if TYPE_CHECKING:
import cirq

TSelf = TypeVar('TSelf', bound='_BaseGridQid')
TSelf = TypeVar('TSelf', bound='_BaseGridQid') # type: ignore


@functools.total_ordering
@functools.total_ordering # type: ignore
class _BaseGridQid(ops.Qid):
"""The Base class for `GridQid` and `GridQubit`."""

Expand Down
4 changes: 2 additions & 2 deletions cirq/devices/line_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
if TYPE_CHECKING:
import cirq

TSelf = TypeVar('TSelf', bound='_BaseLineQid')
TSelf = TypeVar('TSelf', bound='_BaseLineQid') # type: ignore


@functools.total_ordering
@functools.total_ordering # type: ignore
class _BaseLineQid(ops.Qid):
"""The base class for `LineQid` and `LineQubit`."""

Expand Down
27 changes: 20 additions & 7 deletions cirq/ops/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,26 @@ def _json_dict_(self) -> Dict[str, Any]:
def _from_json_dict_(cls, operations, **kwargs):
return Moment(operations)

def __add__(self,
other: Union['cirq.Operation', 'cirq.Moment']) -> 'cirq.Moment':
if isinstance(other, raw_types.Operation):
return self.with_operation(other)
if isinstance(other, Moment):
return Moment(self.operations + other.operations)
return NotImplemented
def __add__(self, other: 'cirq.OP_TREE') -> 'cirq.Moment':
from cirq.circuits import circuit
if isinstance(other, circuit.Circuit):
return NotImplemented # Delegate to Circuit.__radd__.
return Moment([self.operations, other])

def __sub__(self, other: 'cirq.OP_TREE') -> 'cirq.Moment':
from cirq.ops import op_tree
must_remove = set(op_tree.flatten_to_ops(other))
new_ops = []
for op in self.operations:
if op in must_remove:
must_remove.remove(op)
else:
new_ops.append(op)
if must_remove:
raise ValueError(f"Subtracted missing operations from a moment.\n"
f"Missing operations: {must_remove!r}\n"
f"Moment: {self!r}")
return Moment(new_ops)

# pylint: disable=function-redefined
@overload
Expand Down
22 changes: 22 additions & 0 deletions cirq/ops/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,28 @@ def test_add():
with pytest.raises(ValueError, match='Overlap'):
_ = m1 + m2

assert m1 + [[[[cirq.Y(b)]]]] == cirq.Moment(cirq.X(a), cirq.Y(b))
assert m1 + [] == m1


def test_sub():
a, b, c = cirq.LineQubit.range(3)
m = cirq.Moment(cirq.X(a), cirq.Y(b))
assert m - [] == m
assert m - cirq.X(a) == cirq.Moment(cirq.Y(b))
assert m - [[[[cirq.X(a)]], []]] == cirq.Moment(cirq.Y(b))
assert m - [cirq.X(a), cirq.Y(b)] == cirq.Moment()
assert m - [cirq.Y(b)] == cirq.Moment(cirq.X(a))

with pytest.raises(ValueError, match="missing operations"):
_ = m - cirq.X(b)
with pytest.raises(ValueError, match="missing operations"):
_ = m - [cirq.X(a), cirq.Z(c)]

# Preserves relative order.
m2 = cirq.Moment(cirq.X(a), cirq.Y(b), cirq.Z(c))
assert m2 - cirq.Y(b) == cirq.Moment(cirq.X(a), cirq.Z(c))


def test_op_tree():
eq = cirq.testing.EqualsTester()
Expand Down
5 changes: 2 additions & 3 deletions cirq/protocols/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ def _has_channel_(self) -> bool:
"""


def channel(val: Any,
default: Any = RaiseTypeErrorIfNotProvided
) -> Union[Tuple[np.ndarray], Sequence[TDefault]]:
def channel(val: Any, default: Any = RaiseTypeErrorIfNotProvided
) -> Union[Tuple[np.ndarray, ...], TDefault]:
r"""Returns a list of matrices describing the channel for the given value.
These matrices are the terms in the operator sum representation of
Expand Down
8 changes: 5 additions & 3 deletions cirq/study/trial_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Defines trial results."""

from typing import (Any, Callable, Dict, Iterable, Optional, Sequence,
TYPE_CHECKING, Tuple, TypeVar, Union)
TYPE_CHECKING, Tuple, TypeVar, Union, cast)

import collections
import io
Expand Down Expand Up @@ -147,7 +147,8 @@ def multi_measurement_histogram( # type: ignore
self,
*, # Forces keyword args.
keys: Iterable[TMeasurementKey],
fold_func: Callable[[Tuple], T] = _tuple_of_big_endian_int
fold_func: Callable[[Tuple], T] = cast(Callable[[Tuple], T],
_tuple_of_big_endian_int)
) -> collections.Counter:
"""Counts the number of times combined measurement results occurred.
Expand Down Expand Up @@ -208,7 +209,8 @@ def histogram( # type: ignore
self,
*, # Forces keyword args.
key: TMeasurementKey,
fold_func: Callable[[Tuple], T] = value.big_endian_bits_to_int
fold_func: Callable[[Tuple], T] = cast(Callable[[Tuple], T],
value.big_endian_bits_to_int)
) -> collections.Counter:
"""Counts the number of times a measurement result occurred.
Expand Down
2 changes: 1 addition & 1 deletion cirq/value/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@
from cirq.value.type_alias import (
TParamVal,)

from cirq.value.value_equality import (
from cirq.value.value_equality_attr import (
value_equality,)
4 changes: 2 additions & 2 deletions cirq/value/linear_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

"""Linear combination represented as mapping of things to coefficients."""

import numbers
from typing import (Any, Callable, Dict, ItemsView, Iterable, Iterator,
KeysView, Mapping, MutableMapping, overload, Tuple, TypeVar,
Union, ValuesView, Generic, Optional)

Scalar = Union[complex, float]
Scalar = Union[complex, float, numbers.Complex]
TVector = TypeVar('TVector')

TDefault = TypeVar('TDefault')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines `@cirq.value_equality`, for easy __eq__/__hash__ methods."""

from typing import Union, Callable, overload, Any
Expand Down Expand Up @@ -89,8 +88,8 @@ def _value_equality_ne(self: _SupportsValueEquality,


def _value_equality_hash(self: _SupportsValueEquality) -> int:
return hash((self._value_equality_values_cls_(),
self._value_equality_values_()))
return hash(
(self._value_equality_values_cls_(), self._value_equality_values_()))


def _value_equality_approx_eq(self: _SupportsValueEquality,
Expand Down Expand Up @@ -220,4 +219,6 @@ class return the existing class' type.
setattr(cls, '_approx_eq_', _value_equality_approx_eq)

return cls


# pylint: enable=function-redefined
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

@cirq.value_equality
class BasicC:

def __init__(self, x):
self.x = x

Expand All @@ -27,6 +28,7 @@ def _value_equality_values_(self):

@cirq.value_equality
class BasicD:

def __init__(self, x):
self.x = x

Expand Down Expand Up @@ -81,6 +83,7 @@ def test_value_equality_manual():

@cirq.value_equality(unhashable=True)
class UnhashableC:

def __init__(self, x):
self.x = x

Expand All @@ -90,6 +93,7 @@ def _value_equality_values_(self):

@cirq.value_equality(unhashable=True)
class UnhashableD:

def __init__(self, x):
self.x = x

Expand All @@ -112,16 +116,15 @@ def test_value_equality_unhashable():

# Equality works as expected.
eq = cirq.testing.EqualsTester()
eq.add_equality_group(UnhashableC(1),
UnhashableC(1),
UnhashableCa(1),
eq.add_equality_group(UnhashableC(1), UnhashableC(1), UnhashableCa(1),
UnhashableCb(1))
eq.add_equality_group(UnhashableC(2))
eq.add_equality_group(UnhashableD(1))


@cirq.value_equality(distinct_child_types=True)
class DistinctC:

def __init__(self, x):
self.x = x

Expand All @@ -131,6 +134,7 @@ def _value_equality_values_(self):

@cirq.value_equality(distinct_child_types=True)
class DistinctD:

def __init__(self, x):
self.x = x

Expand Down Expand Up @@ -164,6 +168,7 @@ def test_value_equality_distinct_child_types():

@cirq.value_equality(approximate=True)
class ApproxE:

def __init__(self, x):
self.x = x

Expand All @@ -179,6 +184,7 @@ def test_value_equality_approximate():

@cirq.value_equality(approximate=True)
class PeriodicF:

def __init__(self, x, n):
self.x = x
self.n = n
Expand Down Expand Up @@ -210,6 +216,7 @@ class ApproxEb(ApproxE):

@cirq.value_equality(distinct_child_types=True, approximate=True)
class ApproxG:

def __init__(self, x):
self.x = x

Expand All @@ -235,6 +242,7 @@ def test_value_equality_approximate_typing():

def test_value_equality_forgot_method():
with pytest.raises(TypeError, match='_value_equality_values_'):

@cirq.value_equality
class _:
pass
Expand Down
Loading

0 comments on commit aff27b8

Please sign in to comment.