Skip to content

Commit

Permalink
add sample function of mixture model from prior
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Mar 9, 2017
1 parent 018fe82 commit 362845e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
15 changes: 15 additions & 0 deletions chapter11/GaussWishart.m
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
function obj = clone(obj)
end

function d = dim(obj)
d = numel(obj.m_);
end

function obj = addData(obj, X)
kappa0 = obj.kappa_;
m0 = obj.m_;
Expand Down Expand Up @@ -89,5 +93,16 @@
c = gammaln((v+d)/2)-gammaln(v/2)-(d*log(v*pi)+2*sum(log(diag(U))))/2;
y = c+o;
end

function [mu, Sigma] = sample(obj)
% Sample a Gaussian distribution from GaussianWishart prior
kappa = obj.kappa_;
m = obj.m_;
nu = obj.nu_;
U = obj.U_;

Sigma = iwishrnd(U'*U,nu);
mu = gaussRnd(m,Sigma/kappa);
end
end
end
4 changes: 2 additions & 2 deletions chapter11/mixDpGb.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
n = size(X,2);
[label,Theta,w] = mixDpGbOl(X,alpha,theta);
nk = n*w;
maxIter = 200;
maxIter = 50;
llh = zeros(1,maxIter);
for iter = 1:maxIter
for i = randperm(n)
Expand All @@ -34,7 +34,7 @@
llh(iter) = llh(iter)+sum(p-log(n));
k = discreteRnd(exp(p-logsumexp(p)));
if k == numel(Theta)+1 % add extra cluster
Theta{k} = theta.clone.addSample(x);
Theta{k} = theta.clone().addSample(x);
nk = [nk,1];
else
Theta{k} = Theta{k}.addSample(x);
Expand Down
18 changes: 18 additions & 0 deletions chapter11/mixGaussSample.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
function [X, z] = mixGaussSample(Theta, w, n )
% Genarate samples form a Gaussian mixture model with GaussianWishart prior.
% Input:
% Theta: cell of GaussianWishart priors of components
% w: weight of components
% n: number of data
% Output:
% X: d x n data matrix
% z: 1 x n response variable
% Written by Mo Chen ([email protected]).
z = discreteRnd(w,n);
d = Theta{1}.dim();
X = zeros(d,n);
for i = 1:numel(w)
idx = z==i;
[mu,Sigma] = Theta{i}.sample(); % invpd(wishrnd(W0,v0));
X(:,idx) = gaussRnd(mu,Sigma,sum(idx));
end
13 changes: 9 additions & 4 deletions demo/ch11/mixGaussGb_demo.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
d = 2;
k = 3;
n = 500;
[X,label] = mixGaussRnd(d,k,n);
plotClass(X,label);
[X,z] = mixGaussRnd(d,k,n);
plotClass(X,z);

[y,model] = mixGaussGb(X);
[z,Theta,w,llh] = mixGaussGb(X);
figure
plotClass(X,y);
plotClass(X,z);

[X,z] = mixGaussSample(Theta,w,n);
figure
plotClass(X,z);

0 comments on commit 362845e

Please sign in to comment.