Skip to content

Commit

Permalink
ENH : cleanup ICA internals + fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
agramfort authored and dengemann committed Jul 30, 2013
1 parent 1b2329c commit dc15f39
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 129 deletions.
164 changes: 74 additions & 90 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,17 @@ def decompose_epochs(self, epochs, picks=None, verbose=None):

return self

def _get_sources(self, data):
"""Compute sources from data (operates inplace)"""
if self.pca_mean_ is not None:
data -= self.pca_mean_[:, None]

# Apply first PCA
pca_data = np.dot(self.pca_components_[:self.n_components_], data)
# Apply unmixing to low dimension PCA
sources = np.dot(self.unmixing_matrix_, pca_data)
return sources

def get_sources_raw(self, raw, start=None, stop=None):
"""Estimate raw sources given the unmixing matrix
Expand All @@ -356,17 +367,10 @@ def get_sources_raw(self, raw, start=None, stop=None):
raise RuntimeError('No fit available. Please first fit ICA '
'decomposition.')
start, stop = _check_start_stop(raw, start, stop)
return self._get_sources_raw(raw, start, stop)[0]

def _get_sources_raw(self, raw, start, stop):
"""Aux function"""

picks = [raw.ch_names.index(k) for k in self.ch_names]
data, _ = self._pre_whiten(raw[picks, start:stop][0], raw.info, picks)
pca_data = self._transform_pca(data.T)
n_components = self.n_components_
raw_sources = self._transform_ica(pca_data[:, :n_components]).T
return raw_sources, pca_data
return self._get_sources(data)

def get_sources_epochs(self, epochs, concatenate=False):
"""Estimate epochs sources given the unmixing matrix
Expand All @@ -387,10 +391,6 @@ def get_sources_epochs(self, epochs, concatenate=False):
raise RuntimeError('No fit available. Please first fit ICA '
'decomposition.')

return self._get_sources_epochs(epochs, concatenate)[0]

def _get_sources_epochs(self, epochs, concatenate):

picks = pick_types(epochs.info, include=self.ch_names, exclude=[])

# special case where epochs come picked but fit was 'unpicked'.
Expand All @@ -401,17 +401,15 @@ def _get_sources_epochs(self, epochs, concatenate):
'ica.ch_names' % (len(self.ch_names),
len(picks)))

data, _ = self._pre_whiten(np.hstack(epochs.get_data()[:, picks]),
epochs.info, picks)

pca_data = self._transform_pca(data.T)
sources = self._transform_ica(pca_data[:, :self.n_components_]).T
sources = np.array(np.split(sources, len(epochs.events), 1))
data = np.hstack(epochs.get_data()[:, picks])
data, _ = self._pre_whiten(data, epochs.info, picks)
sources = self._get_sources(data)

if concatenate:
sources = np.hstack(sources)
if not concatenate:
# Put the data back in 3D
sources = np.array(np.split(sources, len(epochs.events), 1))

return sources, pca_data
return sources

@verbose
def save(self, fname):
Expand Down Expand Up @@ -605,7 +603,7 @@ def plot_sources_raw(self, raw, order=None, start=None, stop=None,
n_components=n_components, source_idx=source_idx,
ncol=ncol, nrow=nrow, title=title)
if show:
import matplotlib.pylab as pl
import pylab as pl
pl.show()

