Skip to content

Commit

Permalink
Unify transformation comparison implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
shumpohl committed Jun 4, 2024
1 parent f78b832 commit fc0e5e7
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions qupulse/program/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __hash__(self):
return 0x1234991

def __eq__(self, other):
return isinstance(other, IdentityTransformation)
return self is other

def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]:
return output_channels
Expand Down Expand Up @@ -134,7 +134,9 @@ def __hash__(self):
return hash(self._transformations)

def __eq__(self, other):
return self._transformations == getattr(other, '_transformations', None)
if isinstance(other, ChainedTransformation):
return self._transformations == other._transformations
return NotImplemented

def chain(self, next_transformation) -> Transformation:
return chain_transformations(*self.transformations, next_transformation)
Expand Down Expand Up @@ -223,11 +225,11 @@ def __hash__(self):
return hash((self._input_channels, self._output_channels, self._matrix.tobytes()))

def __eq__(self, other):
if isinstance(other, type(self)):
if isinstance(other, LinearTransformation):
return (self._input_channels == other._input_channels and
self._output_channels == other._output_channels and
np.array_equal(self._matrix, other._matrix))
return False
return NotImplemented

@property
def compare_key(self) -> Tuple[Tuple[ChannelID], Tuple[ChannelID], bytes]:
Expand Down Expand Up @@ -278,7 +280,9 @@ def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> Abstrac
return input_channels

def __eq__(self, other):
return isinstance(other, OffsetTransformation) and self._offsets == other._offsets
if isinstance(other, OffsetTransformation):
return self._offsets == other._offsets
return NotImplemented

def __hash__(self):
return hash(self._offsets)
Expand Down Expand Up @@ -320,7 +324,9 @@ def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> Abstrac
return input_channels

def __eq__(self, other):
return isinstance(other, ScalingTransformation) and self._factors == other._factors
if isinstance(other, ScalingTransformation):
return self._factors == other._factors
return NotImplemented

def __hash__(self):
return hash(self._factors)
Expand Down Expand Up @@ -393,7 +399,9 @@ def __hash__(self):
return hash(self._channels)

def __eq__(self, other):
return isinstance(other, ParallelChannelTransformation) and self._channels == other._channels
if isinstance(other, ParallelChannelTransformation):
return self._channels == other._channels
return NotImplemented

@property
def compare_key(self) -> Hashable:
Expand Down

0 comments on commit fc0e5e7

Please sign in to comment.