forked from PRML/PRMLT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mixGaussVbPred.m
33 lines (31 loc) · 1.03 KB
/
mixGaussVbPred.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
function [z, R] = mixGaussVbPred(model, X)
% Predict label and responsibility for Gaussian mixture model trained by VB.
% Input:
% X: d x n data matrix
% model: trained model structure outputed by the EM algirthm
% Output:
% label: 1 x n cluster label
% R: k x n responsibility
% Written by Mo Chen ([email protected]).
alpha = model.alpha; % Dirichlet
kappa = model.kappa; % Gaussian
m = model.m; % Gasusian
v = model.v; % Whishart
U = model.U; % Whishart
logW = model.logW;
n = size(X,2);
[d,k] = size(m);
EQ = zeros(n,k);
for i = 1:k
Q = (U(:,:,i)'\bsxfun(@minus,X,m(:,i)));
EQ(:,i) = d/kappa(i)+v(i)*dot(Q,Q,1); % 10.64
end
ElogLambda = sum(psi(0,0.5*bsxfun(@minus,v+1,(1:d)')),1)+d*log(2)+logW; % 10.65
Elogpi = psi(0,alpha)-psi(0,sum(alpha)); % 10.66
logRho = -0.5*bsxfun(@minus,EQ,ElogLambda-d*log(2*pi)); % 10.46
logRho = bsxfun(@plus,logRho,Elogpi); % 10.46
logR = bsxfun(@minus,logRho,logsumexp(logRho,2)); % 10.49
R = exp(logR);
z = zeros(1,n);
[~,z(:)] = max(R,[],2);
[~,~,z(:)] = unique(z);