Skip to content

Commit

Permalink
更新EM算法部分内容, 未完, 待整理
Browse files Browse the repository at this point in the history
  • Loading branch information
SmirkCao committed Sep 18, 2018
1 parent 73493d8 commit 61efaab
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
47 changes: 43 additions & 4 deletions CH10/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

class HMM(object):

def __init__(self, n_component=0, Q=None, V=None):
def __init__(self, n_component=0,
Q=None,
V=None,
n_iters=5):
self.A = None
self.B = None
self.p = None
Expand All @@ -21,6 +24,21 @@ def __init__(self, n_component=0, Q=None, V=None):
self.T = 0
self.Q = Q
self.V = V
self.n_iters = n_iters
self.alpha = None
self.beta = None
self.gamma = None
self.xi = None
self.Ei = None
self.Ei_ = None
self.Ei_j = None

def init_param(self):
self.A = None
self.B = None
self.p = None

return self

def _do_forward(self, X):
# todo: logsumexp trick
Expand Down Expand Up @@ -81,14 +99,35 @@ def backward(self, obs_seq):
print(prob, prob)
return X

def _do_estep(self):
pass
def _do_estep(self, X):
# 在hmmlearn里面是会没有专门的estep的
_, self.alpha = self._do_forward(X)
_, self.beta = self._do_backward(X)

post_prior = self.alpha*self.beta
self.gamma = post_prior/np.sum(post_prior)
trans_post_prior = self.alpha*self.A*self.B*self.beta
self.xi = trans_post_prior/np.sum(trans_post_prior)
self.Ei = np.sum(self.gamma, axis=1)
self.Ei_ = np.sum(self.gamma[:-1], axis=1)
self.Ei_j = np.sum(self.xi, axis=1)
return self

def _do_mstep(self):
pass
self.A = self.Ei_j/self.Ei
self.B = 1/self.Ei
self.p = self.gamma[:, 0]
return self

def fit(self, X):
# 估计模型参数
self.init_param()
for n_iter in range(self.n_iters):
self._do_estep(X)
self._do_mstep()
# convergence check
if False:
return rst
return self

def predict(self, X):
Expand Down
4 changes: 4 additions & 0 deletions CH10/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def test_backward(self):
self.assertAlmostEqual(prob, 0.13022, places=5)

def test_bkw_frw(self):
# 并没有实际的测试内容
Q = {0: 1, 1: 2, 2: 3}
V = {0: "red", 1: "white"}
hmm_forward = HMM(n_component=3)
Expand All @@ -152,6 +153,9 @@ def test_bkw_frw(self):
alpha = hmm_forward.forward(X)
print(alpha, beta)

def test_EM(self):
pass


if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
Expand Down

0 comments on commit 61efaab

Please sign in to comment.