Skip to content

Commit

Permalink
Add == and != methods between Projection (mne-tools#11147)
Browse files Browse the repository at this point in the history
* add comparison of proj and test

* add entry to changelog

* simpler

* simpler as it also checks type

* try with different py:obj role

* fix dataset download in test

* fix x-ref

* fix test
  • Loading branch information
mscheltienne authored Sep 9, 2022
1 parent 8c3e6fb commit 6df57c3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
2 changes: 1 addition & 1 deletion doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Enhancements
- Add :func:`mne.minimum_norm.apply_inverse_tfr_epochs` to apply inverse methods to time-frequency resolved epochs (:gh:`11095` by `Alex Rockhill`_)
- Add :func:`mne.chpi.get_active_chpi` to retrieve the number of active hpi coils for each time point (:gh:`11122` by `Eduard Ort`_)
- Add example of how to obtain time-frequency decomposition using narrow bandpass Hilbert transforms to :ref:`ex-tfr-comparison` (:gh:`11116` by `Alex Rockhill`_)

- Add ``==`` and ``!=`` comparison between `mne.Projection` objects (:gh:`11147` by `Mathieu Scheltienne`_)

Bugs
~~~~
Expand Down
11 changes: 10 additions & 1 deletion mne/io/proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
write_float_matrix, end_block, start_block)
from ..defaults import (_INTERPOLATION_DEFAULT, _BORDER_DEFAULT,
_EXTRAPOLATE_DEFAULT)
from ..utils import logger, verbose, warn, fill_doc, _validate_type
from ..utils import (logger, verbose, warn, fill_doc, _validate_type,
object_diff)


class Projection(dict):
Expand Down Expand Up @@ -70,6 +71,14 @@ def __deepcopy__(self, memodict):
result[k] = v # kind, active, desc, explained_var immutable
return result

def __eq__(self, other):
"""Equality == method."""
return True if len(object_diff(self, other)) == 0 else False

def __ne__(self, other):
"""Different != method."""
return not self.__eq__(other)

@fill_doc
def plot_topomap(self, info, cmap=None, sensors=True,
colorbar=False, res=64, size=1, show=True,
Expand Down
24 changes: 24 additions & 0 deletions mne/io/tests/test_proj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from copy import deepcopy

from mne.datasets.testing import data_path, requires_testing_data
from mne.io import read_raw_fif

directory = data_path(download=False) / "MEG" / "sample"
fname = directory / "sample_audvis_trunc_raw.fif"


@requires_testing_data
def test_eq_ne():
"""Test == and != between projectors."""
raw = read_raw_fif(fname, preload=False)

pca1 = deepcopy(raw.info["projs"][0])
pca2 = deepcopy(raw.info["projs"][1])
car = deepcopy(raw.info["projs"][3])

assert pca1 != pca2
assert pca1 != car
assert pca2 != car
assert pca1 == raw.info["projs"][0]
assert pca2 == raw.info["projs"][1]
assert car == raw.info["projs"][3]

0 comments on commit 6df57c3

Please sign in to comment.