forked from Mikoto10032/DeepLearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6920cd8
commit c959a4e
Showing
186 changed files
with
5,096 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
reference/* | ||
*.m~ | ||
*.asv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
Introduction | ||
------- | ||
This package is a Matlab implementation of the algorithms described in the classical machine learning textbook: | ||
Pattern Recognition and Machine Learning by C. Bishop ([PRML](http://research.microsoft.com/en-us/um/people/cmbishop/prml/)). | ||
|
||
Note: this package requires Matlab **R2016b** or latter, since it utilizes a new syntax of Matlab called [Implicit expansion](https://cn.mathworks.com/help/matlab/release-notes.html?rntext=implicit+expansion&startrelease=R2016b&endrelease=R2016b&groupby=release&sortby=descending) (a.k.a. broadcasting in Python). | ||
|
||
Description | ||
------- | ||
While developing this package, I stick to following principles | ||
|
||
* Succinct: The code is extremely terse. Minimizing the number of lines is one of the primal goals. As a result, the core of the algorithms can be easily spot. | ||
* Efficient: Many tricks for making Matlab scripts fast were applied (eg. vectorization and matrix factorization). Many functions are even comparable with C implementations. Usually, functions in this package are orders faster than Matlab builtin ones which provide the same functionality (eg. kmeans). If anyone have found any Matlab implementation that is faster than mine, I am happy to further optimize. | ||
* Robust: Many tricks for numerical stability are applied, such as probability computation in log scale and square root matrix update to enforce matrix symmetry, etc. | ||
* Readable: The code is heavily commented. Reference formulas in PRML book are indicated for corresponding code lines. Symbols are in sync with the book. | ||
* Practical: The package is designed not only to be easily read, but also to be easily used to facilitate ML research. Many functions in this package are already widely used (see [Matlab file exchange](http://www.mathworks.com/matlabcentral/fileexchange/?term=authorid%3A49739)). | ||
|
||
Installation | ||
------- | ||
1. Download the package to your local path (e.g. PRMLT/) by running: `git clone https://github.com/PRML/PRMLT.git`. | ||
|
||
2. Run Matlab and navigate to PRMLT/, then run the init.m script. | ||
|
||
3. Try demos in PRMLT/demo directory to verify installation correctness. Enjoy! | ||
|
||
FeedBack | ||
------- | ||
If you found any bug or have any suggestion, please do file issues. I am graceful for any feedback and will do my best to improve this package. | ||
|
||
License | ||
------- | ||
Currently Released Under GPLv3 | ||
|
||
|
||
Contact | ||
------- | ||
sth4nth at gmail dot com |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
function z = condEntropy (x, y) | ||
% Compute conditional entropy z=H(x|y) of two discrete variables x and y. | ||
% Input: | ||
% x, y: two integer vector of the same length | ||
% Output: | ||
% z: conditional entropy z=H(x|y) | ||
% Written by Mo Chen ([email protected]). | ||
assert(numel(x) == numel(y)); | ||
n = numel(x); | ||
x = reshape(x,1,n); | ||
y = reshape(y,1,n); | ||
|
||
l = min(min(x),min(y)); | ||
x = x-l+1; | ||
y = y-l+1; | ||
k = max(max(x),max(y)); | ||
|
||
idx = 1:n; | ||
Mx = sparse(idx,x,1,n,k,n); | ||
My = sparse(idx,y,1,n,k,n); | ||
Pxy = nonzeros(Mx'*My/n); %joint distribution of x and y | ||
Hxy = -dot(Pxy,log2(Pxy)); | ||
|
||
Py = nonzeros(mean(My,1)); | ||
Hy = -dot(Py,log2(Py)); | ||
|
||
% conditional entropy H(x|y) | ||
z = Hxy-Hy; | ||
z = max(0,z); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
function z = entropy(x) | ||
% Compute entropy z=H(x) of a discrete variable x. | ||
% Input: | ||
% x: a integer vectors | ||
% Output: | ||
% z: entropy z=H(x) | ||
% Written by Mo Chen ([email protected]). | ||
n = numel(x); | ||
[~,~,x] = unique(x); | ||
Px = accumarray(x, 1)/n; | ||
Hx = -dot(Px,log2(Px)); | ||
z = max(0,Hx); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
function z = jointEntropy(x, y) | ||
% Compute joint entropy z=H(x,y) of two discrete variables x and y. | ||
% Input: | ||
% x, y: two integer vector of the same length | ||
% Output: | ||
% z: joint entroy z=H(x,y) | ||
% Written by Mo Chen ([email protected]). | ||
assert(numel(x) == numel(y)); | ||
n = numel(x); | ||
x = reshape(x,1,n); | ||
y = reshape(y,1,n); | ||
|
||
l = min(min(x),min(y)); | ||
x = x-l+1; | ||
y = y-l+1; | ||
k = max(max(x),max(y)); | ||
|
||
idx = 1:n; | ||
p = nonzeros(sparse(idx,x,1,n,k,n)'*sparse(idx,y,1,n,k,n)/n); %joint distribution of x and y | ||
|
||
z = -dot(p,log2(p)); | ||
z = max(0,z); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
function z = mutInfo(x, y) | ||
% Compute mutual information I(x,y) of two discrete variables x and y. | ||
% Input: | ||
% x, y: two integer vector of the same length | ||
% Output: | ||
% z: mutual information z=I(x,y) | ||
% Written by Mo Chen ([email protected]). | ||
assert(numel(x) == numel(y)); | ||
n = numel(x); | ||
x = reshape(x,1,n); | ||
y = reshape(y,1,n); | ||
|
||
l = min(min(x),min(y)); | ||
x = x-l+1; | ||
y = y-l+1; | ||
k = max(max(x),max(y)); | ||
|
||
idx = 1:n; | ||
Mx = sparse(idx,x,1,n,k,n); | ||
My = sparse(idx,y,1,n,k,n); | ||
Pxy = nonzeros(Mx'*My/n); %joint distribution of x and y | ||
Hxy = -dot(Pxy,log2(Pxy)); | ||
|
||
Px = nonzeros(mean(Mx,1)); | ||
Py = nonzeros(mean(My,1)); | ||
|
||
% entropy of Py and Px | ||
Hx = -dot(Px,log2(Px)); | ||
Hy = -dot(Py,log2(Py)); | ||
% mutual information | ||
z = Hx+Hy-Hxy; | ||
z = max(0,z); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
function z = nmi(x, y) | ||
% Compute normalized mutual information I(x,y)/sqrt(H(x)*H(y)) of two discrete variables x and y. | ||
% Input: | ||
% x, y: two integer vector of the same length | ||
% Ouput: | ||
% z: normalized mutual information z=I(x,y)/sqrt(H(x)*H(y)) | ||
% Written by Mo Chen ([email protected]). | ||
assert(numel(x) == numel(y)); | ||
n = numel(x); | ||
x = reshape(x,1,n); | ||
y = reshape(y,1,n); | ||
|
||
l = min(min(x),min(y)); | ||
x = x-l+1; | ||
y = y-l+1; | ||
k = max(max(x),max(y)); | ||
|
||
idx = 1:n; | ||
Mx = sparse(idx,x,1,n,k,n); | ||
My = sparse(idx,y,1,n,k,n); | ||
Pxy = nonzeros(Mx'*My/n); %joint distribution of x and y | ||
Hxy = -dot(Pxy,log2(Pxy)); | ||
|
||
|
||
% hacking, to elimative the 0log0 issue | ||
Px = nonzeros(mean(Mx,1)); | ||
Py = nonzeros(mean(My,1)); | ||
|
||
% entropy of Py and Px | ||
Hx = -dot(Px,log2(Px)); | ||
Hy = -dot(Py,log2(Py)); | ||
|
||
% mutual information | ||
MI = Hx + Hy - Hxy; | ||
|
||
% normalized mutual information | ||
z = sqrt((MI/Hx)*(MI/Hy)); | ||
z = max(0,z); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
function z = nvi(x, y) | ||
% Compute normalized variation information z=(1-I(x,y)/H(x,y)) of two discrete variables x and y. | ||
% Input: | ||
% x, y: two integer vector of the same length | ||
% Output: | ||
% z: normalized variation information z=(1-I(x,y)/H(x,y)) | ||
% Written by Mo Chen ([email protected]). | ||
assert(numel(x) == numel(y)); | ||
n = numel(x); | ||
x = reshape(x,1,n); | ||
y = reshape(y,1,n); | ||
|
||
l = min(min(x),min(y)); | ||
x = x-l+1; | ||
y = y-l+1; | ||
k = max(max(x),max(y)); | ||
|
||
idx = 1:n; | ||
Mx = sparse(idx,x,1,n,k,n); | ||
My = sparse(idx,y,1,n,k,n); | ||
Pxy = nonzeros(Mx'*My/n); %joint distribution of x and y | ||
Hxy = -dot(Pxy,log2(Pxy)); | ||
|
||
Px = nonzeros(mean(Mx,1)); | ||
Py = nonzeros(mean(My,1)); | ||
|
||
% entropy of Py and Px | ||
Hx = -dot(Px,log2(Px)); | ||
Hy = -dot(Py,log2(Py)); | ||
|
||
% nvi | ||
z = 2-(Hx+Hy)/Hxy; | ||
z = max(0,z); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
function z = relatEntropy (x, y) | ||
% Compute relative entropy (a.k.a KL divergence) z=KL(p(x)||p(y)) of two discrete variables x and y. | ||
% Input: | ||
% x, y: two integer vector of the same length | ||
% Output: | ||
% z: relative entropy (a.k.a KL divergence) z=KL(p(x)||p(y)) | ||
% Written by Mo Chen ([email protected]). | ||
assert(numel(x) == numel(y)); | ||
n = numel(x); | ||
x = reshape(x,1,n); | ||
y = reshape(y,1,n); | ||
|
||
l = min(min(x),min(y)); | ||
x = x-l+1; | ||
y = y-l+1; | ||
k = max(max(x),max(y)); | ||
|
||
idx = 1:n; | ||
Mx = sparse(idx,x,1,n,k,n); | ||
My = sparse(idx,y,1,n,k,n); | ||
Px = nonzeros(mean(Mx,1)); | ||
Py = nonzeros(mean(My,1)); | ||
|
||
z = -dot(Px,log2(Py)-log2(Px)); | ||
z = max(0,z); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
function y = logDirichlet(X, a) | ||
% Compute log pdf of a Dirichlet distribution. | ||
% Input: | ||
% X: d x n data matrix, each column sums to one (sum(X,1)==ones(1,n) && X>=0) | ||
% a: d x k parameter of Dirichlet | ||
% y: k x n probability density | ||
% Output: | ||
% y: k x n probability density in logrithm scale y=log p(x) | ||
% Written by Mo Chen ([email protected]). | ||
X = bsxfun(@times,X,1./sum(X,1)); | ||
if size(a,1) == 1 | ||
a = repmat(a,size(X,1),1); | ||
end | ||
c = gammaln(sum(a,1))-sum(gammaln(a),1); | ||
g = (a-1)'*log(X); | ||
y = bsxfun(@plus,g,c'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
function y = logGauss(X, mu, sigma) | ||
% Compute log pdf of a Gaussian distribution. | ||
% Input: | ||
% X: d x n data matrix | ||
% mu: d x 1 mean vector of Gaussian | ||
% sigma: d x d covariance matrix of Gaussian | ||
% Output: | ||
% y: 1 x n probability density in logrithm scale y=log p(x) | ||
% Written by Mo Chen ([email protected]). | ||
d = size(X,1); | ||
X = X-mu; | ||
[U,p]= chol(sigma); | ||
if p ~= 0 | ||
error('ERROR: sigma is not PD.'); | ||
end | ||
Q = U'\X; | ||
q = dot(Q,Q,1); % quadratic term (M distance) | ||
c = d*log(2*pi)+2*sum(log(diag(U))); % normalization constant | ||
y = -(c+q)/2; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
function z = logKde (X, Y, sigma) | ||
% Compute log pdf of kernel density estimator. | ||
% Input: | ||
% X: d x n data matrix to be evaluate | ||
% Y: d x k data matrix served as database | ||
% Output: | ||
% z: probability density in logrithm scale z=log p(x|y) | ||
% Written by Mo Chen ([email protected]). | ||
D = dot(X,X,1)+dot(Y,Y,1)'-2*(Y'*X); | ||
z = logsumexp(D/(-2*sigma^2),1)-0.5*log(2*pi)-log(sigma*size(Y,2),1); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
function z = logMn(x, p) | ||
% Compute log pdf of a multinomial distribution. | ||
% Input: | ||
% x: d x 1 integer vector | ||
% p: d x 1 probability | ||
% Output: | ||
% z: probability density in logrithm scale z=log p(x) | ||
% Written by Mo Chen ([email protected]). | ||
z = gammaln(sum(x)+1)-sum(gammaln(x+1))+dot(x,log(p)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
function y = logMvGamma(x, d) | ||
% Compute logarithm multivariate Gamma function | ||
% which is used in the probability density function of the Wishart and inverse Wishart distributions. | ||
% Gamma_d(x) = pi^(d(d-1)/4) \prod_(j=1)^d Gamma(x+(1-j)/2) | ||
% log(Gamma_d(x)) = d(d-1)/4 log(pi) + \sum_(j=1)^d log(Gamma(x+(1-j)/2)) | ||
% Input: | ||
% x: m x n data matrix | ||
% d: dimension | ||
% Output: | ||
% y: m x n logarithm multivariate Gamma | ||
% Written by Michael Chen ([email protected]). | ||
y = d*(d-1)/4*log(pi)+sum(gammaln(x(:)+(1-(1:d))/2),2); | ||
y = reshape(y,size(x)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
function y = logSt(X, mu, sigma, v) | ||
% Compute log pdf of a Student's t distribution. | ||
% Input: | ||
% X: d x n data matrix | ||
% mu: mean | ||
% sigma: variance | ||
% v: degree of freedom | ||
% Output: | ||
% y: probability density in logrithm scale y=log p(x) | ||
% Written by mo Chen ([email protected]). | ||
[d,k] = size(mu); | ||
|
||
if size(sigma,1)==d && size(sigma,2)==d && k==1 | ||
[R,p]= chol(sigma); | ||
if p ~= 0 | ||
error('ERROR: sigma is not SPD.'); | ||
end | ||
X = bsxfun(@minus,X,mu); | ||
Q = R'\X; | ||
q = dot(Q,Q,1); % quadratic term (M distance) | ||
o = -log(1+q/v)*((v+d)/2); | ||
c = gammaln((v+d)/2)-gammaln(v/2)-(d*log(v*pi)+2*sum(log(diag(R))))/2; | ||
y = c+o; | ||
elseif size(sigma,1)==d && size(sigma,2)==k | ||
lambda = 1./sigma; | ||
ml = mu.*lambda; | ||
q = bsxfun(@plus,X'.^2*lambda-2*X'*ml,dot(mu,ml,1)); % M distance | ||
o = bsxfun(@times,log(1+bsxfun(@times,q,1./v)),-(v+d)/2); | ||
c = gammaln((v+d)/2)-gammaln(v/2)-(d*log(pi*v)+sum(log(sigma),1))/2; | ||
y = bsxfun(@plus,o,c); | ||
elseif size(sigma,1)==1 && size(sigma,2)==k | ||
X2 = repmat(dot(X,X,1)',1,k); | ||
D = bsxfun(@plus,X2-2*X'*mu,dot(mu,mu,1)); | ||
q = bsxfun(@times,D,1./sigma); % M distance | ||
o = bsxfun(@times,log(1+bsxfun(@times,q,1./v)),-(v+d)/2); | ||
c = gammaln((v+d)/2)-gammaln(v/2)-d*log(pi*v.*sigma)/2; | ||
y = bsxfun(@plus,o,c); | ||
else | ||
error('Parameters are mismatched.'); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
function y = logVmf(X, mu, kappa) | ||
% Compute log pdf of a von Mises-Fisher distribution. | ||
% Input: | ||
% X: d x n data matrix | ||
% mu: d x k mean | ||
% kappa: 1 x k variance | ||
% Output: | ||
% y: k x n probability density in logrithm scale y=log p(x) | ||
% Written by Mo Chen ([email protected]). | ||
d = size(X,1); | ||
c = (d/2-1)*log(kappa)-(d/2)*log(2*pi)-logbesseli(d/2-1,kappa); | ||
q = bsxfun(@times,mu,kappa)'*X; | ||
y = bsxfun(@plus,q,c'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
function y = logWishart(Sigma, W, v) | ||
% Compute log pdf of a Wishart distribution. | ||
% Input: | ||
% Sigma: d x d covariance matrix | ||
% W: d x d covariance parameter | ||
% v: degree of freedom | ||
% Output: | ||
% y: probability density in logrithm scale y=log p(Sigma) | ||
% Written by Mo Chen ([email protected]). | ||
d = length(Sigma); | ||
B = -0.5*v*logdet(W)-0.5*v*d*log(2)-logmvgamma(0.5*v,d); | ||
y = B+0.5*(v-d-1)*logdet(Sigma)-0.5*trace(W\Sigma); |
Oops, something went wrong.