Skip to content

Commit

Permalink
mixBernRnd not done and nbBern not tested
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Mar 6, 2016
1 parent 7aa53bb commit fae1b73
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 109 deletions.
2 changes: 1 addition & 1 deletion TODO.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
TODO:
extract demos
ch08: BP, EP, NB
ch08: BP, EP, NBMn
ch14: Cart

16 changes: 11 additions & 5 deletions chapter08/demo.m
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
% demo for ch08

%% Naive Bayes with Gauss
%% Naive Bayes with independent Gausssian
d = 2;
k = 3;
n = 1000;
% [X, t] = kmeansRnd(d,k,n);
[X, t] = kmeansRnd(d,k,n);
plotClass(X,t);

model = nbGauss(X,t);
y = nbGaussPred(model,X);
plotClass(X,y);
m = floor(n/2);
X1 = X(:,1:m);
X2 = X(:,(m+1):end);
t1 = t(1:m);
model = nbGauss(X1,t1);
y2 = nbGaussPred(model,X2);
plotClass(X2,y2);

%% Naive Bayes with independent Bernoulli
17 changes: 17 additions & 0 deletions chapter08/nbBern.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
function model = nbBern(X, t)
% Naive bayes classifier with indepenet Bernoulli.
% Input:
% X: d x n data matrix
% t: 1 x n label (1~k)
% Output:
% model: trained model structure
% Written by Mo Chen ([email protected]).
n = size(X,2);
k = max(t);
E = sparse(t,1:n,1,k,n,n);
nk = full(sum(E,2));
w = nk/n;
mu = full(sparse(X)*E'*spdiags(1./nk,0,k,k));

model.mu = mu; % d x k means
model.w = w;
13 changes: 13 additions & 0 deletions chapter08/nbBernPred.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function y = nbBernPred(model, X)
% Prediction of naive Bayes classifier with independent Bernoulli.
% input:
% model: trained model structure
% X: d x n data matrix
% output:
% y: 1 x n predicted class label
% Written by Mo Chen ([email protected]).
mu = model.mu;
w = model.w;
P = exp(log(mu)*sparse(X));
[~,y] = max(bsxfun(@times,P,w),[],1);

7 changes: 0 additions & 7 deletions chapter08/nbMn.m

This file was deleted.

7 changes: 0 additions & 7 deletions chapter08/nbMnPred.m

This file was deleted.

141 changes: 81 additions & 60 deletions chapter09/demo.m
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
% demos for ch09

%% Empirical Bayesian linear regression via EM
% close all; clear;
% d = 5;
% n = 200;
% [x,t] = linRnd(d,n);
% [model,llh] = linRegEm(x,t);
% plot(llh);
%
% %% RVM classification via EM
% clear; close all
% k = 2;
% d = 2;
% n = 1000;
% [X,t] = kmeansRnd(d,k,n);
% [x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n));
%
% [model, llh] = rvmBinEm(X,t-1);
% plot(llh);
% y = rvmBinPred(model,X)+1;
% figure;
% binPlot(model,X,y);
%% kmeans
close all; clear;
d = 5;
n = 200;
[x,t] = linRnd(d,n);
[model,llh] = linRegEm(x,t);
plot(llh);

%% RVM classification via EM
clear; close all
k = 2;
d = 2;
n = 1000;
[X,t] = kmeansRnd(d,k,n);
[x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n));

