forked from PRML/PRMLT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
141 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,13 @@ | ||
function [model, llh] = rvmBinFp(X, t, alpha) | ||
% Relevance Vector Machine (ARD sparse prior) for binary classification | ||
% training by empirical bayesian (type II ML) using fix point update (Mackay update) | ||
% Relevance Vector Machine (ARD sparse prior) for binary classification. | ||
% trained by empirical bayesian (type II ML) using Mackay fix point update. | ||
% Input: | ||
% X: d x n data matrix | ||
% t: 1 x n label (0/1) | ||
% alpha: prior parameter | ||
% Output: | ||
% model: trained model structure | ||
% llh: loglikelihood | ||
% Written by Mo Chen ([email protected]). | ||
if nargin < 3 | ||
alpha = 1; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,70 @@ | ||
% demos for ch09 | ||
|
||
%% Empirical Bayesian linear regression via EM | ||
close all; clear; | ||
d = 5; | ||
n = 200; | ||
[x,t] = linRnd(d,n); | ||
[model,llh] = linRegEm(x,t); | ||
plot(llh); | ||
|
||
%% demo: EM linear regression | ||
% close all; clear; | ||
% d = 5; | ||
% n = 200; | ||
% [x,t] = linRnd(d,n); | ||
% [model,llh] = linRegEm(x,t); | ||
% plot(llh); | ||
%% RVM classification via EM | ||
clear; close all | ||
k = 2; | ||
d = 2; | ||
n = 1000; | ||
[X,t] = kmeansRnd(d,k,n); | ||
[x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n)); | ||
|
||
%% classification | ||
% clear; close all | ||
% k = 2; | ||
% d = 2; | ||
% n = 1000; | ||
% [X,t] = kmeansRnd(d,k,n); | ||
% [x1,x2] = meshgrid(linspace(min(X(1,:)),max(X(1,:)),n), linspace(min(X(2,:)),max(X(2,:)),n)); | ||
% | ||
% [model, llh] = rvmBinEm(X,t-1); | ||
% plot(llh); | ||
% y = rvmBinPred(model,X)+1; | ||
% figure; | ||
% binPlot(model,X,y); | ||
%% demo: kmeans | ||
% close all; clear; | ||
% d = 2; | ||
% k = 3; | ||
% n = 500; | ||
% [X,label] = kmeansRnd(d,k,n); | ||
% y = kmeans(X,k); | ||
% plotClass(X,label); | ||
% figure; | ||
% plotClass(X,y); | ||
[model, llh] = rvmBinEm(X,t-1); | ||
plot(llh); | ||
y = rvmBinPred(model,X)+1; | ||
figure; | ||
binPlot(model,X,y); | ||
%% kmeans | ||
close all; clear; | ||
d = 2; | ||
k = 3; | ||
n = 500; | ||
[X,label] = kmeansRnd(d,k,n); | ||
y = kmeans(X,k); | ||
plotClass(X,label); | ||
figure; | ||
plotClass(X,y); | ||
|
||
%% demo: Em for Gauss Mixture | ||
% close all; clear; | ||
% d = 2; | ||
% k = 3; | ||
% n = 1000; | ||
% [X,label] = mixGaussRnd(d,k,n); | ||
% plotClass(X,label); | ||
% | ||
% m = floor(n/2); | ||
% X1 = X(:,1:m); | ||
% X2 = X(:,(m+1):end); | ||
% % train | ||
% [z1,model,llh] = mixGaussEm(X1,k); | ||
% figure; | ||
% plot(llh); | ||
% figure; | ||
% plotClass(X1,z1); | ||
% % predict | ||
% z2 = mixGaussPred(X2,model); | ||
% figure; | ||
% plotClass(X2,z2); | ||
%% demo: Em for Gauss mixture initialized with kmeans; | ||
% close all; clear; | ||
% d = 2; | ||
% k = 3; | ||
% n = 500; | ||
% [X,label] = mixGaussRnd(d,k,n); | ||
% init = kmeans(X,k); | ||
% [z,model,llh] = mixGaussEm(X,init); | ||
% plotClass(X,label); | ||
% figure; | ||
% plotClass(X,init); | ||
% figure; | ||
% plotClass(X,z); | ||
% figure; | ||
% plot(llh); | ||
%% Gausssian Mixture via EM | ||
close all; clear; | ||
d = 2; | ||
k = 3; | ||
n = 1000; | ||
[X,label] = mixGaussRnd(d,k,n); | ||
plotClass(X,label); | ||
|
||
m = floor(n/2); | ||
X1 = X(:,1:m); | ||
X2 = X(:,(m+1):end); | ||
% train | ||
[z1,model,llh] = mixGaussEm(X1,k); | ||
figure; | ||
plot(llh); | ||
figure; | ||
plotClass(X1,z1); | ||
% predict | ||
z2 = mixGaussPred(X2,model); | ||
figure; | ||
plotClass(X2,z2); | ||
%% Gauss mixture initialized by kmeans | ||
close all; clear; | ||
d = 2; | ||
k = 3; | ||
n = 500; | ||
[X,label] = mixGaussRnd(d,k,n); | ||
init = kmeans(X,k); | ||
[z,model,llh] = mixGaussEm(X,init); | ||
plotClass(X,label); | ||
figure; | ||
plotClass(X,init); | ||
figure; | ||
plotClass(X,z); | ||
figure; | ||
plot(llh); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,12 @@ | ||
function [label, energy, model] = kmeans(X, init) | ||
% Perform k-means clustering. | ||
% Perform k-means clustering. | ||
% Input: | ||
% X: d x n data matrix | ||
% k: number of seeds | ||
% init: k number of clusters or label (1 x n vector) | ||
% Output: | ||
% label: 1 x n cluster label | ||
% energy: optimization target value | ||
% model: trained model structure | ||
% Written by Mo Chen ([email protected]). | ||
n = size(X,2); | ||
if numel(init)==1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
function [label, energy] = kmeansPred(model, Xt) | ||
% Prediction for kmeans clusterng | ||
% Input: | ||
% model: trained model structure | ||
% Xt: d x n testing data | ||
% Output: | ||
% label: 1 x n cluster label | ||
% energy: optimization target value | ||
% Written by Mo Chen ([email protected]). | ||
[val,label] = min(sqdist(model.means, Xt)); | ||
energy = sum(val); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,14 @@ | ||
function [X, z, center] = kmeansRnd(d, k, n) | ||
% Sampling from a Gaussian mixture distribution with common variances (kmeans model). | ||
% Written by Michael Chen ([email protected]). | ||
% Generate samples from a Gaussian mixture distribution with common variances (kmeans model). | ||
% Input: | ||
% d: dimension of data | ||
% k: number of components | ||
% n: number of data | ||
% Output: | ||
% X: d x n data matrix | ||
% z: 1 x n response variable | ||
% center: d x k centers of clusters | ||
% Written by Mo Chen ([email protected]). | ||
alpha = 1; | ||
beta = nthroot(k,d); % in volume x^d there is k points: x^d=k | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,12 @@ | ||
function [label, model, llh] = mixBernEm(X, k) | ||
% Perform EM algorithm for fitting the Bernoulli mixture model. | ||
% Input: | ||
% X: d x n data matrix | ||
% init: k (1 x 1) or label (1 x n, 1<=label(i)<=k) or center (d x k) | ||
% k: number of cluster (1 x 1) or label (1 x n, 1<=label(i)<=k) or model structure | ||
% Output: | ||
% label: 1 x n cluster label | ||
% model: trained model structure | ||
% llh: loglikelihood | ||
% Written by Mo Chen ([email protected]). | ||
%% initialization | ||
fprintf('EM for mixture model: running ... \n'); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
function [label, R] = mixGaussPred(X, model) | ||
% Predict label and responsibility for Gaussian mixture model. | ||
% Input: | ||
% X: d x n data matrix | ||
% model: trained model structure outputed by the EM algirthm | ||
% Output: | ||
% label: 1 x n cluster label | ||
% R: k x n responsibility | ||
% Written by Mo Chen ([email protected]). | ||
mu = model.mu; | ||
Sigma = model.Sigma; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,14 @@ | ||
function [X, z, model] = mixGaussRnd(d, k, n) | ||
% Sampling form a Gaussian mixture distribution. | ||
% Written by Michael Chen ([email protected]). | ||
% Input: | ||
% d: dimension of data | ||
% k: number of components | ||
% n: number of data | ||
% Output: | ||
% X: d x n data matrix | ||
% z: 1 x n response variable | ||
% model: model structure | ||
% Written by Mo Chen ([email protected]). | ||
alpha0 = 1; % hyperparameter of Dirichlet prior | ||
W0 = eye(d); % hyperparameter of inverse Wishart prior of covariances | ||
v0 = d+1; % hyperparameter of inverse Wishart prior of covariances | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,12 @@ | ||
function [label, model, llh] = mixMnEm(X, k) | ||
% Perform EM algorithm for fitting the multinomial mixture model. | ||
% Input: | ||
% X: d x n data matrix | ||
% init: k (1 x 1) or label (1 x n, 1<=label(i)<=k) or center (d x k) | ||
% k: number of cluster (1 x 1) or label (1 x n, 1<=label(i)<=k) or model structure | ||
% Output: | ||
% label: 1 x n cluster label | ||
% model: trained model structure | ||
% llh: loglikelihood | ||
% Written by Mo Chen ([email protected]). | ||
%% initialization | ||
fprintf('EM for mixture model: running ... \n'); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,13 @@ | ||
function [model, llh] = rvmBinEm(X, t, alpha) | ||
% Relevance Vector Machine (ARD sparse prior) for binary classification | ||
% training by empirical bayesian (type II ML) using fix point update (Mackay update) | ||
% Relevance Vector Machine (ARD sparse prior) for binary classification. | ||
% trained by empirical bayesian (type II ML) using EM. | ||
% Input: | ||
% X: d x n data matrix | ||
% t: 1 x n label (0/1) | ||
% alpha: prior parameter | ||
% Output: | ||
% model: trained model structure | ||
% llh: loglikelihood | ||
% Written by Mo Chen ([email protected]). | ||
if nargin < 3 | ||
alpha = 1; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,14 @@ | ||
function [model, llh] = rvmRegEm(X, t, alpha, beta) | ||
% Relevance Vector Machine (ARD sparse prior) for regression | ||
% training by empirical bayesian (type II ML) using standard EM update | ||
% trained by empirical bayesian (type II ML) using EM | ||
% Input: | ||
% X: d x n data | ||
% t: 1 x n response | ||
% alpha: prior parameter | ||
% beta: prior parameter | ||
% Output: | ||
% model: trained model structure | ||
% llh: loglikelihood | ||
% Written by Mo Chen ([email protected]). | ||
if nargin < 3 | ||
alpha = 0.02; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,6 @@ | |
% y: 1 x n response variable | ||
% W: d+1 x k weight matrix | ||
% Written by Mo Chen ([email protected]). | ||
% Written by Mo Chen ([email protected]). | ||
W = randn(d+1,k); | ||
[X, z] = kmeansRnd(d, k, n); | ||
y = zeros(1,n); | ||
|