Skip to content

Commit

Permalink
ENH: add inverse_transform method to UnsupervisedSpatialFilter (mne-t…
Browse files Browse the repository at this point in the history
…ools#4371)

* enh: inverse_transform

* flake8
  • Loading branch information
kingjr authored and agramfort committed Jul 4, 2017
1 parent cbcbc43 commit 593da2a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mne/decoding/tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def test_unsupervised_spatial_filter():
assert_equal(usf.transform(X).ndim, 3)
# test fit_transform
assert_array_almost_equal(usf.transform(X), usf1.fit_transform(X))
# assert shape
assert_equal(usf.transform(X).shape[1], n_components)
assert_array_almost_equal(usf.inverse_transform(usf.transform(X)), X)

# Test with average param
usf = UnsupervisedSpatialFilter(PCA(4), average=True)
Expand Down
43 changes: 38 additions & 5 deletions mne/decoding/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def fit_transform(self, X, y=None):
Returns
-------
X : array, shape (n_trials, n_channels, n_times)
X : array, shape (n_epochs, n_channels, n_times)
The transformed data.
"""
return self.fit(X).transform(X)
Expand All @@ -645,14 +645,47 @@ def transform(self, X):
Returns
-------
X : array, shape (n_trials, n_channels, n_times)
X : array, shape (n_epochs, n_channels, n_times)
The transformed data.
"""
return self._apply_method(X, 'transform')

def inverse_transform(self, X):
"""Inverse transform the data to its original space.
Parameters
----------
X : array, shape (n_epochs, n_components, n_times)
The data to be inverted.
Returns
-------
X : array, shape (n_epochs, n_channels, n_times)
The transformed data.
"""
return self._apply_method(X, 'inverse_transform')

def _apply_method(self, X, method):
"""Vectorize time samples as trials, apply method and reshape back.
Parameters
----------
X : array, shape (n_epochs, n_dims, n_times)
The data to be inverted.
Returns
-------
X : array, shape (n_epochs, n_dims, n_times)
The transformed data.
"""
n_epochs, n_channels, n_times = X.shape
# trial as time samples
X = np.transpose(X, [1, 0, 2]).reshape([n_channels, n_epochs *
n_times]).T
X = self.estimator.transform(X)
X = np.transpose(X, [1, 0, 2])
X = np.reshape(X, [n_channels, n_epochs * n_times]).T
# apply method
method = getattr(self.estimator, method)
X = method(X)
# put it back to n_epochs, n_dimensions
X = np.reshape(X.T, [-1, n_epochs, n_times]).transpose([1, 0, 2])
return X

Expand Down

0 comments on commit 593da2a

Please sign in to comment.