forked from PRML/PRMLT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
1,410 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
function y = logDirichlet(X, a) | ||
% Compute log pdf of a Dirichlet distribution. | ||
% X: d x n data matrix satifying (sum(X,1)==ones(1,n) && X>=0) | ||
% a: d x k parameters | ||
% y: k x n probability density | ||
% Written by Mo Chen ([email protected]). | ||
X = bsxfun(@times,X,1./sum(X,1)); | ||
if size(a,1) == 1 | ||
a = repmat(a,size(X,1),1); | ||
end | ||
c = gammaln(sum(a,1))-sum(gammaln(a),1); | ||
g = (a-1)'*log(X); | ||
y = bsxfun(@plus,g,c'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
function y = logGauss(X, mu, sigma) | ||
% Compute log pdf of a Gaussian distribution. | ||
% Written by Mo Chen ([email protected]). | ||
|
||
[d,n] = size(X); | ||
k = size(mu,2); | ||
if n == k && size(sigma,1) == 1 | ||
X = bsxfun(@times,X-mu,1./sigma); | ||
q = dot(X,X,1); % M distance | ||
c = d*log(2*pi)+2*log(sigma); % normalization constant | ||
y = -0.5*(c+q); | ||
elseif size(sigma,1)==d && size(sigma,2)==d && k==1 % one mu and one dxd sigma | ||
X = bsxfun(@minus,X,mu); | ||
[R,p]= chol(sigma); | ||
if p ~= 0 | ||
error('ERROR: sigma is not PD.'); | ||
end | ||
Q = R'\X; | ||
q = dot(Q,Q,1); % quadratic term (M distance) | ||
c = d*log(2*pi)+2*sum(log(diag(R))); % normalization constant | ||
y = -0.5*(c+q); | ||
elseif size(sigma,1)==d && size(sigma,2)==k % k mu and k diagonal sigma | ||
lambda = 1./sigma; | ||
ml = mu.*lambda; | ||
q = bsxfun(@plus,X'.^2*lambda-2*X'*ml,dot(mu,ml,1)); % M distance | ||
c = d*log(2*pi)+2*sum(log(sigma),1); % normalization constant | ||
y = -0.5*bsxfun(@plus,q,c); | ||
elseif size(sigma,1)==1 && (size(sigma,2)==k || size(sigma,2)==1) % k mu and (k or one) scalar sigma | ||
X2 = repmat(dot(X,X,1)',1,k); | ||
D = bsxfun(@plus,X2-2*X'*mu,dot(mu,mu,1)); | ||
q = bsxfun(@times,D,1./sigma); % M distance | ||
c = d*(log(2*pi)+2*log(sigma)); % normalization constant | ||
y = -0.5*bsxfun(@plus,q,c); | ||
else | ||
error('Parameters mismatched.'); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
function z = logKde (X, Y, sigma) | ||
% Compute log pdf of kernel density estimator. | ||
% Written by Mo Chen ([email protected]). | ||
D = bsxfun(@plus,full(dot(X,X,1)),full(dot(Y,Y,1))')-full(2*(Y'*X)); | ||
z = logSumExp(D/(-2*sigma^2),1)-0.5*log(2*pi)-log(sigma*size(Y,2)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
function z = logMn (x, p) | ||
% Compute log pdf of a multinomial distribution. | ||
% Written by Mo Chen ([email protected]). | ||
if numel(x) ~= numel(p) | ||
n = numel(x); | ||
x = reshape(x,1,n); | ||
[u,~,label] = unique(x); | ||
x = full(sum(sparse(label,1:n,1,n,numel(u),n),2)); | ||
end | ||
z = gammaln(sum(x)+1)-sum(gammaln(x+1))+dot(x,log(p)); | ||
endfunction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
function y = logMvGamma(x,d) | ||
% Compute logarithm multivariate Gamma function. | ||
% Gamma_p(x) = pi^(p(p-1)/4) prod_(j=1)^p Gamma(x+(1-j)/2) | ||
% log Gamma_p(x) = p(p-1)/4 log pi + sum_(j=1)^p log Gamma(x+(1-j)/2) | ||
% Written by Michael Chen ([email protected]). | ||
s = size(x); | ||
x = reshape(x,1,prod(s)); | ||
x = bsxfun(@plus,repmat(x,d,1),(1-(1:d)')/2); | ||
y = d*(d-1)/4*log(pi)+sum(gammaln(x),1); | ||
y = reshape(y,s); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
function y = logSt(X, mu, sigma, v) | ||
% Compute log pdf of a student-t distribution. | ||
% Written by mo Chen ([email protected]). | ||
[d,k] = size(mu); | ||
|
||
if size(sigma,1)==d && size(sigma,2)==d && k==1 | ||
[R,p]= cholcov(sigma,0); | ||
if p ~= 0 | ||
error('ERROR: sigma is not SPD.'); | ||
end | ||
X = bsxfun(@minus,X,mu); | ||
Q = R'\X; | ||
q = dot(Q,Q,1); % quadratic term (M distance) | ||
o = -log(1+q/v)*((v+d)/2); | ||
c = gammaln((v+d)/2)-gammaln(v/2)-(d*log(v*pi)+2*sum(log(diag(R))))/2; | ||
y = c+o; | ||
elseif size(sigma,1)==d && size(sigma,2)==k | ||
lambda = 1./sigma; | ||
ml = mu.*lambda; | ||
q = bsxfun(@plus,X'.^2*lambda-2*X'*ml,dot(mu,ml,1)); % M distance | ||
o = bsxfun(@times,log(1+bsxfun(@times,q,1./v)),-(v+d)/2); | ||
c = gammaln((v+d)/2)-gammaln(v/2)-(d*log(pi*v)+sum(log(sigma),1))/2; | ||
y = bsxfun(@plus,o,c); | ||
elseif size(sigma,1)==1 && size(sigma,2)==k | ||
X2 = repmat(dot(X,X,1)',1,k); | ||
D = bsxfun(@plus,X2-2*X'*mu,dot(mu,mu,1)); | ||
q = bsxfun(@times,D,1./sigma); % M distance | ||
o = bsxfun(@times,log(1+bsxfun(@times,q,1./v)),-(v+d)/2); | ||
c = gammaln((v+d)/2)-gammaln(v/2)-d*log(pi*v.*sigma)/2; | ||
y = bsxfun(@plus,o,c); | ||
else | ||
error('Parameters mismatched.'); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
function y = logVmf(X, mu, kappa) | ||
% Compute log pdf of a von Mises-Fisher distribution. | ||
% Written by Mo Chen ([email protected]). | ||
d = size(X,1); | ||
c = (d/2-1)*log(kappa)-(d/2)*log(2*pi)-logbesseli(d/2-1,kappa); | ||
q = bsxfun(@times,mu,kappa)'*X; | ||
y = bsxfun(@plus,q,c'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
function y = logWishart(Sigma, v, W) | ||
% Compute log pdf of a Wishart distribution. | ||
% Written by Mo Chen ([email protected]). | ||
d = length(Sigma); | ||
B = -0.5*v*logdet(W)-0.5*v*d*log(2)-logmvgamma(0.5*v,d); | ||
y = B+0.5*(v-d-1)*logdet(Sigma)-0.5*trace(W\Sigma); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
function model = linReg(X, t, lambda) | ||
% Fit linear regression model t=w'x+w0 | ||
% X: d x n data | ||
% t: 1 x n response | ||
% Written by Mo Chen ([email protected]). | ||
if nargin < 3 | ||
lambda = 0; | ||
end | ||
d = size(X,1); | ||
xbar = mean(X,2); | ||
tbar = mean(t,2); | ||
|
||
X = bsxfun(@minus,X,xbar); | ||
t = bsxfun(@minus,t,tbar); | ||
|
||
S = X*X'; | ||
dg = sub2ind([d,d],1:d,1:d); | ||
S(dg) = S(dg)+lambda; | ||
% w = S\(X*t'); | ||
R = chol(S); | ||
w = R\(R'\(X*t')); % 3.15 & 3.28 | ||
w0 = tbar-dot(w,xbar); % 3.19 | ||
|
||
model.w = w; | ||
model.w0 = w0; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
function [model, llh] = linRegEbEm(X, t, alpha, beta) | ||
% Fit empirical Bayesian linear model with EM | ||
% X: d x n data | ||
% t: 1 x n response | ||
% Written by Mo Chen ([email protected]). | ||
if nargin < 3 | ||
alpha = 0.02; | ||
beta = 0.5; | ||
end | ||
[d,n] = size(X); | ||
|
||
xbar = mean(X,2); | ||
tbar = mean(t,2); | ||
|
||
X = bsxfun(@minus,X,xbar); | ||
t = bsxfun(@minus,t,tbar); | ||
|
||
C = X*X'; | ||
Xt = X*t'; | ||
dg = sub2ind([d,d],1:d,1:d); | ||
I = eye(d); | ||
tol = 1e-4; | ||
maxiter = 100; | ||
llh = -inf(1,maxiter+1); | ||
for iter = 2:maxiter | ||
A = beta*C; | ||
A(dg) = A(dg)+alpha; | ||
U = chol(A); | ||
V = U\I; | ||
|
||
w = beta*(V*(V'*Xt)); | ||
w2 = dot(w,w); | ||
err = sum((t-w'*X).^2); | ||
|
||
logdetA = 2*sum(log(diag(U))); | ||
llh(iter) = 0.5*(d*log(alpha)+n*log(beta)-alpha*w2-beta*err-logdetA-n*log(2*pi)); | ||
if llh(iter)-llh(iter-1) < tol*abs(llh(iter-1)); break; end | ||
|
||
trS = dot(V(:),V(:)); | ||
alpha = d/(w2+trS); % 9.63 | ||
|
||
gamma = d-alpha*trS; | ||
beta = n/(err+gamma/beta); | ||
end | ||
w0 = tbar-dot(w,xbar); | ||
|
||
llh = llh(2:iter); | ||
model.w0 = w0; | ||
model.w = w; | ||
model.alpha = alpha; | ||
model.beta = beta; | ||
model.xbar = xbar; | ||
model.V = V; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
function [model, llh] = linRegEbFp(X, t, alpha, beta) | ||
% Fit empirical Bayesian linear model with Mackay fixed point method | ||
% X: d x n data | ||
% t: 1 x n response | ||
% Written by Mo Chen ([email protected]). | ||
if nargin < 3 | ||
alpha = 0.02; | ||
beta = 0.5; | ||
end | ||
[d,n] = size(X); | ||
|
||
xbar = mean(X,2); | ||
tbar = mean(t,2); | ||
|
||
X = bsxfun(@minus,X,xbar); | ||
t = bsxfun(@minus,t,tbar); | ||
|
||
C = X*X'; | ||
Xt = X*t'; | ||
dg = sub2ind([d,d],1:d,1:d); | ||
I = eye(d); | ||
tol = 1e-4; | ||
maxiter = 100; | ||
llh = -inf(1,maxiter+1); | ||
for iter = 2:maxiter | ||
A = beta*C; | ||
A(dg) = A(dg)+alpha; | ||
U = chol(A); | ||
V = U\I; | ||
|
||
w = beta*(V*(V'*Xt)); | ||
w2 = dot(w,w); | ||
err = sum((t-w'*X).^2); | ||
|
||
logdetA = 2*sum(log(diag(U))); | ||
llh(iter) = 0.5*(d*log(alpha)+n*log(beta)-alpha*w2-beta*err-logdetA-n*log(2*pi)); | ||
if llh(iter)-llh(iter-1) < tol*abs(llh(iter-1)); break; end | ||
|
||
trS = dot(V(:),V(:)); | ||
gamma = d-alpha*trS; | ||
alpha = gamma/w2; | ||
beta = (n-gamma)/err; | ||
end | ||
w0 = tbar-dot(w,xbar); | ||
|
||
llh = llh(2:iter); | ||
model.w0 = w0; | ||
model.w = w; | ||
model.alpha = alpha; | ||
model.beta = beta; | ||
model.xbar = xbar; | ||
model.V = V; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
function U = fda(X, y, d) | ||
% Fisher (linear) discriminant analysis | ||
% Written by Mo Chen ([email protected]). | ||
n = size(X,2); | ||
k = max(y); | ||
|
||
E = sparse(1:n,y,true,n,k,n); % transform label into indicator matrix | ||
nk = full(sum(E)); | ||
|
||
m = mean(X,2); | ||
Xo = bsxfun(@minus,X,m); | ||
St = (Xo*Xo')/n; | ||
|
||
mk = bsxfun(@times,X*E,1./nk); | ||
mo = bsxfun(@minus,mk,m); | ||
mo = bsxfun(@times,mo,sqrt(nk/n)); | ||
Sb = mo*mo'; | ||
% Sw = St-Sb; | ||
|
||
[U,A] = eig(Sb,St,'chol'); | ||
[~,idx] = sort(diag(A),'descend'); | ||
U = U(:,idx(1:d)); | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
function [model, llh] = logitReg(X, t, lambda) | ||
% logistic regression for binary classification (Bernoulli likelihood) | ||
% Written by Mo Chen ([email protected]). | ||
if nargin < 3 | ||
lambda = 1e-4; | ||
end | ||
[d,n] = size(X); | ||
dg = sub2ind([d,d],1:d,1:d); | ||
X = [X; ones(1,n)]; | ||
d = d+1; | ||
|
||
tol = 1e-4; | ||
maxiter = 100; | ||
llh = -inf(1,maxiter); | ||
|
||
h = ones(1,n); | ||
h(t==0) = -1; | ||
w = zeros(d,1); | ||
z = w'*X; | ||
for iter = 2:maxiter | ||
y = sigmoid(z); | ||
Xw = bsxfun(@times, X, sqrt(y.*(1-y))); | ||
H = Xw*Xw'; | ||
H(dg) = H(dg)+lambda; | ||
g = X*(y-t)'+lambda*w; | ||
p = -H\g; | ||
wo = w; | ||
while true | ||
w = wo+p; | ||
z = w'*X; | ||
llh(iter) = -sum(log1pexp(-h.*z))-0.5*lambda*dot(w,w); | ||
progress = llh(iter)-llh(iter-1); | ||
if progress < 0 | ||
p = p/2; | ||
else | ||
break; | ||
end | ||
end | ||
if progress < tol | ||
break | ||
end | ||
end | ||
llh = llh(2:iter); | ||
model.w = w; |
Oops, something went wrong.