[model, llh] = rvmBinEm(X,t-1);
plot(llh);
y = rvmBinPred(model,X)+1;
figure;
binPlot(model,X,y);
% kmeans
close all; clear;
d = 20;
k = 6;
Expand All @@ -33,44 +33,65 @@
tic
y = kmeans(X',k);
toc
% y = kmedoids(X,k);
% plotClass(X,label);
% figure;
% plotClass(X,y);
y = kmedoids(X,k);
plotClass(X,label);
figure;
plotClass(X,y);

%% Gausssian Mixture via EM
% close all; clear;
% d = 2;
% k = 3;
% n = 1000;
% [X,label] = mixGaussRnd(d,k,n);
% plotClass(X,label);
%
% m = floor(n/2);
% X1 = X(:,1:m);
% X2 = X(:,(m+1):end);
% % train
% [z1,model,llh] = mixGaussEm(X1,k);
% figure;
% plot(llh);
% figure;
% plotClass(X1,z1);
% % predict
% z2 = mixGaussPred(X2,model);
% figure;
% plotClass(X2,z2);
% %% Gauss mixture initialized by kmeans
% close all; clear;
% d = 2;
% k = 3;
% n = 500;
% [X,label] = mixGaussRnd(d,k,n);
% init = kmeans(X,k);
% [z,model,llh] = mixGaussEm(X,init);
% plotClass(X,label);
% figure;
% plotClass(X,init);
% figure;
% plotClass(X,z);
% figure;
% plot(llh);
close all; clear;
d = 2;
k = 3;
n = 1000;
[X,label] = mixGaussRnd(d,k,n);
plotClass(X,label);

m = floor(n/2);
X1 = X(:,1:m);
X2 = X(:,(m+1):end);
% train
[z1,model,llh] = mixGaussEm(X1,k);
figure;
plot(llh);
figure;
plotClass(X1,z1);
% predict
z2 = mixGaussPred(X2,model);
figure;
plotClass(X2,z2);
%% Gauss mixture initialized by kmeans
close all; clear;
d = 2;
k = 3;
n = 500;
[X,label] = mixGaussRnd(d,k,n);
init = kmeans(X,k);
[z,model,llh] = mixGaussEm(X,init);
plotClass(X,label);
figure;
plotClass(X,init);
figure;
plotClass(X,z);
figure;
plot(llh);

%% Bernoulli Mixture via EM
close all; clear;
d = 2;
k = 3;
n = 1000;
[X,z] = mixBernRnd(d,k,n);

m = floor(n/2);
X1 = X(:,1:m);
X2 = X(:,(m+1):end);
% train
[z1,model,llh] = mixGaussEm(X1,k);
figure;
plot(llh);
figure;
plotClass(X1,z1);
% predict
z2 = mixGaussPred(X2,model);
figure;
plotClass(X2,z2);
8 changes: 4 additions & 4 deletions chapter09/kmeansRnd.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [X, z, center] = kmeansRnd(d, k, n)
function [X, z, mu] = kmeansRnd(d, k, n)
% Generate samples from a Gaussian mixture distribution with common variances (kmeans model).
% Input:
% d: dimension of data
Expand All @@ -7,7 +7,7 @@
% Output:
% X: d x n data matrix
% z: 1 x n response variable
% center: d x k centers of clusters
% mu: d x k centers of clusters
% Written by Mo Chen ([email protected]).
alpha = 1;
beta = nthroot(k,d); % in volume x^d there is k points: x^d=k
Expand All @@ -16,5 +16,5 @@
w = dirichletRnd(alpha,ones(1,k)/k);
z = discreteRnd(w,n);
E = full(sparse(z,1:n,1,k,n,n));
center = randn(d,k)*beta;
X = X+center*E;
mu = randn(d,k)*beta;
X = X+mu*E;
22 changes: 22 additions & 0 deletions chapter09/mixBernRnd.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function [X, z, mu] = mixBernRnd(d, k, n)
% Generate samples from a Bernoulli mixture distribution.
% Input:
% d: dimension of data
% k: number of components
% n: number of data
% Output:
% X: d x n data matrix
% z: 1 x n response variable
% center: d x k centers of clusters
% Written by Mo Chen ([email protected]).
alpha = 1;

w = dirichletRnd(alpha,ones(1,k)/k);
z = discreteRnd(w,n);
mu = rand(1,k);

X = zeros(d,n);
for i = 1:k
idx = z==i;
X(:,idx) = rand(d,sum(idx)) < mu(k);
end
51 changes: 26 additions & 25 deletions chapter10/demo.m
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@
% demos for ch10
% chapter10/12: prediction functions for VB
%% Variational Bayesian for linear\RVM regression
% clear; close all;
%
% d = 100;
% beta = 1e-1;
% X = rand(1,d);
% w = randn;
% b = randn;
% t = w'*X+b+beta*randn(1,d);
% x = linspace(min(X),max(X),d); % test data
%
% [model,llh] = linRegVb(X,t);
% % [model,llh] = rvmRegVb(X,t);
% plot(llh);
% [y, sigma] = linRegPred(model,x,t);
% figure
% plotCurveBar(x,y,sigma);
% hold on;
% plot(X,t,'o');
% hold off
clear; close all;

d = 100;
beta = 1e-1;
X = rand(1,d);
w = randn;
b = randn;
t = w'*X+b+beta*randn(1,d);
x = linspace(min(X),max(X),d); % test data

[model,llh] = linRegVb(X,t);
% [model,llh] = rvmRegVb(X,t);
plot(llh);
[y, sigma] = linRegPred(model,x,t);
figure
plotCurveBar(x,y,sigma);
hold on;
plot(X,t,'o');
hold off
%% Variational Bayesian for Gaussian Mixture Model
close all; clear;
d = 2;
k = 3;
n = 2000;
[X,z] = mixGaussRnd(d,k,n);
plotClass(X,z);
Xt = X(:,n/2+1:end);
X = X(:,1:n/2);
m = floor(n/2);
X1 = X(:,1:m);
X2 = X(:,(m+1):end);
% VB fitting
[y, model, L] = mixGaussVb(X,10);
[y1, model, L] = mixGaussVb(X1,10);
figure;
plotClass(X,y);
plotClass(X1,y1);
figure;
plot(L)
% Predict testing data
[yt, R] = mixGaussVbPred(model,Xt);
[y2, R] = mixGaussVbPred(model,X2);
figure;
plotClass(Xt,yt);
plotClass(X2,y2);

0 comments on commit fae1b73

Please sign in to comment.