Skip to content

Commit

Permalink
[msm] milestone counting, added state assignment last_core method
Browse files Browse the repository at this point in the history
  • Loading branch information
thempel committed Jul 1, 2019
1 parent 800686a commit 94db159
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 19 deletions.
23 changes: 20 additions & 3 deletions pyemma/msm/estimators/_dtraj_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from msmtools import estimation as msmest
from pyemma.util.annotators import alias, aliased
from pyemma.util.linalg import submatrix
from pyemma.util.discrete_trajectories import visited_set
from pyemma.util.discrete_trajectories import visited_set, rewrite_dtrajs_to_core_sets

__author__ = 'noe'

Expand Down Expand Up @@ -148,7 +148,8 @@ def _compute_connected_sets(C, mincount_connectivity, strong=True):
S = msmest.connected_sets(Cconn, directed=strong)
return S

def count_lagged(self, lag, count_mode='sliding', mincount_connectivity='1/n', show_progress=True, n_jobs=None, name=''):
def count_lagged(self, lag, count_mode='sliding', mincount_connectivity='1/n',
show_progress=True, n_jobs=None, name='', core_set=None, milestoning_method='last_core'):
r""" Counts transitions at given lag time
Parameters
Expand Down Expand Up @@ -182,11 +183,27 @@ def count_lagged(self, lag, count_mode='sliding', mincount_connectivity='1/n', s

# Compute count matrix
count_mode = count_mode.lower()
if count_mode == 'sliding':
if count_mode in ('sliding', 'sample') and core_set is not None:
if milestoning_method == 'last_core':

# assign -1 frames to last visited core
for d in self._dtrajs:
while -1 in d:
mask = (d == -1)
d[mask] = d[np.roll(mask, -1)]
self._C = msmest.count_matrix(self._dtrajs, lag, sliding=count_mode == 'sliding')

else:
raise NotImplementedError('Milestoning method {} not implemented.'.format(milestoning_method))


elif count_mode == 'sliding':
self._C = msmest.count_matrix(self._dtrajs, lag, sliding=True)
elif count_mode == 'sample':
self._C = msmest.count_matrix(self._dtrajs, lag, sliding=False)
elif count_mode == 'effective':
if core_set is not None:
raise RuntimeError('Cannot estimate core set MSM with effective counting.')
from pyemma.util.reflection import getargspec_no_self
argspec = getargspec_no_self(msmest.effective_count_matrix)
kw = {}
Expand Down
19 changes: 12 additions & 7 deletions pyemma/msm/estimators/maximum_likelihood_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,16 @@ def _get_dtraj_stats(self, dtrajs):
# TODO: reassign dtrajs needed?
dtrajstats = dtrajs
else:
self._dtrajs_orginal = dtrajs
# check for -1 in dtrajs and possibly rewrite to core_set
from pyemma.util.discrete_trajectories import milestone_counting
self._dtrajs_full, self._dtrajs_milestone_counting_offsets, self.n_cores = \
milestone_counting(dtrajs, core_set=self.core_set, in_place=False)
if self.core_set is None and any(-1 in d for d in dtrajs):
raise ValueError('Empty core set definition not compatible with unassigned states (-1) in trajectory.')
if self.core_set is not None or any(-1 in d for d in dtrajs):
self._dtrajs_orginal = dtrajs
# check for -1 in dtrajs and possibly rewrite to core_set
from pyemma.util.discrete_trajectories import rewrite_dtrajs_to_core_sets
self._dtrajs_full, self._dtrajs_milestone_counting_offsets, self.n_cores = \
rewrite_dtrajs_to_core_sets(dtrajs, core_set=self.core_set, in_place=False)
else:
self._dtrajs_full = dtrajs

# compute and store discrete trajectory statistics
dtrajstats = _DiscreteTrajectoryStats(self._dtrajs_full)
Expand All @@ -222,13 +227,13 @@ def _get_dtraj_stats(self, dtrajs):
self.logger.warning('Building a dense MSM with {nstates} states. This can be '
'inefficient or unfeasible in terms of both runtime and memory consumption. '
'Consider using sparse=True.'.format(nstates=dtrajstats.nstates))

self.milestoning_method = 'last_core'
# count lagged
dtrajstats.count_lagged(self.lag, count_mode=self.count_mode,
mincount_connectivity=self.mincount_connectivity,
n_jobs=getattr(self, 'n_jobs', None),
show_progress=getattr(self, 'show_progress', False),
name=self.name)
name=self.name, core_set=self.core_set, milestoning_method=self.milestoning_method)
# for other statistics
return dtrajstats

Expand Down
7 changes: 3 additions & 4 deletions pyemma/msm/tests/test_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,10 +1113,9 @@ def test_core(self):
assert len(np.setdiff1d(uniq, core_set)) == 0

def test_indices_remapping(self):
dtrajs = [[-1, -1, 1, 0, 0, 1], [-1, 1, 0, 1, 3], [0, 1, 2, 3]]
dtrajs = [[5, 5, 1, 0, 0, 1], [5, 1, 0, 1, 3], [0, 1, 2, 3]]
desired_offsets = [2, 1, 0]
# implicit core_set (omit -1)
msm = pyemma.msm.estimate_markov_model(dtrajs, lag=1)
msm = pyemma.msm.estimate_markov_model(dtrajs, lag=1, core_set=[0, 1, 2, 3])
np.testing.assert_equal(msm.dtrajs_milestone_counting_offsets, desired_offsets)

# sampling
Expand Down Expand Up @@ -1144,7 +1143,7 @@ def test_compare2hmm(self):
def test_compare2hmm_bayes(self):
"""test core set MSM with Bayesian sampling, compare ITS to 2-state BHMM; double-well"""

cmsm = pyemma.msm.bayesian_markov_model(self.dtraj, lag=5, core_set=[34, 65], nsamples=20)
cmsm = pyemma.msm.bayesian_markov_model(self.dtraj, lag=5, core_set=[34, 65], nsamples=20, count_mode='sliding')
hmm = pyemma.msm.bayesian_hidden_markov_model(self.dtraj, 2, lag=5, nsamples=20)

has_overlap = not (np.all(cmsm.sample_conf('timescales') < hmm.sample_conf('timescales')[0]) or
Expand Down
5 changes: 0 additions & 5 deletions pyemma/util/tests/test_discrete_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,6 @@ def test_core_sets_6(self):
dtrajs = [np.array([0, 1, 1, 2]), np.array([0, 0, 0])]
import warnings

if sys.version_info[0] == 2: # yeah python 2 bugs ftw...
if hasattr(dt.rewrite_dtrajs_to_core_sets, '__globals__'):
if dt.rewrite_dtrajs_to_core_sets.__globals__.has_key('__warningregistry__'):
dt.rewrite_dtrajs_to_core_sets.__globals__['__warningregistry__'].clear()

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", category=UserWarning, append=False)
dtraj_core, offsets, _ = dt.rewrite_dtrajs_to_core_sets(dtrajs, core_set=[1, 2])
Expand Down

0 comments on commit 94db159

Please sign in to comment.