Skip to content

Commit

Permalink
Better type handling for rshifts (#369)
Browse files Browse the repository at this point in the history
**Context:**

```
from mrmustard.lab_dev import Coherent

Coherent([0, 1]) >> Coherent([0]).dual
```
At the moment, the code above returns a `CircuitComponent`. We want it
to return a `Ket`.
  • Loading branch information
SamFerracin authored Mar 22, 2024
1 parent 2ee1b98 commit 447856e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 43 deletions.
42 changes: 16 additions & 26 deletions mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from mrmustard.physics.gaussian import purity
from mrmustard.physics.representations import Bargmann, Fock
from ..circuit_components import CircuitComponent
from ..transformations.transformations import Unitary, Channel

__all__ = ["State", "DM", "Ket"]

Expand Down Expand Up @@ -351,17 +350,14 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent:
Contracts ``self`` and ``other`` as it would in a circuit, adding the adjoints when
they are missing.
Returns a ``DM`` when ``other`` is a ``Unitary`` or a ``Channel``, and ``other`` acts on
``self``'s modes. Otherwise, it returns a ``CircuitComponent``.
Returns a ``DM`` when the wires of the resulting components are compatible with those
of a ``Ket``, a ``CircuitComponent`` otherwise.
"""
component = super().__rshift__(other)
ret = super().__rshift__(other)

if isinstance(other, (Unitary, Channel)) and set(other.modes).issubset(self.modes):
dm = DM()
dm._wires = component.wires
dm._representation = component.representation
return dm
return component
if not ret.wires.input and ret.wires.bra.modes == ret.wires.ket.modes:
return DM._from_attributes("", ret.representation, ret.wires)
return ret

def __repr__(self) -> str:
return super().__repr__().replace("CircuitComponent", "DM")
Expand Down Expand Up @@ -478,23 +474,17 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent:
Contracts ``self`` and ``other`` as it would in a circuit, adding the adjoints when
they are missing.
Returns a ``State`` (either ``Ket`` or ``DM``) when ``other`` is a ``Unitary`` or a
``Channel``, and ``other`` acts on ``self``'s modes. Otherwise, it returns a
``CircuitComponent``.
Returns a ``DM`` or a ``Ket`` when the wires of the resulting components are compatible
with those of a ``DM`` or of a ``Ket``, a ``CircuitComponent`` otherwise.
"""
component = super().__rshift__(other)

if isinstance(other, Unitary) and set(other.modes).issubset(set(self.modes)):
ket = Ket()
ket._wires = component.wires
ket._representation = component.representation
return ket
elif isinstance(other, Channel) and set(other.modes).issubset(set(self.modes)):
dm = DM()
dm._wires = component.wires
dm._representation = component.representation
return dm
return component
ret = super().__rshift__(other)

if not ret.wires.input:
if not ret.wires.bra:
return Ket._from_attributes("", ret.representation, ret.wires)
if ret.wires.bra.modes == ret.wires.ket.modes:
return DM._from_attributes("", ret.representation, ret.wires)
return ret

def __repr__(self) -> str:
return super().__repr__().replace("CircuitComponent", "Ket")
23 changes: 7 additions & 16 deletions mrmustard/lab_dev/transformations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,13 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent:
Returns a ``Unitary`` when ``other`` is a ``Unitary``, a ``Channel`` when ``other`` is a
``Channel``, and a ``CircuitComponent`` otherwise.
"""
component = super().__rshift__(other)
ret = super().__rshift__(other)

if isinstance(other, Unitary):
unitary = Unitary()
unitary._wires = component.wires
unitary._representation = component.representation
return unitary
return Unitary._from_attributes("", ret.representation, ret.wires)
elif isinstance(other, Channel):
channel = Channel()
channel._wires = component.wires
channel._representation = component.representation
return channel
return component
return Channel._from_attributes("", ret.representation, ret.wires)
return ret

def __repr__(self) -> str:
return super().__repr__().replace("CircuitComponent", "Unitary")
Expand Down Expand Up @@ -101,14 +95,11 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent:
Returns a ``Channel`` when ``other`` is a ``Unitary`` or a ``Channel``, and a
``CircuitComponent`` otherwise.
"""
component = super().__rshift__(other)
ret = super().__rshift__(other)

if isinstance(other, (Unitary, Channel)):
channel = Channel()
channel._wires = component.wires
channel._representation = component.representation
return channel
return component
return Channel._from_attributes("", ret.representation, ret.wires)
return ret

def __repr__(self) -> str:
return super().__repr__().replace("CircuitComponent", "Channel")
12 changes: 11 additions & 1 deletion tests/test_lab_dev/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,18 @@ def test_rshift(self):
channel.name, channel.representation, channel.wires
) # pylint: disable=protected-access

# gates
assert isinstance(ket >> unitary, Ket)
assert isinstance(ket >> channel, DM)
assert isinstance(ket >> unitary >> channel, DM)
assert isinstance(ket >> channel >> unitary, DM)
assert isinstance(ket >> u_component, CircuitComponent)
assert isinstance(ket >> ch_component, CircuitComponent)

# measurements
assert isinstance(ket >> Coherent([0], 1).dual, Ket)
assert isinstance(ket >> Coherent([0], 1).dm().dual, DM)

def test_repr(self):
ket = Coherent([0, 1], 1)
ket_component = CircuitComponent._from_attributes(
Expand Down Expand Up @@ -273,13 +278,18 @@ def test_rshift(self):
) # pylint: disable=protected-access

dm = ket >> channel
assert isinstance(dm, DM)

# gates
assert isinstance(dm, DM)
assert isinstance(dm >> unitary >> channel, DM)
assert isinstance(dm >> channel >> unitary, DM)
assert isinstance(dm >> u_component, CircuitComponent)
assert isinstance(dm >> ch_component, CircuitComponent)

# measurements
assert isinstance(dm >> Coherent([0], 1).dual, DM)
assert isinstance(dm >> Coherent([0], 1).dm().dual, DM)

def test_repr(self):
ket = Coherent([0, 1], 1)
channel = Attenuator([1], 1)
Expand Down

0 comments on commit 447856e

Please sign in to comment.