Skip to content

Commit

Permalink
[msm.estimators.MaximumLikelihoodHMSM] new submodel inplace kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
thempel committed Jun 26, 2018
1 parent 0214687 commit 8a2fba2
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pyemma/msm/estimators/maximum_likelihood_hmsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import absolute_import
from six.moves import range
from pyemma.util.annotators import alias, aliased, fix_docs

import numpy as _np
Expand Down Expand Up @@ -276,9 +275,9 @@ def _estimate(self, dtrajs):
states_subset = 'populous-strong'

# return submodel (will return self if all None)
self._internal_submodel_call = True
return self.submodel(states=states_subset, obs=observe_subset,
mincount_connectivity=self.mincount_connectivity)
mincount_connectivity=self.mincount_connectivity,
inplace=True)

@property
def msm_init(self):
Expand Down Expand Up @@ -365,7 +364,7 @@ def discrete_trajectories_obs(self):
# Submodel functions using estimation information (counts)
################################################################################

def submodel(self, states=None, obs=None, mincount_connectivity='1/n'):
def submodel(self, states=None, obs=None, mincount_connectivity='1/n', inplace=False):
"""Returns a HMM with restricted state space
Parameters
Expand All @@ -388,6 +387,9 @@ def submodel(self, states=None, obs=None, mincount_connectivity='1/n'):
Counts lower than that will count zero in the connectivity check and
may thus separate the resulting transition matrix. Default value:
1/nstates.
inplace : Bool
if True, submodel is estimated in-place, overwriting the original
estimator and possibly discarding information. Default value: False
Returns
-------
Expand All @@ -410,14 +412,12 @@ def submodel(self, states=None, obs=None, mincount_connectivity='1/n'):
S = _tmatrix_disconnected.connected_sets(self.count_matrix,
mincount_connectivity=mincount_connectivity,
strong=True)
if hasattr(self, '_internal_submodel_call') and self._internal_submodel_call:
if inplace:
submodel_estimator = self
else:
from copy import deepcopy
submodel_estimator = deepcopy(self)

self._internal_submodel_call = False

if len(S) > 1:
# keep only non-negligible transitions
C = _np.zeros(self.count_matrix.shape)
Expand Down

0 comments on commit 8a2fba2

Please sign in to comment.