Skip to content

Commit

Permalink
refactor kalmanFilter and fix kalmanSmoother
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Nov 28, 2018
1 parent 2025472 commit 0523c2c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions chapter13/LDS/kalmanFilter.m
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
llh(1) = logGauss(X(:,1),C*mu0,R);
for i = 2:n
[mu(:,i), V(:,:,i), llh(i)] = ...
forwardStep(X(:,i), mu(:,i-1), V(:,:,i-1), A, G, C, S, I);
forwardUpdate(X(:,i), mu(:,i-1), V(:,:,i-1), A, G, C, S, I);
end
llh = sum(llh);

function [mu, V, llh] = forwardStep(x, mu, V, A, G, C, S, I)
function [mu, V, llh] = forwardUpdate(x, mu, V, A, G, C, S, I)
P = A*V*A'+G; % 13.88
PC = P*C';
R = C*PC+S;
Expand Down
28 changes: 14 additions & 14 deletions chapter13/LDS/kalmanSmoother.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [nu, U, Ezz, Ezy, llh] = kalmanSmoother(model, X)
function [nu, U, llh, Ezz, Ezy] = kalmanSmoother(model, X)
% Kalman smoother (forward-backward algorithm for linear dynamic system)
% NOTE: This is the exact implementation of the Kalman smoother algorithm in PRML.
% However, this algorithm is not practical. It is numerical unstable.
Expand Down Expand Up @@ -26,20 +26,19 @@
P = zeros(q,q,n); % C_{t+1|t}
Amu = zeros(q,n); % u_{t+1|t}
llh = zeros(1,n);
I = eye(q);

% forward
PC = P0*C';
R = C*PC+S;
K = PC/R;
mu(:,1) = mu0+K*(X(:,1)-C*mu0);
V(:,:,1) = (I-K*C)*P0;
V(:,:,1) = (eye(q)-K*C)*P0;
P(:,:,1) = P0; % useless, just make a point
Amu(:,1) = mu0; % useless, just make a point
llh(1) = logGauss(X(:,1),C*mu0,R);
for i = 2:n
[mu(:,i), V(:,:,i), Amu(:,i), P(:,:,i), llh(i)] = ...
forwardStep(X(:,i), mu(:,i-1), V(:,:,i-1), A, G, C, S, I);
forwardUpdate(X(:,i), mu(:,i-1), V(:,:,i-1), A, G, C, S);
end
llh = sum(llh);
% backward
Expand All @@ -53,24 +52,25 @@
Ezz(:,:,n) = U(:,:,n)+nu(:,n)*nu(:,n)';
for i = n-1:-1:1
[nu(:,i), U(:,:,i), Ezz(:,:,i), Ezy(:,:,i)] = ...
backwardStep(nu(:,i+1), U(:,:,i+1), mu(:,i), V(:,:,i), Amu(:,i+1), P(:,:,i+1), A);
backwardUpdate(nu(:,i+1), U(:,:,i+1), mu(:,i), V(:,:,i), Amu(:,i+1), P(:,:,i+1), A);
end

function [mu, V, Amu, P, llh] = forwardStep(x, mu0, V0, A, G, C, S, I)
function [mu1, V1, Amu, P, llh] = forwardUpdate(x, mu0, V0, A, G, C, S)
k = numel(mu0);
P = A*V0*A'+G; % 13.88
PC = P*C';
R = C*PC+S;
K = PC/R; % 13.92
Amu = A*mu0;
CAmu = C*Amu;
mu = Amu+K*(x-CAmu); % 13.89
V = (I-K*C)*P; % 13.90
mu1 = Amu+K*(x-CAmu); % 13.89
V1 = (eye(k)-K*C)*P; % 13.90
llh = logGauss(x,CAmu,R); % 13.91


function [nu, U, Ezz, Ezy] = backwardStep(nu0, U0, mu, V, Amu, P, A)
J = V*A'/P; % 13.102
nu = mu+J*(nu0-Amu); % 13.100
U = V+J*(U0-P)*J'; % 13.101
Ezy = J*U0+nu0*nu'; % 13.106
Ezz = U+nu*nu'; % 13.107
function [nu0, U0, E00, E10] = backwardUpdate(nu1, U1, mu, V, Amu, P, A)
J = V*A'/P; % 13.102
nu0 = mu+J*(nu1-Amu); % 13.100
U0 = V+J*(U1-P)*J'; % 13.101
E00 = U0+nu0*nu0'; % 13.107
E10 = U1*J'+nu1*nu0'; % 13.106

0 comments on commit 0523c2c

Please sign in to comment.