diff --git a/pyemma/msm/estimators/_dtraj_stats.py b/pyemma/msm/estimators/_dtraj_stats.py index 1c3a7c2bd..548afa95b 100644 --- a/pyemma/msm/estimators/_dtraj_stats.py +++ b/pyemma/msm/estimators/_dtraj_stats.py @@ -23,7 +23,7 @@ from msmtools import estimation as msmest from pyemma.util.annotators import alias, aliased from pyemma.util.linalg import submatrix -from pyemma.util.discrete_trajectories import visited_set +from pyemma.util.discrete_trajectories import visited_set, rewrite_dtrajs_to_core_sets __author__ = 'noe' @@ -148,7 +148,8 @@ def _compute_connected_sets(C, mincount_connectivity, strong=True): S = msmest.connected_sets(Cconn, directed=strong) return S - def count_lagged(self, lag, count_mode='sliding', mincount_connectivity='1/n', show_progress=True, n_jobs=None, name=''): + def count_lagged(self, lag, count_mode='sliding', mincount_connectivity='1/n', + show_progress=True, n_jobs=None, name='', core_set=None, milestoning_method='last_core'): r""" Counts transitions at given lag time Parameters @@ -182,11 +183,27 @@ def count_lagged(self, lag, count_mode='sliding', mincount_connectivity='1/n', s # Compute count matrix count_mode = count_mode.lower() - if count_mode == 'sliding': + if count_mode in ('sliding', 'sample') and core_set is not None: + if milestoning_method == 'last_core': + + # assign -1 frames to last visited core + for d in self._dtrajs: + while -1 in d: + mask = (d == -1) + d[mask] = d[np.roll(mask, -1)] + self._C = msmest.count_matrix(self._dtrajs, lag, sliding=count_mode == 'sliding') + + else: + raise NotImplementedError('Milestoning method {} not implemented.'.format(milestoning_method)) + + + elif count_mode == 'sliding': self._C = msmest.count_matrix(self._dtrajs, lag, sliding=True) elif count_mode == 'sample': self._C = msmest.count_matrix(self._dtrajs, lag, sliding=False) elif count_mode == 'effective': + if core_set is not None: + raise RuntimeError('Cannot estimate core set MSM with effective counting.') from pyemma.util.reflection import getargspec_no_self argspec = getargspec_no_self(msmest.effective_count_matrix) kw = {} diff --git a/pyemma/msm/estimators/maximum_likelihood_msm.py b/pyemma/msm/estimators/maximum_likelihood_msm.py index 56fbaba52..28e2cb47b 100644 --- a/pyemma/msm/estimators/maximum_likelihood_msm.py +++ b/pyemma/msm/estimators/maximum_likelihood_msm.py @@ -209,11 +209,16 @@ def _get_dtraj_stats(self, dtrajs): # TODO: reassign dtrajs needed? dtrajstats = dtrajs else: - self._dtrajs_orginal = dtrajs - # check for -1 in dtrajs and possibly rewrite to core_set - from pyemma.util.discrete_trajectories import milestone_counting - self._dtrajs_full, self._dtrajs_milestone_counting_offsets, self.n_cores = \ - milestone_counting(dtrajs, core_set=self.core_set, in_place=False) + if self.core_set is None and any(-1 in d for d in dtrajs): + raise ValueError('Empty core set definition not compatible with unassigned states (-1) in trajectory.') + if self.core_set is not None or any(-1 in d for d in dtrajs): + self._dtrajs_orginal = dtrajs + # check for -1 in dtrajs and possibly rewrite to core_set + from pyemma.util.discrete_trajectories import rewrite_dtrajs_to_core_sets + self._dtrajs_full, self._dtrajs_milestone_counting_offsets, self.n_cores = \ + rewrite_dtrajs_to_core_sets(dtrajs, core_set=self.core_set, in_place=False) + else: + self._dtrajs_full = dtrajs # compute and store discrete trajectory statistics dtrajstats = _DiscreteTrajectoryStats(self._dtrajs_full) @@ -222,13 +227,13 @@ def _get_dtraj_stats(self, dtrajs): self.logger.warning('Building a dense MSM with {nstates} states. This can be ' 'inefficient or unfeasible in terms of both runtime and memory consumption. ' 'Consider using sparse=True.'.format(nstates=dtrajstats.nstates)) - + self.milestoning_method = 'last_core' # count lagged dtrajstats.count_lagged(self.lag, count_mode=self.count_mode, mincount_connectivity=self.mincount_connectivity, n_jobs=getattr(self, 'n_jobs', None), show_progress=getattr(self, 'show_progress', False), - name=self.name) + name=self.name, core_set=self.core_set, milestoning_method=self.milestoning_method) # for other statistics return dtrajstats diff --git a/pyemma/msm/tests/test_msm.py b/pyemma/msm/tests/test_msm.py index 3655e9588..28e2eba2b 100644 --- a/pyemma/msm/tests/test_msm.py +++ b/pyemma/msm/tests/test_msm.py @@ -1113,10 +1113,9 @@ def test_core(self): assert len(np.setdiff1d(uniq, core_set)) == 0 def test_indices_remapping(self): - dtrajs = [[-1, -1, 1, 0, 0, 1], [-1, 1, 0, 1, 3], [0, 1, 2, 3]] + dtrajs = [[5, 5, 1, 0, 0, 1], [5, 1, 0, 1, 3], [0, 1, 2, 3]] desired_offsets = [2, 1, 0] - # implicit core_set (omit -1) - msm = pyemma.msm.estimate_markov_model(dtrajs, lag=1) + msm = pyemma.msm.estimate_markov_model(dtrajs, lag=1, core_set=[0, 1, 2, 3]) np.testing.assert_equal(msm.dtrajs_milestone_counting_offsets, desired_offsets) # sampling @@ -1144,7 +1143,7 @@ def test_compare2hmm(self): def test_compare2hmm_bayes(self): """test core set MSM with Bayesian sampling, compare ITS to 2-state BHMM; double-well""" - cmsm = pyemma.msm.bayesian_markov_model(self.dtraj, lag=5, core_set=[34, 65], nsamples=20) + cmsm = pyemma.msm.bayesian_markov_model(self.dtraj, lag=5, core_set=[34, 65], nsamples=20, count_mode='sliding') hmm = pyemma.msm.bayesian_hidden_markov_model(self.dtraj, 2, lag=5, nsamples=20) has_overlap = not (np.all(cmsm.sample_conf('timescales') < hmm.sample_conf('timescales')[0]) or diff --git a/pyemma/util/tests/test_discrete_trajectories.py b/pyemma/util/tests/test_discrete_trajectories.py index 463d2fd68..ef61bf248 100644 --- a/pyemma/util/tests/test_discrete_trajectories.py +++ b/pyemma/util/tests/test_discrete_trajectories.py @@ -243,11 +243,6 @@ def test_core_sets_6(self): dtrajs = [np.array([0, 1, 1, 2]), np.array([0, 0, 0])] import warnings - if sys.version_info[0] == 2: # yeah python 2 bugs ftw... - if hasattr(dt.rewrite_dtrajs_to_core_sets, '__globals__'): - if dt.rewrite_dtrajs_to_core_sets.__globals__.has_key('__warningregistry__'): - dt.rewrite_dtrajs_to_core_sets.__globals__['__warningregistry__'].clear() - with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always", category=UserWarning, append=False) dtraj_core, offsets, _ = dt.rewrite_dtrajs_to_core_sets(dtrajs, core_set=[1, 2])