forked from PRML/PRMLT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfa.m
51 lines (45 loc) · 1.5 KB
/
fa.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
function [W, mu, psi, llh] = fa(X, m)
% Perform EM algorithm for factor analysis model
% Input:
% X: d x n data matrix
% m: dimension of target space
% Output:
% W: d x m weight matrix
% mu: d x 1 mean vector
% psi: d x 1 variance vector
% llh: loglikelihood
% Reference: Pattern Recognition and Machine Learning by Christopher M. Bishop
% Written by Mo Chen ([email protected]).
[d,n] = size(X);
mu = mean(X,2);
X = bsxfun(@minus,X,mu);
tol = 1e-4;
maxiter = 500;
llh = -inf(1,maxiter);
I = eye(m);
r = dot(X,X,2);
W = randn(d,m);
lambda = 1./rand(d,1);
for iter = 2:maxiter
T = bsxfun(@times,W,sqrt(lambda));
M = T'*T+I; % M = W'*inv(Psi)*W+I
U = chol(M);
WInvPsiX = bsxfun(@times,W,lambda)'*X; % WInvPsiX = W'*inv(Psi)*X
% likelihood
logdetC = 2*sum(log(diag(U)))-sum(log(lambda)); % log(det(C))
Q = U'\WInvPsiX;
trInvCS = (r'*lambda-dot(Q(:),Q(:)))/n; % trace(inv(C)*S)
llh(iter) = -n*(d*log(2*pi)+logdetC+trInvCS)/2;
if abs(llh(iter)-llh(iter-1)) < tol*abs(llh(iter-1)); break; end % check likelihood for convergence
% E step
Ez = M\WInvPsiX; % 12.66
V = inv(U);
Ezz = n*(V*V')+Ez*Ez'; % 12.67
% M step
U = chol(Ezz);
XEz = X*Ez';
W = (XEz/U)/U'; % 12.69
lambda = n./(r-dot(W,XEz,2)); % 12.70
end
llh = llh(2:iter);
psi = 1./lambda;