-
Notifications
You must be signed in to change notification settings - Fork 92
/
dnetEstep.m
43 lines (40 loc) · 1.27 KB
/
dnetEstep.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
function model = dnetEstep(model, Ypred)
% DNETESTEP Do an E-step (update importance weights) on an Density Network model.
% FORMAT
% DESC updates the importance weights (or posterior responsibilities) for
% a Density Network model.
% ARG model : the model which is to be updated.
% RETURN model : the model with updated weights.
%
% FORMAT
% DESC updates the importance weights (or posterior responsibilities) for
% a Density Network model.
% ARG model : the model which is to be updated.
% ARG ypred : model predictions at the mixture component centres.
% RETURN model : the model with updated weights.
%
% SEEALSO : dnetCreate, dnetMstep
%
% COPYRIGHT : Neil D. Lawrence, 2008
% MLTOOLS
diffVal = zeros(model.N, model.M);
if nargin < 2
Ypred = dnetOut(model, model.X_u);
end
if model.N > model.M
for k = 1:model.M
diffY = model.y - repmat(Ypred(k, :), model.N, 1);
diffVal(:, k) = -0.5*sum(diffY.*diffY, 2)*model.beta;
end
else
for i = 1:model.N
diffY = repmat(model.y(i, :), model.M, 1) - Ypred;
diffVal(i, :) = -0.5*sum(diffY.*diffY, 2)'*model.beta;
end
end
diffVal = diffVal - repmat(max(diffVal, [], 2), 1, model.M);
w = exp(diffVal);
w = w./repmat(sum(w, 2), 1, model.M);
model.w = sparse(w);
model.X = model.w*model.X_u;
end