return fig
Expand Down Expand Up @@ -819,15 +817,19 @@ def pick_sources_raw(self, raw, include=None, exclude=None,
self.n_pca_components = n_pca_components

start, stop = _check_start_stop(raw, start, stop)
sources, pca_data = self._get_sources_raw(raw, start=start, stop=stop)
recomposed = self._pick_sources(sources, pca_data, include,
self.exclude)

picks = pick_types(raw.info, meg=False, include=self.ch_names,
exclude='bads')

data = raw[picks, start:stop][0]
data, _ = self._pre_whiten(data, raw.info, picks)

data = self._pick_sources(data, include, self.exclude)

if copy is True:
raw = raw.copy()

picks = [raw.ch_names.index(k) for k in self.ch_names]
raw[picks, start:stop] = recomposed
raw[picks, start:stop] = data
return raw

def pick_sources_epochs(self, epochs, include=None, exclude=None,
Expand Down Expand Up @@ -865,29 +867,27 @@ def pick_sources_epochs(self, epochs, include=None, exclude=None,
'working. Please read raw data with '
'preload=True.')

sources, pca_data = self._get_sources_epochs(epochs, True)
picks = pick_types(epochs.info, include=self.ch_names,
picks = pick_types(epochs.info, meg=False, include=self.ch_names,
exclude='bads')

if copy is True:
epochs = epochs.copy()
# special case where epochs come picked but fit was 'unpicked'.
if len(picks) != len(self.ch_names):
raise RuntimeError('Epochs don\'t match fitted data: %i channels '
'fitted but %i channels supplied. \nPlease '
'provide Epochs compatible with '
'ica.ch_names' % (len(self.ch_names),
len(picks)))

if exclude is None:
self.exclude = list(set(self.exclude))
else:
self.exclude = list(set(self.exclude + exclude))
logger.info('Adding sources %s to .exclude' % ', '.join(
[str(i) for i in exclude if i not in self.exclude]))
data = np.hstack(epochs.get_data()[:, picks])
data, _ = self._pre_whiten(data, epochs.info, picks)

if n_pca_components is not None:
self.n_pca_components = n_pca_components
data = self._pick_sources(data, include, exclude)

# put sources-dimension first for selection
recomposed = self._pick_sources(sources, pca_data, include,
self.exclude)
if copy is True:
epochs = epochs.copy()

# restore epochs, channels, tsl order
epochs._data[:, picks] = np.array(np.split(recomposed,
epochs._data[:, picks] = np.array(np.split(data,
len(epochs.events), 1))
epochs.preload = True

Expand All @@ -896,11 +896,10 @@ def pick_sources_epochs(self, epochs, include=None, exclude=None,
def plot_topomap(self, source_idx, ch_type='mag', res=500, layout=None,
vmax=None, cmap='RdBu_r', sensors='k,', colorbar=True,
show=True):
""" plot topographic map of ICA source
"""Plot topographic map of ICA source
Parameters
----------
The ica object to plot from.
source_idx : int | array-like
The indices of the sources to be plotted.
ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg'
Expand Down Expand Up @@ -1019,7 +1018,6 @@ def detect_artifacts(self, raw, start_find=None, stop_find=None,
The ica object with the detected artifact indices marked for
exclusion
"""

logger.info(' Searching for artifacts...')
_detect_artifacts(self, raw=raw, start_find=start_find,
stop_find=stop_find, ecg_ch=ecg_ch,
Expand Down Expand Up @@ -1117,11 +1115,11 @@ def _decompose(self, data, max_pca_components, fit_type):

# get unmixing and add scaling
self.unmixing_matrix_ = getattr(ica, 'components_', 'unmixing_matrix_')
self.unmixing_matrix_ /= np.sqrt(exp_var[sel])[:, None]
self.mixing_matrix_ = linalg.pinv(self.unmixing_matrix_).T
self.unmixing_matrix_ /= np.sqrt(exp_var[sel])[None, :]
self.mixing_matrix_ = linalg.pinv(self.unmixing_matrix_)
self.current_fit = fit_type

def _pick_sources(self, sources, pca_data, include, exclude):
def _pick_sources(self, data, include, exclude):
"""Aux function"""

_n_pca_comp = _check_n_pca_components(self, self.n_pca_components,
Expand All @@ -1131,55 +1129,41 @@ def _pick_sources(self, sources, pca_data, include, exclude):
raise ValueError('n_pca_components must be between '
'n_components and max_pca_components.')

n_components = self.n_components_
n_pca_components = self.n_pca_components

if self.pca_mean_ is not None:
data -= self.pca_mean_[:, None]

# Apply first PCA
pca_data = np.dot(self.pca_components_, data)
# Apply unmixing to low dimension PCA
sources = np.dot(self.unmixing_matrix_, pca_data[:n_components])

if include not in (None, []):
mute = [i for i in xrange(len(sources)) if i not in include]
sources[mute, :] = 0. # include via exclusion
mask = np.ones(len(sources), dtype=np.bool)
mask[include] = False
sources[mask] = 0.
elif exclude not in (None, []):
sources[exclude, :] = 0. # just exclude
sources[exclude] = 0.

# restore pca data
pca_restored = np.dot(sources.T, self.mixing_matrix_)
pca_data[:n_components] = np.dot(self.mixing_matrix_, sources)
data = np.dot(self.pca_components_[:n_components].T,
pca_data[:n_components])
if n_pca_components > n_components:
data += np.dot(self.pca_components_[n_components:_n_pca_comp].T,
pca_data[n_components:_n_pca_comp])

# re-append deselected pca dimension if desired
if _n_pca_comp > self.n_components_:
pca_reappend = pca_data[:, self.n_components_:_n_pca_comp]
pca_restored = np.c_[pca_restored, pca_reappend]

# restore sensor space data
out = self._inverse_transform_pca(pca_restored)
if self.pca_mean_ is not None:
data += self.pca_mean_[:, None]

# restore scaling
if self.noise_cov is None: # revert standardization
out /= self._pre_whitener
data /= self._pre_whitener[:, None]
else:
out = np.dot(out, linalg.pinv(self._pre_whitener))

return out.T

def _transform_pca(self, data):
"""Apply decorrelation / dimensionality reduction on MEEG data.
"""
X = np.atleast_2d(data)
if self.pca_mean_ is not None:
X = X - self.pca_mean_

X = np.dot(X, self.pca_components_.T)
return X

def _transform_ica(self, data):
"""Apply ICA unmixing matrix to recover the latent sources.
"""
return np.dot(np.atleast_2d(data), self.unmixing_matrix_.T)

def _inverse_transform_pca(self, X):
"""Aux function"""
components = self.pca_components_[:X.shape[1]]
X_orig = np.dot(X, components)

if self.pca_mean_ is not None:
X_orig += self.pca_mean_
data = np.dot(linalg.pinv(self._pre_whitener), data)

return X_orig
return data


@verbose
Expand Down Expand Up @@ -1476,7 +1460,7 @@ def read_ica(fname):
ica.n_components_ = unmixing_matrix.shape[0]
ica.pca_explained_variance_ = pca_explained_variance
ica.unmixing_matrix_ = unmixing_matrix
ica.mixing_matrix_ = linalg.pinv(ica.unmixing_matrix_).T
ica.mixing_matrix_ = linalg.pinv(ica.unmixing_matrix_)
ica.exclude = [] if exclude is None else list(exclude)
ica.info = info
logger.info('Ready.')
Expand Down
Loading

0 comments on commit dc15f39

Please sign in to comment.